PyTorchプリトレーニングモデルの読み込み例(pretraned)
プリトレーニングモデルを使用するコードは以下の通りです。
#
resNet50 = models.resnet50(pretrained=True)
ResNet50 = ResNet(Bottleneck, [3, 4, 6, 3], num_classes=2)
#
pretrained_dict = resNet50.state_dict()
model_dict = ResNet50.state_dict()
# pretained_dict model_dict
pretrained_dict = {k: v for k, v in pretrained_dict.items() if k in model_dict}
# model_dict
model_dict.update(pretrained_dict)
# state_dict
ResNet50.load_state_dict(model_dict)
以上のPyTorchでプリトレーニングモデルをローディングした事例(pretraned)は、小編集が皆さんに共有している内容です。参考にしていただければと思います。どうぞよろしくお願いします。