PyTorchプリトレーニングモデルの読み込み例(pretraned)


プリトレーニングモデルを使用するコードは以下の通りです。

#        
 resNet50 = models.resnet50(pretrained=True)
 ResNet50 = ResNet(Bottleneck, [3, 4, 6, 3], num_classes=2)

 #     
 pretrained_dict = resNet50.state_dict()
 model_dict = ResNet50.state_dict()

 #  pretained_dict    model_dict     
 pretrained_dict = {k: v for k, v in pretrained_dict.items() if k in model_dict}

 #      model_dict
 model_dict.update(pretrained_dict)

 #        state_dict
 ResNet50.load_state_dict(model_dict)
以上のPyTorchでプリトレーニングモデルをローディングした事例(pretraned)は、小編集が皆さんに共有している内容です。参考にしていただければと思います。どうぞよろしくお願いします。