pytorchプリトレーニングモデルのロード
3527 ワード
pytorchプリトレーニングのmodelをロード
ブロガーも深さ分類ネットワークを使っていくつかの応用を実現し、比較試験を行うことを学び始めたばかりで、pytorchはみんなが極力推薦したdeep learningの枠組みで、python自身の言語スタイルとよく合っていて、matlabを書く感じがして、広く愛されています.博主は1篇の文章の対比実験の中でVGG、ResNetとXceptionNetと私の分類の任務の中で対比したいと思って、前もどんな学習の枠組みの応用に対して全然通じなくて、いくつかの材料を調べた後に自分で総括して、分かち合います.
私たちはあるネットワークを訓練する時(特にいくつかの名声の高いネットワーク)、ほとんど予備訓練モデルを使用して、モデルの訓練の起点が比較的に良いならば、適切な学習率の下でlossを迅速に収束させて、学習のプロセスを加速させて、fine-tuningとも呼ばれます.
プリトレーニングモデルをロードする本質的な問題は,使用したいパラメータのモデル構造が現在のモデル構造と完全に同じではなく,最も簡単なのは最後の全接続層のニューロン数が異なることである.本稿では、このような状況に対して、最も直接的な応用シーンは、例えば、1000分類のVGG-19のネットワークパラメータを初期パラメータとしてfine-tuningをしたいが、私の分類タスクはVGG-19を使用して簡単な二分類を実現するだけで、この場合、全接続層の前のパラメータを私のネットワークにロードするためにいくつかの操作が必要である.
VGG-19
ResNet-50
XceptionNet
まとめ
VGGロード時にロードするモデルが現在のモデルパラメータ構造と一致しなくても、
リネットやXceptionのようにもう一言追加する必要がなく、ロードを実現できます.
具体的な原因は私もはっきりしていないので、大神に教えてもらいたいのですが、このような方法でロードするのは親測で利用できます.
ブロガーも深さ分類ネットワークを使っていくつかの応用を実現し、比較試験を行うことを学び始めたばかりで、pytorchはみんなが極力推薦したdeep learningの枠組みで、python自身の言語スタイルとよく合っていて、matlabを書く感じがして、広く愛されています.博主は1篇の文章の対比実験の中でVGG、ResNetとXceptionNetと私の分類の任務の中で対比したいと思って、前もどんな学習の枠組みの応用に対して全然通じなくて、いくつかの材料を調べた後に自分で総括して、分かち合います.
私たちはあるネットワークを訓練する時(特にいくつかの名声の高いネットワーク)、ほとんど予備訓練モデルを使用して、モデルの訓練の起点が比較的に良いならば、適切な学習率の下でlossを迅速に収束させて、学習のプロセスを加速させて、fine-tuningとも呼ばれます.
プリトレーニングモデルをロードする本質的な問題は,使用したいパラメータのモデル構造が現在のモデル構造と完全に同じではなく,最も簡単なのは最後の全接続層のニューロン数が異なることである.本稿では、このような状況に対して、最も直接的な応用シーンは、例えば、1000分類のVGG-19のネットワークパラメータを初期パラメータとしてfine-tuningをしたいが、私の分類タスクはVGG-19を使用して簡単な二分類を実現するだけで、この場合、全接続層の前のパラメータを私のネットワークにロードするためにいくつかの操作が必要である.
VGG-19
model = vgg19(num_classes = 2) # VGG-19 , import
if use_gpu: # GPU
model = model.cuda()
model = torch.nn.DataParallel(model)
pretrained_dict = torch.load("lujing/vgg19-dcbb9e9d.pth") #
pretrained_dict = torch.load(args.resume)
model_dict = model.state_dict()
pretrained_dict = {k: v for k, v in pretrained_dict.items() if k in model_dict} #
model_dict.update(pretrained_dict) #
model.load_state_dict(model_dict)
ResNet-50
model = resnet50(num_classes = 2) # ResNet-50 , import
if use_gpu: # GPU
model = model.cuda()
model = torch.nn.DataParallel(model)
pretrained_dict = torch.load("lujing/resnet50-6-classes.pth") # , 6
pretrained_dict = torch.load(args.resume)
model_dict = model.state_dict()
pretrained_dict = {k: v for k, v in pretrained_dict.items() if k in model_dict} #
del pretrained_dict["module.fc.weight"] #
del pretrained_dict["module.fc.bias"]
model_dict.update(pretrained_dict) #
model.load_state_dict(model_dict)
XceptionNet
model = xception(num_classes = 2) # ResNet-50 , import
if use_gpu: # GPU
model = model.cuda()
model = torch.nn.DataParallel(model)
pretrained_dict = torch.load("lujing/xception-6-classes.pth") # , 6
pretrained_dict = torch.load(args.resume)
model_dict = model.state_dict()
pretrained_dict = {k: v for k, v in pretrained_dict.items() if k in model_dict} #
del pretrained_dict["module.fc.weight"] #
del pretrained_dict["module.fc.bias"]
model_dict.update(pretrained_dict) #
model.load_state_dict(model_dict)
まとめ
VGGロード時にロードするモデルが現在のモデルパラメータ構造と一致しなくても、
pretrained_dict = {k: v for k, v in pretrained_dict.items() if k in model_dict}
リネットやXceptionのようにもう一言追加する必要がなく、ロードを実現できます.
del pretrained_dict["module.fc.weight"]
del pretrained_dict["module.fc.bias"]
具体的な原因は私もはっきりしていないので、大神に教えてもらいたいのですが、このような方法でロードするのは親測で利用できます.