Kerasで可変長系列をEmbeddingしてLSTMに入力するときはmask_zero=Trueにする
KerasのLSTMに文章を入力して何らかの評価値を出力する場合、よく次のような流れで処理すると思います。
- 文章を形態素解析器で単語に分割
- 単語系列の系列長を揃える
- 単語ベクトルでEmbedding
- LSTMに入力
- 評価値を出力
「系列長を揃える」というのは、具体的には長すぎる系列を切断したり短すぎる系列をゼロパディングしたりします。
さて、これで問題なく学習できるかと思いきや、学習後にどんな文章を入力しても同じ評価値しか出力しないという状況に陥ってしまいました。
実はこのままだとゼロパディングした部分を0の系列として扱うため、うまく学習できなくなるケースがあるようです。
Embeddingレイヤーでmask_zero=True
にすると、ゼロパディングした部分を無視(?)してくれるようです。
- mask_zero: 真理値.入力の0をパディングのための特別値として扱うかどうか. これは入力の系列長が可変長となりうる変数を入力にもつRecurrentレイヤーに対して有効です. この引数が
True
のとき,以降のレイヤーは全てこのマスクをサポートする必要があり, そうしなければ,例外が起きます. mask_zeroがTrueのとき,index 0は語彙の中で使えません(input_dim は語彙数+1
と等しくなるべきです).
プログラムは次のようになります。
# 文章を形態素解析器で単語に分割
import MeCab
tagger = MeCab.Tagger('-Owakati')
texts = [tagger.parse(x).rstrip() for x in texts]
# 単語をインデックスに変換
from keras.preprocessing.text import Tokenizer
tokenizer = Tokenizer()
tokenizer.fit_on_texts(texts)
sequences = tokenizer.texts_to_sequences(texts)
# 単語系列の系列長を揃える
from keras.preprocessing.sequence import pad_sequences
maxlen = 8000
X = pad_sequences(sequences, maxlen=maxlen)
print(X.shape) # (2000, 8000)
# embedding_matrixを作成
embedding_matrix = np.zeros((max(tokenizer.word_index.values())+1, word2vec.vector_size))
for word, i in tokenizer.word_index.items():
if word in wv:
embedding_matrix[i] = word2vec.wv[word]
# Embedding&LSTM
main_input = Input(shape=(x.shape[1],), name='main_input')
embedding1 = Embedding(
input_dim=embedding_matrix.shape[0],
output_dim=embedding_matrix.shape[1],
weights=[embedding_matrix],
trainable=False,
mask_zero=True,
name='embedding1')(main_input)
lstm1 = LSTM(32, name='lstm1')(embedding1)
main_output = Dense(8, name='main_output')(lstm1)
model = Model(inputs=main_input, outputs=main_output)
model.compile(loss='mse', optimizer='adam')
embedding_matrix
はインデックスをキーとする単語ベクトルの辞書です。
mask_zero
の説明に「index 0は語彙の中で使えません」とあるので、1から開始するようにします。
このようにすると、文章ごとに異なる評価値を出力するようになりました。
Author And Source
この問題について(Kerasで可変長系列をEmbeddingしてLSTMに入力するときはmask_zero=Trueにする), 我々は、より多くの情報をここで見つけました https://qiita.com/hrappuccino/items/f66abebe60f8ea7826d5著者帰属:元の著者の情報は、元のURLに含まれています。著作権は原作者に属する。
Content is automatically searched and collected through network algorithms . If there is a violation . Please contact us . We will adjust (correct author information ,or delete content ) as soon as possible .