pytorch finetuneモデル

2408 ワード

pytorch finetuneモデル


本文では,pytorch上で従来訓練されていたモデルパラメータをどのように読み取るか,モデルの名前が変更された場合にモデルの一部パラメータをどのように読み取るかなどについて述べる.---------作者:jiangwenj 02【転載は明記してください】

pytorchモデルの格納と読み取り


モデルの保存プロセスには、モデルとパラメータを格納するものと、モデルパラメータを個別に格納するものがあります.

モデルパラメータを個別に保存


保存時に使用:
torch.save(the_model.state_dict(), PATH)

読み込み時:
the_model = TheModelClass(*args, **kwargs)
the_model.load_state_dict(torch.load(PATH))

モデルとパラメータの保存


ストレージ:
torch.save(the_model, PATH)

読み込み:
the_model = torch.load(PATH)

モデルのパラメータ


fine-tuneのプロセスは、既存のモデルのパラメータを読み取ることですが、モデルの処理するデータセットが異なるため、最後のレイヤclassの総数が異なるため、モデルの最後のレイヤを修正する必要があります.このように、モデルが読み取るパラメータは、大きなデータセットでダウンロードを訓練したモデルパラメータとは形式的に異なります.関数読み取りパラメータを自分で書く必要があります.

pytorchモデルパラメータの形式


モデルのパラメータは辞書形式で格納されます.
model_dict = the_model.state_dict(),
for k,v in model_dict.items():
    print(k)

すべてのキー値が表示されます.モデルのパラメータを変更したい場合は、対応するキー値に値を割り当てるだけです.
model_dict[k] = new_value

最終更新モデルのパラメータ
the_model.load_state_dict(model_dict)

モデルのkey値とビッグデータセットで訓練したkey値が同じである場合


次のアルゴリズムでモデルを読み込むことができます.
model_dict = model.state_dict()

pretrained_dict = torch.load(model_path)
 # 1. filter out unnecessary keys
diff = {k: v for k, v in model_dict.items() if \
            k in pretrained_dict and pretrained_dict[k].size() == v.size()}
pretrained_dict = {k: v for k, v in pretrained_dict.items() if k in model_dict and model_dict[k].size() == v.size()}
pretrained_dict.update(diff)
# 2. overwrite entries in the existing state dict
model_dict.update(pretrained_dict)
# 3. load the new state dict
model.load_state_dict(model_dict)

モデルのkey値とビッグデータセットで訓練したkey値が異なる場合、順序は同じです。

model_dict = model.state_dict()

pretrained_dict = torch.load(model_path)
keys = []
for k,v in pretrained_dict.items():
    keys.append(k)
i = 0
for k,v in model_dict.items():
    if v.size() == pretrained_dict[keys[i]].size():
        print(k, ',', keys[i])
         model_dict[k]=pretrained_dict[keys[i]]
    i = i + 1
model.load_state_dict(model_dict)

モデルのkey値とビッグデータセットで訓練したkey値が異なる場合、順序も異なります。


自分で対応関係を探して、1つのkeyは1つのkeyの賦値に対応します