pytorch finetuneモデル
2408 ワード
pytorch finetuneモデル
本文では,pytorch上で従来訓練されていたモデルパラメータをどのように読み取るか,モデルの名前が変更された場合にモデルの一部パラメータをどのように読み取るかなどについて述べる.---------作者:jiangwenj 02【転載は明記してください】
pytorchモデルの格納と読み取り
モデルの保存プロセスには、モデルとパラメータを格納するものと、モデルパラメータを個別に格納するものがあります.
モデルパラメータを個別に保存
保存時に使用:torch.save(the_model.state_dict(), PATH)
読み込み時:the_model = TheModelClass(*args, **kwargs)
the_model.load_state_dict(torch.load(PATH))
モデルとパラメータの保存
ストレージ:torch.save(the_model, PATH)
読み込み:the_model = torch.load(PATH)
モデルのパラメータ
fine-tuneのプロセスは、既存のモデルのパラメータを読み取ることですが、モデルの処理するデータセットが異なるため、最後のレイヤclassの総数が異なるため、モデルの最後のレイヤを修正する必要があります.このように、モデルが読み取るパラメータは、大きなデータセットでダウンロードを訓練したモデルパラメータとは形式的に異なります.関数読み取りパラメータを自分で書く必要があります.
pytorchモデルパラメータの形式
モデルのパラメータは辞書形式で格納されます.model_dict = the_model.state_dict(),
for k,v in model_dict.items():
print(k)
すべてのキー値が表示されます.モデルのパラメータを変更したい場合は、対応するキー値に値を割り当てるだけです.model_dict[k] = new_value
最終更新モデルのパラメータthe_model.load_state_dict(model_dict)
モデルのkey値とビッグデータセットで訓練したkey値が同じである場合
次のアルゴリズムでモデルを読み込むことができます.model_dict = model.state_dict()
pretrained_dict = torch.load(model_path)
# 1. filter out unnecessary keys
diff = {k: v for k, v in model_dict.items() if \
k in pretrained_dict and pretrained_dict[k].size() == v.size()}
pretrained_dict = {k: v for k, v in pretrained_dict.items() if k in model_dict and model_dict[k].size() == v.size()}
pretrained_dict.update(diff)
# 2. overwrite entries in the existing state dict
model_dict.update(pretrained_dict)
# 3. load the new state dict
model.load_state_dict(model_dict)
モデルのkey値とビッグデータセットで訓練したkey値が異なる場合、順序は同じです。
model_dict = model.state_dict()
pretrained_dict = torch.load(model_path)
keys = []
for k,v in pretrained_dict.items():
keys.append(k)
i = 0
for k,v in model_dict.items():
if v.size() == pretrained_dict[keys[i]].size():
print(k, ',', keys[i])
model_dict[k]=pretrained_dict[keys[i]]
i = i + 1
model.load_state_dict(model_dict)
モデルのkey値とビッグデータセットで訓練したkey値が異なる場合、順序も異なります。
自分で対応関係を探して、1つのkeyは1つのkeyの賦値に対応します
torch.save(the_model.state_dict(), PATH)
the_model = TheModelClass(*args, **kwargs)
the_model.load_state_dict(torch.load(PATH))
torch.save(the_model, PATH)
the_model = torch.load(PATH)
model_dict = the_model.state_dict(),
for k,v in model_dict.items():
print(k)
model_dict[k] = new_value
the_model.load_state_dict(model_dict)
model_dict = model.state_dict()
pretrained_dict = torch.load(model_path)
# 1. filter out unnecessary keys
diff = {k: v for k, v in model_dict.items() if \
k in pretrained_dict and pretrained_dict[k].size() == v.size()}
pretrained_dict = {k: v for k, v in pretrained_dict.items() if k in model_dict and model_dict[k].size() == v.size()}
pretrained_dict.update(diff)
# 2. overwrite entries in the existing state dict
model_dict.update(pretrained_dict)
# 3. load the new state dict
model.load_state_dict(model_dict)
model_dict = model.state_dict()
pretrained_dict = torch.load(model_path)
keys = []
for k,v in pretrained_dict.items():
keys.append(k)
i = 0
for k,v in model_dict.items():
if v.size() == pretrained_dict[keys[i]].size():
print(k, ',', keys[i])
model_dict[k]=pretrained_dict[keys[i]]
i = i + 1
model.load_state_dict(model_dict)