pytorchプリトレーニングモデルのロード


pytorchプリトレーニングのmodelをロード
ブロガーも深さ分類ネットワークを使っていくつかの応用を実現し、比較試験を行うことを学び始めたばかりで、pytorchはみんなが極力推薦したdeep learningの枠組みで、python自身の言語スタイルとよく合っていて、matlabを書く感じがして、広く愛されています.博主は1篇の文章の対比実験の中でVGG、ResNetとXceptionNetと私の分類の任務の中で対比したいと思って、前もどんな学習の枠組みの応用に対して全然通じなくて、いくつかの材料を調べた後に自分で総括して、分かち合います.
私たちはあるネットワークを訓練する時(特にいくつかの名声の高いネットワーク)、ほとんど予備訓練モデルを使用して、モデルの訓練の起点が比較的に良いならば、適切な学習率の下でlossを迅速に収束させて、学習のプロセスを加速させて、fine-tuningとも呼ばれます.
プリトレーニングモデルをロードする本質的な問題は,使用したいパラメータのモデル構造が現在のモデル構造と完全に同じではなく,最も簡単なのは最後の全接続層のニューロン数が異なることである.本稿では、このような状況に対して、最も直接的な応用シーンは、例えば、1000分類のVGG-19のネットワークパラメータを初期パラメータとしてfine-tuningをしたいが、私の分類タスクはVGG-19を使用して簡単な二分類を実現するだけで、この場合、全接続層の前のパラメータを私のネットワークにロードするためにいくつかの操作が必要である.
VGG-19
model = vgg19(num_classes = 2)                                                   #      VGG-19   ,        import
if use_gpu:                                                                      #   GPU
     model = model.cuda()
     model = torch.nn.DataParallel(model)

pretrained_dict = torch.load("lujing/vgg19-dcbb9e9d.pth")                        #         

pretrained_dict = torch.load(args.resume)
model_dict = model.state_dict()
pretrained_dict = {k: v for k, v in pretrained_dict.items() if k in model_dict}  #         
model_dict.update(pretrained_dict)                                               #       
model.load_state_dict(model_dict)

ResNet-50
model = resnet50(num_classes = 2)                                                #      ResNet-50   ,            import
if use_gpu:                                                                      #   GPU
     model = model.cuda()
     model = torch.nn.DataParallel(model)

pretrained_dict = torch.load("lujing/resnet50-6-classes.pth")                    #         ,      6  

pretrained_dict = torch.load(args.resume)
model_dict = model.state_dict()
pretrained_dict = {k: v for k, v in pretrained_dict.items() if k in model_dict}  #         
del pretrained_dict["module.fc.weight"]                                          #                
del pretrained_dict["module.fc.bias"]
model_dict.update(pretrained_dict)                                               #       
model.load_state_dict(model_dict)

XceptionNet
model = xception(num_classes = 2)                                                #      ResNet-50   ,             import
if use_gpu:                                                                      #   GPU
     model = model.cuda()
     model = torch.nn.DataParallel(model)

pretrained_dict = torch.load("lujing/xception-6-classes.pth")                    #         ,      6  

pretrained_dict = torch.load(args.resume)
model_dict = model.state_dict()
pretrained_dict = {k: v for k, v in pretrained_dict.items() if k in model_dict}  #         
del pretrained_dict["module.fc.weight"]                                          #                
del pretrained_dict["module.fc.bias"]
model_dict.update(pretrained_dict)                                               #       
model.load_state_dict(model_dict)

まとめ
VGGロード時にロードするモデルが現在のモデルパラメータ構造と一致しなくても、
pretrained_dict = {k: v for k, v in pretrained_dict.items() if k in model_dict} 

リネットやXceptionのようにもう一言追加する必要がなく、ロードを実現できます.
del pretrained_dict["module.fc.weight"]
del pretrained_dict["module.fc.bias"]

具体的な原因は私もはっきりしていないので、大神に教えてもらいたいのですが、このような方法でロードするのは親測で利用できます.