LSTMとSENet


なんかLSTMに関して理解してみたらCNNにおけるSENetの構造と似てるんじゃないかなと思ったので記事に起こしてみます。SENetに関しては割とどういう解釈が正解なのかよく分からないので多分に間違っている可能性が高いですが、真に受けずにそういう解釈もあるのだな程度に留めてください。

そもそもLSTMとは?

LSTM(Long Short-Term Memory)といってRNNの一種で通常のRNNが数10個ステップくらいの情報を引き継げないのに対して1000ステップの長距離の情報を保存することができます。
良くある図ではLSTMは例えば下記のように示されます。
スマートですね。スマートすぎて自分はこの図では意味が理解できませんでしたが。

RNNの歴史推定

そもそもRNNの生い立ちを知らないのでそれを想像で推定しながら書いていきます。

RNNの前

通常の順伝播型NNを時系列データに適用することを考えてみましょう。
すると、以下のような多層パーセプトロンモデルをsigmoid関数で出力に変換するモデルを最初考えたでしょう。

ところがこのモデルでいざ時系列データを逐一予測しようとするとおそらく上手く行かなかったに違いありません。何故なら時系列データにこのモデルを繰り返し適用すると以下のようになります。

モデルが繰り返し使われることから複数個前のステップのデータを用いるモデルは実質的に非常に深いモデルと等価と見なせるのです。ここでモデルの出力のために活性化関数にsigmoid関数が使われることに注目しましょう。中間層の活性化関数にsigmoid関数をつかうと勾配消失が起こりやすく多層での学習が上手く行われないと言われています。(このため現在の深層学習において中間層の活性化関数にはsigmoid関数ではなく代わりにrelu関数が使われます)
とするなら上記の時系列モデルにおいて学習可能なのはせいぜい1~2ステップまでで3ステップ前のパターンは勾配消失で学習できないことになります。

追記

記事を書き終えた後気付きましたが、勾配消失の原因がsigmoid関数としましたが、もしかしたら出力に変換するのに中間層を絞ることかもしれません。例えがあれだがAutoEncoderでEncorder出力の潜在変数の次元数が低すぎるとAutoEncoderの復元精度が下がるのと同じことです。とはいえ参考書的にはRNNで勾配消失が克服できるとあるので、潜在変数の次元数低下は関係ないかもしれません。

RNNの発想

RNN以前のモデルの何が問題だといえば出力を生成するためのsigmoid関数です。
わざわざ出力に変換してから再度入力として取り込む作業で勾配が消えてしまいます。
じゃあ出力に変換する前の中間層を直接次に渡せばいいんじゃないか?これがRNNの原理です。
順伝播型NNと違いRNNでは内部状態(記憶)を保持しているという解釈の方が人によっては良いかもしれません。


緑線がRNNにおいて追加された経路です。
これにより勾配伝達可能なステップ数はsigmoid関数をskip出来て学習可能なDNNの層数になりました(二十層程度?)。一方で、順伝播型DNNも数十層を超えると学習可能ではないようにこのRNNの学習にも上限がありました。

LSTMの発想

結局のところLSTMの発想はといえば渡す中間層をモデル内のNN層で汚染しないことだと思います。上記RNNでは渡す中間層が必ずNNモデルの入力の一部となっていましたが、これを独立させたの橙線を追加させたのがLSTMかと思います。橙線はRNNで渡す中間層よりゆっくり変わっていくと思われるので長距離の情報を保存できます。

実際のLSTM:

参考:http://colah.github.io/posts/2015-08-Understanding-LSTMs/
(正確にはLSTMの実装は何通りかあるみたいです)

Residualモジュール

ここで橙線と緑線の関係を追ってみるといわゆるResNetにおけるResidualモジュールに近い形をしているのが分かります。ResNetが畳み込み百層以上でも学習できることとLSTMが長距離の情報を保存するのは実は同じなんじゃないかと考えられます。

SE_Block部分

さて、先ほど少しスルーしてましたが実際のLSTMをもう一度よく見てみます。すると上図の紫線におけるsigmoid関数を掛けたものの積である短距離記憶を長距離記憶に足すという処理が行われているのが分かります。LSTMのこの処理を見た時、自分はSENetにおけるSE_Blockに近いのでは?と思いました。歴史的にはLSTMの方がずっと早いのですが。

def se_block(input, channels, r=8):
    # Squeeze
    x = GlobalAveragePooling2D()(input)
    # Excitation
    x = Dense(channels//r, activation="relu")(x)
    x = Dense(channels, activation="sigmoid")(x)
    return Multiply()([input, x])

SE_Block部分2

もう一つLSTMのSE_Blockに近い部分にてNNにtanhを掛けるのはなぜでしょうか。
これは強制的に値を-1~1に変換することになります。これで思い出すのは正規化(Batch_Normalization)です。LSTMが開発された当時(1997年)には相当する正規化がなかったため、これが使われたのかもしれません。
逆説的に言うならSE_Blockはチャンネル毎に係数を掛けることになりますがこれはBatch_Normalizationの機能も担っているのではないでしょうか。例えばSENetの論文にて活性化するsigmoid層の値とclassの関係が示されています。https://arxiv.org/pdf/1709.01507v3.pdf

Batch_Normalizationは学習データ全体をバッチ方向のサンプリングで各チャンネル出力を正規化(正則化)する機能ですが、係数の数はチャンネル数(*4)だけです。一方、SE_Blockでは各classによってその係数を変化させることができるようです。例えば「金魚」と「パグ」では係数が違いますが、これは二つのクラスにおける中間層での出力分布が異なっており、SE_Blockではそれぞれの分布を正則化出来る係数を選択できるのかもしれません。
(Batch_Normalizationは全クラスにおける分布を正則化しかできない。これはクラス間分布に差が見つけられない浅い層では問題がないのかもしれないが、クラスごとに出力分布の異なる深い層ではBatch_Normalizationはクラスごとの正則化性能が今一つなのかもしれない)

KerasのLSTMのパラメータ数からのモデル構造推定

KerasのLSTMモデルsummaryのパラメータ数を見てLSTMの実装を想像してみたい。
パラメータ数だけを見た想像なので実際合ってるかの保証はない。

from keras.models import Model
from keras.models import Sequential
from keras.layers.recurrent import LSTM

LENGTH_PER_UNIT = 30
DIMENSION = 5
HIDDEN_LAYER_COUNT = 10

input_shape=(None, LENGTH_PER_UNIT, DIMENSION)

model = Sequential()
model.add(LSTM(HIDDEN_LAYER_COUNT, batch_input_shape=input_shape))
model.add(Dense(1))
model.compile(loss='mse', optimizer='Adam')

model.summary()
print(model.layers[0].input.shape)
print(model.layers[0].output.shape)

......
_________________________________________________________________
Layer (type)                 Output Shape              Param #
=================================================================
lstm_1 (LSTM)                (None, 10)                640
_________________________________________________________________
dense_1 (Dense)              (None, 1)                 11
=================================================================
Total params: 651
Trainable params: 651
Non-trainable params: 0
_________________________________________________________________
(?, 30, 5)
(?, 10)

LSTMの適当な構成(LENGTH_PER_UNIT = 30, DIMENSION = 5, HIDDEN_LAYER_COUNT = 10)の場合、上記のようなパラメータ数(640)になりました。
これは(15)→(10)の全結合((15+1)*10=160)が計4個入ったモデルであることを示唆している。
注意すべきは(10)→(15)の全結合((10+1)*15=165)は含まれないことである。
この(15)→(10)の全結合が4個入ったモデルを考えると多分、以下AかBのどちらかであろう。
A、Bの大きな違いはCellの次元数だが、モデルの外からこれが見れないのでどちらか分からない。
赤線がLSTMモデルに含まれる(15)→(10)の全結合を示している。


またLENGTH_PER_UNITの値はLSTMのパラメータ数にはなんら影響を及ぼさない。
これは以下のように同じモデル重みを使いまわして多段に重ねる構成なのでLENGTH_PER_UNITを増やしてもパラメータ数は増加しないのだと思われる。

まとめ

LSTMの構造にResNet+SENet構造が入っているというポエムを書いた。