TensorFlow2忘備録 Kerasカスタムモデル で中間層の値を取得する


ディープラーニングの研究や実践では、中間層の値を取得して観察する、勾配の計算学習に使うなど中間層の値を取得したい場合が多々あります。
TensorFlow2において,KerasのSubclassing API を使ってカスタムモデルを作るとき、これをどのように行うか、作ってみました。

Qiitaで既に他の方々が書かれた記事があります.

今回結果的にそれらとは異なった実装にになっています。

学習済みモデルなどをロードして使う場合は解説の範囲外です。

方針

  • callの引数でフラグを渡し、条件分岐する。
  • 中間層がいらない場合は通常通り、モデルの出力のみをreturnする
  • 中間層を取得する場合、計算の途中のテンソルをリストに保持しておき、モデルの出力と、中間層のテンソルが入ったリストをreturnする

シンプルですがこれで動きます。tf.functionを使った高速化にも対応しており、条件分岐やリストの処理は問題を起こさないようです。(参考)

実装

例として、3個のDense層を含むMLPモデルを作ります。


class MyModel(Model):
  '''  3つのdense層をもつMLP'''
  def __init__(self):
    super(MyModel, self).__init__()
    self.d1 = Dense(100,activation='relu')
    self.d2 = Dense(200,activation='relu')
    self.d3 = Dense(10)
    self.ll = [self.d1,self.d2,self.d3]


  #中間層の値のリストを返すときはreturn_hidden_states=Trueにして呼び出す
  def call(self,x, training=False,return_hidden_states=False):
    #return self.forward(x,training=training)
    h_list=[] #取り出したい中間層はこのリストにいれる
    print('tracing')    
    for l in self.ll:
      x = l(x,training=training)

      if return_hidden_states:
        h_list.append(x)

    if not return_hidden_states:
      return x 
    else:
      return x,h_list

return_hidden_states フラグをTrueにすると、callされたときに、モデルの出力と中間層のリストの2つを返します。

例:

model = MyModel()

x = np.random.uniform(size=(5,28)).astype(np.float32)
t = np.random.randint(0,10,size=(5,))

# 出力のみの取り出し
y = model(x)
print('y.shape= {}'.format(y.shape))
print('---')
# 中間層込の取り出し
y,h = model(x,return_hidden_states=True)
print('y.shape= {}'.format(y.shape))
print('len(h) =  {}'.format(len(h)))
print('h[0].shape =  {}'.format(h[0].shape))
print('h[1].shape =  {}'.format(h[1].shape))
print('h[2].shape =  {}'.format(h[2].shape))

結果

tracing
y.shape= (5, 10)
---
tracing
y.shape= (5, 10)
len(h) =  3
h[0].shape =  (5, 100)
h[1].shape =  (5, 200)
h[2].shape =  (5, 10)

tf.functionを使ったグラフモードでの実行も可能なようです。

普通の学習ステップ


@tf.function
def train_step(x, t):
  with tf.GradientTape() as tape:
    predictions = model(x, training=True)
    loss = loss_object(t, predictions)
  gradients = tape.gradient(loss, model.trainable_variables) 
  optimizer.apply_gradients(zip(gradients, model.trainable_variables)) 

print('---')
train_step(x,t) #トレーシングはここで起こる.
print('---')
train_step(x,t) #トレーシングはここでは起こらない

中間層を得るだけの関数

@tf.function
def get_hiddens(x, t):
  y,h = model(x, training=False, return_hidden_states=True)
  return h

h = get_hiddens(x,t)

中間層で微分してみる


@tf.function
def grad_hiddens(x, t):
  with tf.GradientTape() as tape:
    predictions,hiddens = model(x, training=True, return_hidden_states=True)
    loss = loss_object(t, predictions)

  gradients = tape.gradient(loss, hiddens) 
  return gradients

dLdh = grad_hiddens(x,t)
print('dLdh[0].shape = {}'.format(dLdh[0].shape))
print('dLdh[1].shape = {}'.format(dLdh[1].shape))
print('dLdh[2].shape = {}'.format(dLdh[2].shape))

今後

例では中間層すべてをリターンするようになっていますが、特定の中間層だけ取得したり、引数によってどの中間層を取得するか制御するようコードを変更することもできると思います。