[Pytorch]モデルの学習パラメータを手動で代入したい
記事を読んで何か指摘があれば、遠慮なくどうぞ。
いいねもしてもらえると励みになります。
モデルの学習パラメータを手動で代入したいとき
Pytorchを使っている時に、パラメータの値に制約をつけたい時に以下の方法で実現可能です。
具体例
- 具体的には、$w > 0$ のように、パラメータの行列またはベクトルの各要素に値の制約をつけたい場合に以下のように実現できます。
- 以下のコードでは、モデル内のパラメータの全ての最小値が
min=1e-4
になるように設定しています。 -
for
で行なっているのは、keys()
を使って、パラメータの全てに順番にアクセスして、clamp
という関数を適用しています。(numpy
で言う所のnp.clip
です。)
コード
state_dict = model.state_dict()#モデル内のパラメータの呼び出し
for k in state_dict.keys():
state_dict[k] = torch.clamp(state_dict[k], min=1e-4)
model.load_state_dict(state_dict)
ポイント
-
torch.clamp
の使用 -
model.load_state_dict
の使用
解説
- 例えば、値の最小値が0という制約をつけたい時、
torch.clamp
という関数を使って、値の制限をすることができます。torch.clamp(input,min=0,max=10)
などで値の範囲を制限することができます。 - 直接、モデルのパラメータに値を代入することは難しいので、
model.load_state_dict
という方法があるみたいです。
Author And Source
この問題について([Pytorch]モデルの学習パラメータを手動で代入したい), 我々は、より多くの情報をここで見つけました https://qiita.com/marusta/items/1ac26af5b39d757cbb07著者帰属:元の著者の情報は、元のURLに含まれています。著作権は原作者に属する。
Content is automatically searched and collected through network algorithms . If there is a violation . Please contact us . We will adjust (correct author information ,or delete content ) as soon as possible .