Tensorflow2.0モデルを保存およびロードする方法

5648 ワード

ゼロ、総説

  • save/load weights

  • save/load entire model

  • saved_model


  • 一、Save the weights


    1.すべてのパラメータを一度に保存

    model.save_weights('./checkpoints/my_checkpoint') 
    

    2.ウエイトのロード


    この方法でモデルを保存するには、パラメータのみが保存され、ファイルが小さく、ロードが速いが、テスト/導入時に構築ネットワークを再構築する必要があることに注意してください.
    model = create_model() # 
    model.load_weights('./checkpoints/my_checkpoint') # 
    
    loss, acc = model.evaluate(test_images, test_labels)
    
    network.save_weights('weights.ckpt') # 
    print('saved weights')
    del network
    
    network = Sequential([layers.Dense(256)...])# 
    network.compile(optimizer=optimizer.Adam(lr=0.01),loss=tf.losses.CategoricalCrossentropy(from_logits=True),
                    metrics=['accuracy'])
    network.load_weights('weights.ckpt') # 
    network.evaluate(ds_val)
    

    二.Save the model


    この方法はモデルも保存し,ファイルが大きく,効率が低い.
    # 
    network.save('model.h5')
    # 
    del network
    # 
    network = tf.keras.models.load_model('model.h5')
    network.evaluate(x_val, y_val)
    

    #三、ONNXはonnxとして保存され、これは汎用フォーマットであり、pythonが生成したのはc++で解析することができ、一般的にpythonは訓練してC++で配置する.なお、ONNXは、NVIDIAの組み込み機器に配備するためにTensorRTを転送することができる.
    
    tf.saved_model.save(m, '/tmp/saved_model/') # 
    
    imported = tf.saved_model.load(path) # Load
    f = imported.signatures["serving_default"]