Pytorchのモデルのロード/保存
pytorchモデルを保存するには、次の2つの方法があります.モデル全体(構造+パラメータ) を保存パラメータのみ保存(公式推奨) どちらも
モデル全体の保存
この方法は簡単で、保存とロードは2行のコードで、Python pickleパッケージの使い方と同じで、modelを1つのオブジェクトとして直接保存してロードすればいいです.
Note:PyTorch約定使用.ptまたは.pth接尾辞名前保存ファイル.
パラメータの保存
この方法を重点的に紹介します.一般的に1つのモデルを訓練した後、私たちは単独で1つのモデルのパラメータだけを保存しません.回復訓練、パラメータの移行など、後続の操作を容易にするために、現在の回転状態のスナップショットを保存します.具体的な情報は自分の必要に応じて、以下にいくつかの方面をリストします.モデルパラメータ オプティマイザパラメータ loss epoch args
これらの情報を辞書で包装して保存すればいいです.
この方法で保存されているモデルはパラメータにすぎないので、ロード時にモデルを作成してからパラメータをロードする必要があります.以下のようにします.
Note:PyTorch約定使用.ptまたは.pth接尾辞名前保存ファイル.
torch.save(obj, dir)
で実現され、この関数の役割はオブジェクトをディスクに保存することであり、その内部はPythonのpickleで実現されている.2つの方法の違いは実はobjパラメータの違いである:前者のobjはモデル全体のオブジェクトであり、後者のobjはモデルから取得したモデルパラメータを格納した辞書であり、2つ目を推奨する.面倒だが、比較的柔軟で、予備訓練、パラメータ移行などの操作を実現するのに有利である.モデル全体の保存
この方法は簡単で、保存とロードは2行のコードで、Python pickleパッケージの使い方と同じで、modelを1つのオブジェクトとして直接保存してロードすればいいです.
#
model = Mymodel()
torch.save(model, path)
#
model = torch.load(path)
Note:PyTorch約定使用.ptまたは.pth接尾辞名前保存ファイル.
パラメータの保存
この方法を重点的に紹介します.一般的に1つのモデルを訓練した後、私たちは単独で1つのモデルのパラメータだけを保存しません.回復訓練、パラメータの移行など、後続の操作を容易にするために、現在の回転状態のスナップショットを保存します.具体的な情報は自分の必要に応じて、以下にいくつかの方面をリストします.
これらの情報を辞書で包装して保存すればいいです.
この方法で保存されているモデルはパラメータにすぎないので、ロード時にモデルを作成してからパラメータをロードする必要があります.以下のようにします.
#
save_data = {
'model_state_dict': model.state_dict(),
'optimizer_state_dict': optimizer.state_dict(),
'loss': loss,
'epoch': epoch,
'args': args
...
}
#
torch.save(save_data , path)
load_data = torch.load(path)
model = Mymodel()
optimizer = Myoptimizer()
#
model.load_state_dict(load_data ['model_state_dict'])
optimizer.load_state_dict(load_data ['optimizer_state_dict'])
...
Note:PyTorch約定使用.ptまたは.pth接尾辞名前保存ファイル.