keras ModelCheckpointはブレークポイントの継続訓練機能を実現する

8231 ワード

参照リンク:
壁割れ推奨:https://cloud.tencent.com/developer/article/1049579
英語版原文:https://machinelearningmastery.com/check-point-deep-learning-models-keras/
kerasドキュメントコールバック関数:http://keras-cn.readthedocs.io/en/latest/other/callbacks/#modelcheckpoint
まずModelCheckpointのパラメータを見てみましょう.
keras.callbacks.ModelCheckpoint(
    filepath,
    monitor='val_loss',
    verbose=0,
    save_best_only=False,
    save_weights_only=False,
    mode='auto',
    period=1
)

1. filename:   ,       
2. monitor:      ,val_acc  val_loss
3. verbose:      ,01  
4. save_best_onlyTrue5. mode‘auto’minmaxsave_best_only=True              ,  ,     val_accmaxval_lossminauto   ,                。
6. save_weights_onlyTrue,        ,         (      ,     )
7. periodCheckPoint      epoch 

コード実装プロセス:
①keras.callbacksからModelCheckpointクラスをインポート
from keras.callbacks import ModelCheckpoint

②  訓練段階のmodel.compileの後に以下のコードを加えて、epoch(period=1)ごとに最適なパラメータを保存することを実現します.
checkpoint = ModelCheckpoint(filepath,
    monitor='val_loss', save_weights_only=True,verbose=1,save_best_only=True, period=1)
注意:filepathはパラメータを保存するパスです.ここでは「logs/000/trained_best_weights.h 5」です.
③訓練段階のmodel.fitの前に以前に保存したパラメータをロードする
if os.path.exists(filepath):
    model.load_weights(filepath)
    #     print("checkpoint_loaded")

④model.fitにcallbacks=[checkpoint]を追加してコールバックを実現
model.fit_generator(data_generator_wrap(lines[:num_train], batch_size, input_shape, anchors, num_classes),
        steps_per_epoch=max(1, num_train//batch_size),
        validation_data=data_generator_wrap(lines[num_train:], batch_size, input_shape, anchors, num_classes),
        validation_steps=max(1, num_val//batch_size),
        epochs=3,
        initial_epoch=0,
        callbacks=[checkpoint])

テスト出力:
①初回出力は、パラメータなしでロードでき、「checkpoint_loaded」は印刷されず、以下のように出力される(テストepoch=3)
keras ModelCheckpoint 实现断点续训功能_第1张图片
keras ModelCheckpoint 实现断点续训功能_第2张图片
②train.pyを再度実行し、先ほどのコードで直接model.fitの前に前回のトレーニングを保存したパラメータをロードし、トレーニングを継続することができます(lossの変化).注意「checkpoint_load」が出力されると、前に保存したパラメータが正常にロードされたことを示します.
keras ModelCheckpoint 实现断点续训功能_第3张图片
keras ModelCheckpoint 实现断点续训功能_第4张图片
ヒント:リファレンスリンクには簡単なテストコードがあります.以上は私のトレーニングデータで行ったテストだけです.詳細はリファレンスリンクを参照してください.
The end.