Boost camp Week3 Pytorch(2)

10596 ワード

Load Model


model.save()


学習結果の共有と格納に必要な関数
  • 学習結果を格納する関数
  • モデル形状(アーキテクチャ)とパラメータを格納
  • 保存モデルによる学習中間過程最適結果モデルの選択
  • 作成したモデルを外部研究者と共有して学習再現性を高める
  • # model's state_dict
    print("Model's state_dict:")
    for param_tensor in model.state_dict():
    	print(param_tensor, "\t", model.state_dict([param_tensor].size())
    
    torch.save(model.state_dict(), os.path.join(MODEL_PATH, "model.pt"))
    
    # 같은 모델의 형태에서 파라메터만 load 
    new_model = TheModelClass() new_model.load_state_dict(torch.load(os.path.join(MODEL_PATH, "model.pt")))
    
    #모델의 architecture와 함께 저장
    torch.save(model, os.path.join(MODEL_PATH, "model.pt"))
    model = torch.load(os.path.join(MODEL_PATH, "model.pt"))

    checkpoints

  • 学習の中間結果を保存し、ベスト結果を選択
  • earlystopテクニック使用時、以前学んだ成果を保存
  • 記憶損失とメトリック値の継続確認
  • 一般的にepoch、loss、metricを格納して確認する
  • colab継続学習が必要
  • import torch
    torch.save({ 'epoch': e,
    	'model_state_dict': model.state_dict(),
    	'optimizer_state_dict': optimizer.state_dict(),
        	'loss': epoch_loss},
    	f"saved/checkpoint_model_... .pt")
        
    checkpoint = torch.load(PATH)
    model.load_state_dict(checkpoint['model_state_dict'])
    optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
    epoch = checkpoint['epoch']
    loss = checkpoint['loss']

    Transfer Learning

  • 他のデータセットで作成したモデルを現在のデータに適用する
  • 通常ビッグデータセットで作成されるモデルの性能↑
  • 現在DLで最もよく使われている学習方法
  • マスターアーキテクチャ学習の良いモデルの一部
    変更学習

  • Freezing


    pretrainモデルを使用すると、モデルの一部が凍結されます.
    vgg = models.vgg16(pretrained=True).to(device)
    class MyNewNet(nn.Module):
        def __init__(self):
    	super(MyNewNet, self).__init__()
        	self.vgg19 = models.vgg19(pretrained=True)
            self.linear_layers = nn.Linear(1000, 1)
    	# Defining the forward pass
        def forward(self, x):
    	x = self.vgg19(x)
    	return self.linear_layers(x)
    
    #Freezing
    for param in my_model.parameters():
    	param.requires_grad = False
    for param in my_model.linear_layers.parameters():
    	param.requires_grad = True

    Montitoring Tools


    学習時間が長く、記録するツールが必要です.
  • Print文使用
  • Tensorboard
  • Weights and Biases
  • Tensorboard

  • TensorFlowプロジェクト作成のビジュアル化ツール
  • 学習チャート、メジャー、学習結果の可視化サポート
  • PyTorch接続可能→DL可視化コアツール
  • 特徴
  • スカラー:メートル法等の定数値を示す連続(epoch)
  • 図:表示モデルの計算図
  • ヒストグラム:重み値の分布を示す
  • 画像:予測値と実績値の比較
  • mesh:3 d形状データを表すツール
  • Weights & Biases

  • 機械学習実験をサポートする常用ツール
  • 連携、コードバージョン管理、実験結果記録等提供
  • MLOpsの代表的なツールで、低拡張中
  • Multi-GPU


    Hyperparameter Tuning


    PyTorch Troubleshooting