深さの目PyTorchキャンプ第4期(13):hook関数


計算図の実行中に中間変数のデータは保持されないことを知っている.中間変数を保存するには、tnesor.retain_grad = Trueを使用しますが、これによりデータが永続的に保存され、記憶空間が消費されます.すなわち、中間変数の値を1回しか使用できないかもしれませんが、永続的に保存されます.この問題を解決するためにPyTorchはhook関数を導入した.hook関数はテンソルのhook関数とニューラルネットワーク層のhook関数の2種類に分けられる.
1.テンソルのhook関数
  • Tensor.register_hook(hook)

  • このhook関数は,逆伝搬を計算するたびに呼び出される.tensorのhook関数には次のような性質があります.
    hook(grad) -> Tensor or None
    

    すなわちhook関数が値を返さなければ、テンソルの導関数は変化しない.戻り値がある場合、テンソル元の導関数は戻り値で上書きされます.テンソルのhook関数の使い方は以下の通りです.
  • 計算図を定義する.
  • はhook関数を定義する.
  • 登録hook関数;
  • 逆伝播;
  • hook関数を削除します.

  • PyTorchのシナモン4:torchを使いますAutographで導関数を計算する例:
    w = torch.tensor([1.], requires_grad=True) #      
    x = torch.tensor([2.], requires_grad=True)
    a = torch.add(w, x)
    b = torch.add(w, 1)
    y = torch.mul(a, b)
    
    a_grad = list()
    
    def grad_hook(grad): #    hook   
        a_grad.append(grad)
        return grad * 3
    
    handle = w.register_hook(grad_hook) #    hook   
    
    y.backward() #     
    
    #     
    print("gradient:", w.grad, x.grad, a.grad, b.grad, y.grad)
    print("a_grad[0]: ", a_grad[0])
    print("w.grad: ", w.grad)
    handle.remove() #    hook   
    

    実行後、PyTorchは、非リーフノードの導関数を取得しようとすると警告した.結果は次のとおりです.
    gradient: tensor([15.]) tensor([2.]) None None None
    a_grad[0]:  tensor([5.])
    w.grad:  tensor([15.])
    

    本来aの導関数は解放されたが,hook関数によりaの導関数を保存した.戻り値があるため、wの導関数が覆われている.
    2.ニューラルネットワークのhook関数
    ニューラルネットワークのhook関数は,順方向伝搬のhook関数と逆方向伝搬のhook関数に分けられ,それぞれregister_forward_hook(hook)register_backward_hook(hook)であり,使用方法はテンソルのhook関数と大きく異なる.例を見てみましょう
    class Net(nn.Module):
        def __init__(self):
            super(Net, self).__init__()
            self.fc = nn.Linear(5, 1)
    
        def forward(self, x):
            x = self.fc(x)
            return x
    
    
    def forward_hook(module, data_input, data_output):
        fmap_block.append(data_output)
        input_block.append(data_input)
    
    def backward_hook(module, grad_input, grad_output):
        print("backward hook input:{}".format(grad_input))
        print("backward hook output:{}".format(grad_output))
    
    
    #      
    net = Net()
    net.fc.weight.detach().fill_(2)
    net.fc.bias.data.detach().zero_()
    
    #   hook
    fmap_block = list()
    input_block = list()
    net.fc.register_forward_hook(forward_hook)
    net.fc.register_backward_hook(backward_hook)
    
    # inference
    x = torch.ones(5)
    output = net(x)
    
    loss_fnc = nn.L1Loss()
    target = torch.randn_like(output)
    loss = loss_fnc(target, output)
    loss.backward()
    

    ここでは,l 1損失関数を用いて得られた全接続層の逆伝搬導関数は,全接続ネットワークの重みが2の単層全接続ネットワークを定義した.
    backward hook input:(tensor([1.]), tensor([1.]))
    backward hook output:(tensor([1.]),)