TensorFlow2でのモデルファイルの保存と読み込み


Tensorflow2 (とkeras) でのmodelファイルの形式・保存方法が色々あってむずかしいのでメモ

modelファイルの種類

  • .h5: KerasでSequentialモデル全体を一括保存したファイル
    • モデルの構造、重み、訓練の設定、optimizerの情報がHDF5ファイルにまとまってる
  • SavedModel (ディレクトリ): TensorFlowのSavedModel形式
    • 重みを含むチェックポイントファイルとモデルのグラフ構造を持ったprotoファイルが一つのディレクトリに作られる

開発時は基本的にSavedModel形式で保存しておき、デプロイする環境に応じてフォーマットを変えるのが一般的っぽい?
TensorFlow ServingにはこのSavedModelを使うらしい

1. Kerasで訓練済のSequential modelをsave/load

Save and serialize models with Keras の要約:

save

モデル全体を単一ファイルに保存

model.save('path_to_my_model.h5')

HDF5形式でモデルの構造、重み、訓練の設定、optimizerの情報が1つのファイルに保存される

SavedModel形式で保存


model.save('path_to_saved_model', save_format='tf')

path_to_saved_model ディレクトリに重みを含むチェックポイントファイルとモデルのグラフのpbファイルが保存される

SavedModel保存についてはより低レベルで高機能な tf.saved_model を使う方法もある:

tf.saved_model.save(my_trained_model, "mymodel/1/")

モデルの構造や重みだけ保存する

モデルの構造だけ欲しい場合 model.get_config() で学習済みモデルの構造だけ取得し、
keras.Model.from_config(config) で新しいモデルを作って保存する。詳細は公式ガイドのArchitecture only savingなど参照

load

h5ファイルからモデルをロード

new_model = keras.models.load_model('path_to_my_model.h5')

SavedModelからモデルをロード

new_model = tf.keras.models.load_model("mymodel/1/")

自動で判別してくれるみたいなのでh5ファイル or SavedModelディレクトリのパスを渡すだけ

2. Kerasで訓練の過程でSequential modelをsave/load

Save and load models

checkpointを使った重みの保存

checkpoint機能を使うことで訓練途中の重みを随時保存できる。
model.fitの際にcallbackを指定しておくことで、訓練中に更新される重みがファイルに保存される


cp_callback = tf.keras.callbacks.ModelCheckpoint(
    filepath='training_1/cp.ckpt', 
    save_weights_only=True)

model.fit(train_images, 
          train_labels,  
          epochs=10,
          validation_data=(test_images,test_labels),
          callbacks=[cp_callback])

詳細は Checkpoint callback usage 参照

明示的に保存

例えば訓練前の最初の重みも保存する場合


model = create_model()
checkpoint_path = "training_2/cp-{epoch:04d}.ckpt"
model.save_weights(checkpoint_path.format(epoch=0))

補足

Note: the default tensorflow format only saves the 5 most recent checkpoints.

最新5つのチェックポイントだけ保存されるらしい

checkpointファイルから重みをロード

  1. 未訓練の新しいmodelを作る
  2. `model.load_weights(checkpoint_path) で保存したチェックポイントファイルから重みをロード

当然ながら保存した時とネットワークの構造が同じでないとloadできない