Pytorchで新しくRNNを作るには?
研究でRNNを使う時に,デフォルトのtorch.nn.RNN()等を使うのではなくて,
自作したRNNを使用したいと思っていたのですが,やり方に結構悩んだので,備忘録的にまとめておきます.
今回扱うのはデフォルトのRNNCellやLSTM,GRUなどを使ってモデルを組む段階の話ではなく,
RNNCellそのものを自作したり,内部のリンクを変更したいと思った時にどうすればいいかというお話です.
結論から言ってしまえば,nn.Moduleを継承してRNNのクラスを作るだけでした.
※ゼロから作るDeep Learning 3 のp.475~477がとても参考になりました.
部分的なコードは以下のような感じです.
import torch
import torch.nn as nn
class RNNCell(nn.Module):
def __init__(self,n_in,n_hid):
super(RNNCell,self).__init__()
self.i2h=nn.Linear(n_in,n_hid)
self.h2h=nn.Linear(n_hid,n_hid)
self.h=None
def reset_state(self):
self.h=None
def forward(self,x):
if self.h is None:
h_new=self.i2h(x)
else:
h_new=self.i2h(x)+self.h2h(torch.tanh(self.h))
self.h=h_new
h_out=torch.tanh(h_new)
h_out=h_out.detach()
return h_out
ちなみに上のコードのh_newは隠れ状態ではありません.
その一歩手前の(tanhに通す前の)内部状態を表しています.
また,outputのリンク部分は実際に回すモデルの方(RNNCellを使ってモデルを作る時に)で指定してます.
detach()の部分は自分もよく分からないです.
BackProp周りのエラーを解消するために調べて出てきたものを付けました.
後は好きなように
h_new=self.i2h(x)+self.h2h(torch.tanh(self.h))
の部分などを変更すれば,任意のRNNが作れるはずです.
Author And Source
この問題について(Pytorchで新しくRNNを作るには?), 我々は、より多くの情報をここで見つけました https://qiita.com/tsubasa_hizono/items/d3095a4dc7e4cf91ffdb著者帰属:元の著者の情報は、元のURLに含まれています。著作権は原作者に属する。
Content is automatically searched and collected through network algorithms . If there is a violation . Please contact us . We will adjust (correct author information ,or delete content ) as soon as possible .