[Pytorch]モデルの学習パラメータを手動で代入したい


記事を読んで何か指摘があれば、遠慮なくどうぞ。
いいねもしてもらえると励みになります。

モデルの学習パラメータを手動で代入したいとき

Pytorchを使っている時に、パラメータの値に制約をつけたい時に以下の方法で実現可能です。

具体例

  • 具体的には、$w > 0$ のように、パラメータの行列またはベクトルの各要素に値の制約をつけたい場合に以下のように実現できます。
  • 以下のコードでは、モデル内のパラメータの全ての最小値がmin=1e-4になるように設定しています。
  • forで行なっているのは、keys()を使って、パラメータの全てに順番にアクセスして、clampという関数を適用しています。(numpyで言う所のnp.clipです。)

コード

state_dict = model.state_dict()#モデル内のパラメータの呼び出し
for k in state_dict.keys():
    state_dict[k] = torch.clamp(state_dict[k], min=1e-4)
model.load_state_dict(state_dict)

ポイント

  1. torch.clampの使用
  2. model.load_state_dictの使用

解説

  • 例えば、値の最小値が0という制約をつけたい時、torch.clampという関数を使って、値の制限をすることができます。torch.clamp(input,min=0,max=10)などで値の範囲を制限することができます。
  • 直接、モデルのパラメータに値を代入することは難しいので、model.load_state_dictという方法があるみたいです。