TensorFlow2忘備録 Kerasカスタムモデル で中間層の値を取得する
ディープラーニングの研究や実践では、中間層の値を取得して観察する、勾配の計算学習に使うなど中間層の値を取得したい場合が多々あります。
TensorFlow2において,KerasのSubclassing API を使ってカスタムモデルを作るとき、これをどのように行うか、作ってみました。
Qiitaで既に他の方々が書かれた記事があります.
今回結果的にそれらとは異なった実装にになっています。
学習済みモデルなどをロードして使う場合は解説の範囲外です。
方針
-
call
の引数でフラグを渡し、条件分岐する。 - 中間層がいらない場合は通常通り、モデルの出力のみを
return
する - 中間層を取得する場合、計算の途中のテンソルをリストに保持しておき、モデルの出力と、中間層のテンソルが入ったリストを
return
する
シンプルですがこれで動きます。tf.function
を使った高速化にも対応しており、条件分岐やリストの処理は問題を起こさないようです。(参考)
実装
例として、3個のDense層を含むMLPモデルを作ります。
class MyModel(Model):
''' 3つのdense層をもつMLP'''
def __init__(self):
super(MyModel, self).__init__()
self.d1 = Dense(100,activation='relu')
self.d2 = Dense(200,activation='relu')
self.d3 = Dense(10)
self.ll = [self.d1,self.d2,self.d3]
#中間層の値のリストを返すときはreturn_hidden_states=Trueにして呼び出す
def call(self,x, training=False,return_hidden_states=False):
#return self.forward(x,training=training)
h_list=[] #取り出したい中間層はこのリストにいれる
print('tracing')
for l in self.ll:
x = l(x,training=training)
if return_hidden_states:
h_list.append(x)
if not return_hidden_states:
return x
else:
return x,h_list
return_hidden_states
フラグをTrue
にすると、callされたときに、モデルの出力と中間層のリストの2つを返します。
例:
model = MyModel()
x = np.random.uniform(size=(5,28)).astype(np.float32)
t = np.random.randint(0,10,size=(5,))
# 出力のみの取り出し
y = model(x)
print('y.shape= {}'.format(y.shape))
print('---')
# 中間層込の取り出し
y,h = model(x,return_hidden_states=True)
print('y.shape= {}'.format(y.shape))
print('len(h) = {}'.format(len(h)))
print('h[0].shape = {}'.format(h[0].shape))
print('h[1].shape = {}'.format(h[1].shape))
print('h[2].shape = {}'.format(h[2].shape))
結果
tracing
y.shape= (5, 10)
---
tracing
y.shape= (5, 10)
len(h) = 3
h[0].shape = (5, 100)
h[1].shape = (5, 200)
h[2].shape = (5, 10)
tf.function
を使ったグラフモードでの実行も可能なようです。
普通の学習ステップ
@tf.function
def train_step(x, t):
with tf.GradientTape() as tape:
predictions = model(x, training=True)
loss = loss_object(t, predictions)
gradients = tape.gradient(loss, model.trainable_variables)
optimizer.apply_gradients(zip(gradients, model.trainable_variables))
print('---')
train_step(x,t) #トレーシングはここで起こる.
print('---')
train_step(x,t) #トレーシングはここでは起こらない
中間層を得るだけの関数
@tf.function
def get_hiddens(x, t):
y,h = model(x, training=False, return_hidden_states=True)
return h
h = get_hiddens(x,t)
中間層で微分してみる
@tf.function
def grad_hiddens(x, t):
with tf.GradientTape() as tape:
predictions,hiddens = model(x, training=True, return_hidden_states=True)
loss = loss_object(t, predictions)
gradients = tape.gradient(loss, hiddens)
return gradients
dLdh = grad_hiddens(x,t)
print('dLdh[0].shape = {}'.format(dLdh[0].shape))
print('dLdh[1].shape = {}'.format(dLdh[1].shape))
print('dLdh[2].shape = {}'.format(dLdh[2].shape))
今後
例では中間層すべてをリターンするようになっていますが、特定の中間層だけ取得したり、引数によってどの中間層を取得するか制御するようコードを変更することもできると思います。
Author And Source
この問題について(TensorFlow2忘備録 Kerasカスタムモデル で中間層の値を取得する), 我々は、より多くの情報をここで見つけました https://qiita.com/yymgt/items/5a24c6135e868442f8ed著者帰属:元の著者の情報は、元のURLに含まれています。著作権は原作者に属する。
Content is automatically searched and collected through network algorithms . If there is a violation . Please contact us . We will adjust (correct author information ,or delete content ) as soon as possible .