GradientTapeを使ってfitするとエラーが出た。そして解決した。


GradientTapeを使ってfitするとエラー出た

Variatioanal AutoEncoderを実行する際に普段はself.add_loss()を使い学習させていました。
以下のようなコードです。

 #モデルを構築
class VAE(tf.keras.Model):

  def __init__(self,latent_dim):
    super(VAE,self).__init__()
    self.encoder = Encode(latent_dim)
    self.decoder = Decode()
  
  #lossを計算
  def VAE_loss(self,x):
    mean,logvar,z = self.encoder(x)
    x_sigmoid = self.decoder(z)

    #reshape
    shape = tf.shape(x)
    x = tf.reshape(x,[shape[0],-1])
    x_sigmoid = tf.reshape(x_sigmoid,[shape[0],-1])

    #復元誤差
    reconstruction_loss = tf.keras.losses.BinaryCrossentropy(reduction=tf.keras.losses.Reduction.NONE)(x,x_sigmoid)
    #MNISTのshape
    image_shape=28*28
    reconstruction_loss *= image_shape
   
    #KL-divergence(正規分布との差)
    kl_divergence =-0.5* tf.reduce_sum(1 + logvar - tf.square(mean) - tf.exp(logvar),axis=1)
    vae_loss = tf.reduce_mean(reconstruction_loss + kl_divergence)
    return vae_loss

  def call(self,inputs):
    loss = self.VAE_loss(inputs)
    self.add_loss(loss,inputs=inputs)
    return x

encoder、decoderのコードは以下から見てください。
Variatinal AutoEoncoderをSubclassing APIで書いてみた

call関数内でVAE_lossという独自の損失関数を定義し、そのlossをself.add_lossに入れます。これでmodel.compile、model.fitをすることでトレーニング開始してくれます。
しかし、このself.add_lossは、研究室で先輩方や同期に聞いても知らないという人が多くて、GradientTapeを使って学習させてると。ならば自分もGradientTapeを使おうと思ったわけですよ。
上のコードを書き換えます。

 #モデルを構築
class VAE(tf.keras.Model):

  def __init__(self,latent_dim):
    super(VAE,self).__init__()
    self.encoder = Encode(latent_dim)
    self.decoder = Decode()
  
  def call(self,x):
    mean,logvar,z = self.encoder(x)
    y_pred = self.decoder(z)
    return mean,logvar,z,y_pred
    
  #lossを計算
  def VAE_loss(self,x,x_sigmoid,mean,logvar,z):
    #reshape
    shape = tf.shape(x)
    x = tf.reshape(x,[shape[0],-1])
    x_sigmoid = tf.reshape(x_sigmoid,[shape[0],-1])
    #復元誤差
    reconstruction_loss = tf.keras.losses.BinaryCrossentropy(reduction=tf.keras.losses.Reduction.NONE)(x,x_sigmoid)
    image_shape=28*28
    reconstruction_loss *= image_shape

    #KL-divergence(正規分布との差)
    kl_divergence =-0.5* tf.reduce_sum(1 + logvar - tf.square(mean) - tf.exp(logvar),axis=1)
    vae_loss = tf.reduce_mean(reconstruction_loss + kl_divergence)
    return vae_loss

  def train_step(self,x):
    with tf.GradientTape() as tape:
        mean,logvar,z,y_pred= self(x, training=True)
        loss = self.VAE_loss(x,y_pred,mean,logvar,z)
    gradients = tape.gradient(loss, self.trainable_variables)
    self.optimizer.apply_gradients(zip(gradients, self.trainable_variables))
    return {"loss":loss}

さて、これでOKと思い、実行すると

TypeError: Dimension value must be integer or None or have an __index__ method, got TensorShape([None, 28, 28, 1])

え?どゆこと?特におかしくなるようにいじったわけではないが、なぜかエラーが出る。Google様でこのエラー文を検索しても解決策が出てこない...

解決しました

さて問題が解決しました。なにがダメだったのかをみると、入力データの型でした。

 #train_stepに入る前の型
type(train_images)
<class 'numpy.ndarray'>
 #train_stepに入ったあとの型
type(x)
<class 'tuple'>

型がなんかtupleに変わってますね。なんでtupleになってんだ?

if isinstance(x, tuple):
  x = x[0]
 #TensorShape([None, 28, 28, 1])

これでtupleからTensorShapeに変わりました。
このコードをwith tf.GradientTape() as tape:の上に書くと、動きました。これに気づくまでかなり時間かかりました。ずっとほかの部分を見てました。同じようなエラーで悩んでる方は参考にしてください。