pytorchトレーニングモデルの保存とリカバリ

4154 ワード

モデルトレーニング後、テストと導入のためにファイルに保存する必要があります.または、直前の訓練状態を継続する.
https://pytorch.org/tutorials/beginner/saving_loading_models.html
1. Best Practices
https://github.com/pytorch/pytorch/blob/761d6799beb3afa03657a71776412a2171ee7533/docs/source/notes/serialization.rst
主に2種類のモデルのシーケンス化保存とロード回復の方法がある.1.1方法M 1-推奨
リカバリモデルパラメータ(model parameters):
import torch 

#  
torch.save(the_model.state_dict(), PATH)

#  
the_model = TheModelClass(*args, **kwargs)
the_model.load_state_dict(torch.load(PATH))

この方法は、モデルのネットワーク構造情報を独自に導入する必要がある.
1.2方法M 2
モデルのパラメータとネットワーク構造情報を同時に保存します.
import torch

#  
torch.save(the_model, PATH)

#  
the_model = torch.load(PATH)

この方法で保存されたデータは、特定のclassesと使用される正確なディレクトリ構造にバインドされている.'#このため、再ロード後に多くの再構成を経ると、混乱するおそれがある.
2.Stackoverflow回答
From: Best way to save a trained model in PyTorch?
適用シーンに基づいて、モデルの保存とロードの復元方法を選択する.シーンC 1-モデルは推定用に保存されます
自分でモデルを保存し、自分でモデルを復元し、その後、モデルをevaluationモードに変更する.
これは、デフォルトでは、ネットワークモデルの訓練時にBatchNormとDropoutのネットワーク層があることが多いためである.
#  
torch.save(model.state_dict(), filepath)

#  
model.load_state_dict(torch.load(filepath))
model.eval()

シーンC 2-モデル保存リカバリトレーニング用
モデル訓練時、その訓練状態を維持する.モデルモデル、オプティマイザの状態(optimizer state)、epochs、scoreなどを同時に保存する必要がある.
#  
state = {
    'epoch': epoch,
    'state_dict': model.state_dict(),
    'optimizer': optimizer.state_dict(),


     ...
    }
    torch.save(state, filepath)
#  , 
    model.load_state_dict(state['state_dict'])
    optimizer.load_state_dict(state['optimizer'])
#  ,  model.eval().

シーンC 3-モデル保存を共有する
TensorFlowでは、1つ作成できます.pbファイルは、ネットワーク構造とモデルの重みを定義する.この方式は非常に便利で、特にTensorflow serveを使用する.
同様にPytorchでは
#  
torch.save(model, filepath)

#  
model = torch.load(filepath)

Pytorchはまだバージョンの更新が変化しているため、この方法はまだ安定していない.だからお勧めしません.3.例
From  PyTorch 

import torch

state = {
    'epoch': epoch,
    'state_dict': model.state_dict(),
    'optimizer': optimizer.state_dict(),
    'best_score': best_score,
    ...
}

torch.save(state, '/path/to/checkpoint.pth' )

if resume:
    if os.path.isfile(resume_file):
        print("=> loading checkpoint '{}'".format(resume_file))
        checkpoint = torch.load(resume_file)
        start_epoch = checkpoint['epoch']
        best_score = checkpoint['best_score']
        model.load_state_dict(checkpoint['state_dict'])

もう一つの比較的完全な例
#saving
torch.save({
            'epoch': epoch + 1,
            'arch': args.arch,
            'state_dict': model.state_dict(),
            'best_prec1': best_prec1,
        }, 'checkpoint.tar' )

#loading

if args.resume:
        if os.path.isfile(args.resume):
            print("=> loading checkpoint '{}'".format(args.resume))
            checkpoint = torch.load(args.resume)
            args.start_epoch = checkpoint['epoch']
            best_prec1 = checkpoint['best_prec1']
            model.load_state_dict(checkpoint['state_dict'])
            print("=> loaded checkpoint '{}' (epoch {})"
                  .format(args.evaluate, checkpoint['epoch']))

モデルネットワーク層のパラメータの可視化:
import torch.nn as nn
from collections import OrderedDict

#  
model = nn.Sequential(OrderedDict([
('conv1', nn.Conv2d(1,32,5)),
('relu1', nn.ReLU()),
('conv2', nn.Conv2d(32,64,5)),
('relu2', nn.ReLU())
]))
print(model)

#  
params=model.state_dict()
for k,v in params.items():
    print(k) #  
    print(params['conv1.weight']) # conv1   weight
    print(params['conv1.bias']) # conv1   bias

参考文献:https://www.aiuai.cn/aifarm657.html https://pytorch.org/tutorials/beginner/saving_loading_models.html https://byjiang.com/2017/06/05/How_To_Save_And_Restore_Model/https://www.pytorchtutorial.com/pytorch-note5-save-and-restore-models/