Pytorchのモデルのロード/保存


pytorchモデルを保存するには、次の2つの方法があります.
  • モデル全体(構造+パラメータ)
  • を保存
  • パラメータのみ保存(公式推奨)
  • どちらもtorch.save(obj, dir)で実現され、この関数の役割はオブジェクトをディスクに保存することであり、その内部はPythonのpickleで実現されている.2つの方法の違いは実はobjパラメータの違いである:前者のobjはモデル全体のオブジェクトであり、後者のobjはモデルから取得したモデルパラメータを格納した辞書であり、2つ目を推奨する.面倒だが、比較的柔軟で、予備訓練、パラメータ移行などの操作を実現するのに有利である.
    モデル全体の保存
    この方法は簡単で、保存とロードは2行のコードで、Python pickleパッケージの使い方と同じで、modelを1つのオブジェクトとして直接保存してロードすればいいです.
    #   
    model = Mymodel()
    torch.save(model, path)
    #    
    model = torch.load(path)
    

    Note:PyTorch約定使用.ptまたは.pth接尾辞名前保存ファイル.
    パラメータの保存
    この方法を重点的に紹介します.一般的に1つのモデルを訓練した後、私たちは単独で1つのモデルのパラメータだけを保存しません.回復訓練、パラメータの移行など、後続の操作を容易にするために、現在の回転状態のスナップショットを保存します.具体的な情報は自分の必要に応じて、以下にいくつかの方面をリストします.
  • モデルパラメータ
  • オプティマイザパラメータ
  • loss
  • epoch
  • args

  • これらの情報を辞書で包装して保存すればいいです.
    この方法で保存されているモデルはパラメータにすぎないので、ロード時にモデルを作成してからパラメータをロードする必要があります.以下のようにします.
    #       
    save_data = {
        'model_state_dict': model.state_dict(),
        'optimizer_state_dict': optimizer.state_dict(),
        'loss': loss,
        'epoch': epoch,
        'args': args
         ...
    }
    #   
    torch.save(save_data , path)
    load_data = torch.load(path)
    model = Mymodel()
    optimizer = Myoptimizer()
    #     
    model.load_state_dict(load_data ['model_state_dict'])
    optimizer.load_state_dict(load_data ['optimizer_state_dict'])
    ...
    

    Note:PyTorch約定使用.ptまたは.pth接尾辞名前保存ファイル.