pytorchステータス辞書:state_dict

5066 ワード

pytorchのstate_dictは簡単なpythonの辞書オブジェクトであり、各層とその対応するパラメータとマッピング関係を確立する.(モデルの各層のweightsやバイアスなど)
(ボリューム層、線形層などのモデルのstate_dictには、パラメータが訓練可能なlayerのみが保存されることに注意してください.
オプティマイザオブジェクトOptimizerにもstate_がありますdictは、オプティマイザの状態およびlr、momentum、weight_decayなどの使用されるスーパーパラメータを含む
 
コメント:
1) state_dictはmodelまたはoptimizerを定義後にpytorchが自動的に生成し、直接呼び出すことができる.よく使う保存state_dictのフォーマットは「.pt」または'.pth'のファイル、すなわち次のコマンドのPATH="./***.pt"
torch.save(model.state_dict(), PATH)

2) load_state_dictもmodelやoptimizer以降pytorchが自動的に備える関数で、直接呼び出すことができます
model = TheModelClass(*args, **kwargs)
model.load_state_dict(torch.load(PATH))
model.eval()

注意:model.eval()の重要性は,2)で最後にmodelを用いた.eval()は、このコマンドを実行する後にのみ、「dropoutレイヤ」および「batch normalizationレイヤ」がevalutionモードに入るためである.「訓練(training)モード」と「評価(evalution)モード」の下で、この2つの層は異なる表現形式を持っている.
-------------------------------------------------------------------------------------------------------------------------------
モダリティ辞書(state_dict)の保存(model )
1.1)学習したパラメータのみを保存し、以下のコマンドを使用する    torch.save(model.state_dict(), PATH) 1.2) model.state_dict,     model = TheModelClass(*args, **kwargs)
    model.load_state_dict(torch.load(PATH))
    model.eval()
    :model.load_state_dict , ----------- 2.1) model ,     torch.save(model,PATH) 2.2) model , :
          # Model class must be defined somewhere     model = torch.load(PATH)     model.eval()
--------------------------------------------------------------------------------------------------------------------------------------
state_dictはpythonの辞書形式で、辞書の形式で格納され、辞書の形式でロードされ、keyが一致する項目のみがロードされます.
----------------------------------------------------------------------------------------------------------------------
あるレベルのトレーニングのパラメータ(あるレベルのstate)のみをロードする方法
If you want to load parameters from one layer to another, but some keys do not match, simply change the name of the parameter keys in the state_dict that you are loading to match the keys in the model that you are loading into.
conv1_weight_state = torch.load('./model_state_dict.pt')['conv1.weight']

--------------------------------------------------------------------------------------------
モデルパラメータをロードした後、あるレベルのパラメータの「トレーニングが必要かどうか」(param.requires_grad)を設定する方法
for param in list(model.pretrained.parameters()):
    param.requires_grad = False

注意:requires_gradの操作対象はtensor.
疑問:ある層に直接requiresを使うことができますか?gradは?例:model.conv1.requires_grad=False
回答:テストを受けて、できません.model.conv 1にはrequiresがありませんgrad属性
 
---------------------------------------------------------------------------------------------
すべてのテストコード:
#-*-coding:utf-8-*-
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim



# define model
class TheModelClass(nn.Module):
    def __init__(self):
        super(TheModelClass,self).__init__()
        self.conv1 = nn.Conv2d(3,6,5)
        self.pool = nn.MaxPool2d(2,2)
        self.conv2 = nn.Conv2d(6,16,5)
        self.fc1 = nn.Linear(16*5*5,120)
        self.fc2 = nn.Linear(120,84)
        self.fc3 = nn.Linear(84,10)

    def forward(self,x):
        x = self.pool(F.relu(self.conv1(x)))
        x = self.pool(F.relu(self.conv2(x)))
        x = x.view(-1,16*5*5)
        x = F.relu(self.fc1(x))
        x = F.relu(self.fc2(x))
        x = self.fc3(x)
        return x

# initial model
model = TheModelClass()

#initialize the optimizer
optimizer = optim.SGD(model.parameters(),lr=0.001,momentum=0.9)

# print the model's state_dict
print("model's state_dict:")
for param_tensor in model.state_dict():
    print(param_tensor,'\t',model.state_dict()[param_tensor].size())

print("
optimizer's state_dict") for var_name in optimizer.state_dict(): print(var_name,'\t',optimizer.state_dict()[var_name]) print("
print particular param") print('
',model.conv1.weight.size()) print('
',model.conv1.weight) print("------------------------------------") torch.save(model.state_dict(),'./model_state_dict.pt') # model_2 = TheModelClass() # model_2.load_state_dict(torch.load('./model_state_dict')) # model.eval() # print('
',model_2.conv1.weight) # print((model_2.conv1.weight == model.conv1.weight).size()) ## conv1_weight_state = torch.load('./model_state_dict.pt')['conv1.weight'] print(conv1_weight_state==model.conv1.weight) model_2 = TheModelClass() model_2.load_state_dict(torch.load('./model_state_dict.pt')) model_2.conv1.requires_grad=False print(model_2.conv1.requires_grad) print(model_2.conv1.bias.requires_grad)