kerasのRNN系APIの引数return_state, return_sequencesについて


kerasのRNN系APIのGRUの引数return_state, return_sequencesについて

大雑把に書きました。

環境
python3

使用する擬似データ

B = 1  #バッチサイズ
T = 10 #時系列長
N = 1000 #特徴量
data = np.random.randn(B, T, N)

使用するRNN系インターフェース

tf.keras.layers.GRU

return_state=True, return_sequences=True

赤丸 : return_sequencesをTrue時
緑丸 : return_statesをTrue時

gru = tf.keras.layers.GRU(256, return_state=True, return_sequences=True)
B = 1
T = 10
N = 1000
data = np.random.randn(B, T, N)
outputs, states = gru(data)
print("赤丸:", outputs.shape)
print("緑丸:", states.shape)
赤丸: (1, 10, 256)
緑丸: (1, 256)

return_state=True, return_sequences=False

赤丸 : return_sequencesをFalse時
緑丸 : return_statesをTrue時

gru = tf.keras.layers.GRU(256, return_state=True, return_sequences=False)
B = 1
T = 10
N = 1000
data = np.random.randn(B, T, N)
outputs, states = gru(data)
print("赤丸:", outputs.shape)
print("緑丸:", states.shape)
赤丸: (1, 256)
緑丸: (1, 256)

return_state=False, return_sequences=True

gru = tf.keras.layers.GRU(256, return_state=False, return_sequences=True)
B = 1
T = 10
N = 1000
data = np.random.randn(B, T, N)
outputs = gru(data)
print("赤丸:", outputs.shape)
print("緑丸なし")
赤丸: (1, 10, 256)
緑丸なし

return_state=False, return_sequences=False

gru = tf.keras.layers.GRU(256, return_state=False, return_sequences=False)
B = 1
T = 10
N = 1000
data = np.random.randn(B, T, N)
outputs = gru(data)
print("赤丸:", outputs.shape)
print("緑丸なし")
赤丸: (1, 256)
緑丸なし