深さの目PyTorchキャンプ第4期(13):hook関数
計算図の実行中に中間変数のデータは保持されないことを知っている.中間変数を保存するには、
1.テンソルのhook関数
このhook関数は,逆伝搬を計算するたびに呼び出される.tensorのhook関数には次のような性質があります.
すなわちhook関数が値を返さなければ、テンソルの導関数は変化しない.戻り値がある場合、テンソル元の導関数は戻り値で上書きされます.テンソルのhook関数の使い方は以下の通りです.計算図を定義する. はhook関数を定義する. 登録hook関数; 逆伝播; hook関数を削除します.
PyTorchのシナモン4:torchを使いますAutographで導関数を計算する例:
実行後、PyTorchは、非リーフノードの導関数を取得しようとすると警告した.結果は次のとおりです.
本来
2.ニューラルネットワークのhook関数
ニューラルネットワークのhook関数は,順方向伝搬のhook関数と逆方向伝搬のhook関数に分けられ,それぞれ
ここでは,l 1損失関数を用いて得られた全接続層の逆伝搬導関数は,全接続ネットワークの重みが2の単層全接続ネットワークを定義した.
tnesor.retain_grad = True
を使用しますが、これによりデータが永続的に保存され、記憶空間が消費されます.すなわち、中間変数の値を1回しか使用できないかもしれませんが、永続的に保存されます.この問題を解決するためにPyTorchはhook関数を導入した.hook関数はテンソルのhook関数とニューラルネットワーク層のhook関数の2種類に分けられる.1.テンソルのhook関数
Tensor.register_hook(hook)
このhook関数は,逆伝搬を計算するたびに呼び出される.tensorのhook関数には次のような性質があります.
hook(grad) -> Tensor or None
すなわちhook関数が値を返さなければ、テンソルの導関数は変化しない.戻り値がある場合、テンソル元の導関数は戻り値で上書きされます.テンソルのhook関数の使い方は以下の通りです.
PyTorchのシナモン4:torchを使いますAutographで導関数を計算する例:
w = torch.tensor([1.], requires_grad=True) #
x = torch.tensor([2.], requires_grad=True)
a = torch.add(w, x)
b = torch.add(w, 1)
y = torch.mul(a, b)
a_grad = list()
def grad_hook(grad): # hook
a_grad.append(grad)
return grad * 3
handle = w.register_hook(grad_hook) # hook
y.backward() #
#
print("gradient:", w.grad, x.grad, a.grad, b.grad, y.grad)
print("a_grad[0]: ", a_grad[0])
print("w.grad: ", w.grad)
handle.remove() # hook
実行後、PyTorchは、非リーフノードの導関数を取得しようとすると警告した.結果は次のとおりです.
gradient: tensor([15.]) tensor([2.]) None None None
a_grad[0]: tensor([5.])
w.grad: tensor([15.])
本来
a
の導関数は解放されたが,hook関数によりa
の導関数を保存した.戻り値があるため、w
の導関数が覆われている.2.ニューラルネットワークのhook関数
ニューラルネットワークのhook関数は,順方向伝搬のhook関数と逆方向伝搬のhook関数に分けられ,それぞれ
register_forward_hook(hook)
とregister_backward_hook(hook)
であり,使用方法はテンソルのhook関数と大きく異なる.例を見てみましょうclass Net(nn.Module):
def __init__(self):
super(Net, self).__init__()
self.fc = nn.Linear(5, 1)
def forward(self, x):
x = self.fc(x)
return x
def forward_hook(module, data_input, data_output):
fmap_block.append(data_output)
input_block.append(data_input)
def backward_hook(module, grad_input, grad_output):
print("backward hook input:{}".format(grad_input))
print("backward hook output:{}".format(grad_output))
#
net = Net()
net.fc.weight.detach().fill_(2)
net.fc.bias.data.detach().zero_()
# hook
fmap_block = list()
input_block = list()
net.fc.register_forward_hook(forward_hook)
net.fc.register_backward_hook(backward_hook)
# inference
x = torch.ones(5)
output = net(x)
loss_fnc = nn.L1Loss()
target = torch.randn_like(output)
loss = loss_fnc(target, output)
loss.backward()
ここでは,l 1損失関数を用いて得られた全接続層の逆伝搬導関数は,全接続ネットワークの重みが2の単層全接続ネットワークを定義した.
backward hook input:(tensor([1.]), tensor([1.]))
backward hook output:(tensor([1.]),)