kerasのRNN系APIの引数return_state, return_sequencesについて
3166 ワード
kerasのRNN系APIのGRUの引数return_state, return_sequencesについて
大雑把に書きました。
環境
python3
使用する擬似データ
B = 1 #バッチサイズ
T = 10 #時系列長
N = 1000 #特徴量
data = np.random.randn(B, T, N)
使用するRNN系インターフェース
tf.keras.layers.GRU
return_state=True, return_sequences=True
赤丸 : return_sequencesをTrue時
緑丸 : return_statesをTrue時
gru = tf.keras.layers.GRU(256, return_state=True, return_sequences=True)
B = 1
T = 10
N = 1000
data = np.random.randn(B, T, N)
outputs, states = gru(data)
print("赤丸:", outputs.shape)
print("緑丸:", states.shape)
赤丸: (1, 10, 256)
緑丸: (1, 256)
return_state=True, return_sequences=False
赤丸 : return_sequencesをFalse時
緑丸 : return_statesをTrue時
gru = tf.keras.layers.GRU(256, return_state=True, return_sequences=False)
B = 1
T = 10
N = 1000
data = np.random.randn(B, T, N)
outputs, states = gru(data)
print("赤丸:", outputs.shape)
print("緑丸:", states.shape)
赤丸: (1, 256)
緑丸: (1, 256)
return_state=False, return_sequences=True
gru = tf.keras.layers.GRU(256, return_state=False, return_sequences=True)
B = 1
T = 10
N = 1000
data = np.random.randn(B, T, N)
outputs = gru(data)
print("赤丸:", outputs.shape)
print("緑丸なし")
赤丸: (1, 10, 256)
緑丸なし
return_state=False, return_sequences=False
gru = tf.keras.layers.GRU(256, return_state=False, return_sequences=False)
B = 1
T = 10
N = 1000
data = np.random.randn(B, T, N)
outputs = gru(data)
print("赤丸:", outputs.shape)
print("緑丸なし")
赤丸: (1, 256)
緑丸なし
Author And Source
この問題について(kerasのRNN系APIの引数return_state, return_sequencesについて), 我々は、より多くの情報をここで見つけました https://qiita.com/niwaka_dev/items/1b84da7872ac98ae694f著者帰属:元の著者の情報は、元のURLに含まれています。著作権は原作者に属する。
Content is automatically searched and collected through network algorithms . If there is a violation . Please contact us . We will adjust (correct author information ,or delete content ) as soon as possible .