PyTorch - Autograd


g=g(x)=2x3g = g(x) = 2x^3g=g(x)=2x3
h=h(x)=112xh = h(x) =\frac{1}{12}xh=h(x)=121​x
上記の2つの関数があり、zzzは以下のように定義されています.
z=h(g(x))z=h(g(x))z=h(g(x))
zzzをxxxに微分し、Chainルールで以下のようにします.
∂z∂x=∂z∂g⋅∂g∂x\frac{\partial z}{\partial x}={\color{red}\frac{\partial z}{\partial g}}\cdot{\color{blue}\frac{\partial g}{\partial x}}∂x∂z​=∂g∂z​⋅∂x∂g​
∂z∂g=∂∂g⋅h(g)=∂∂g⋅112g=112{\color{red}\frac{\partial z}{\partial g}} =\frac{\partial}{\partial g}\cdot h(g)=\frac{\partial}{\partial g}\cdot\frac{1}{12}g=\frac{1}{12}∂g∂z​=∂g∂​⋅h(g)=∂g∂​⋅121​g=121​
∂g∂x=∂∂x⋅g(x)=∂∂x2x3=6x2{\color{blue}\frac{\partial g}{\partial x}}=\frac{\partial}{\partial x}\cdot g(x)=\frac{\partial}{\partial x}2x^3=6x^2∂x∂g​=∂x∂​⋅g(x)=∂x∂​2x3=6x2
∴∂z∂x=112⋅6x2=12x2\therefore\frac{\partial z}{\partial x}=\frac{1}{12}\cdot6x^2=\frac{1}{2}x^2∴∂x∂z​=121​⋅6x2=21​x2
以下に示すようにPyTorchとして実装する.(関数の動作が異なる)
# Settings
import torch

# Make input x
x = torch.randn(3,3, requires_grad=True)
print(x)

tensor([[ 0.4490, -0.7287, -1.2367],
        [ 1.0262,  0.3886, -1.2934],
        [ 0.2356, -1.5048,  0.1285]], requires_grad=True)
まずxxxに任意の値を持つTensorを割り当てます.requires_grad=Trueを介してTensorのGradientを取得する準備をする.
[!] requires_grad=Trueで生成されたTensorを別の操作を行い、xxx変数に割り当て、x.gradでGradientを読み込むとエラーが発生します.
[+]Tensorを印刷し、最後にrequires_grad=Trueを表示できます.

# Make function g
g = 2*(x**3)
print(g)
print(g.grad_fn)

tensor([[ 1.8108e-01, -7.7389e-01, -3.7829e+00],
        [ 2.1612e+00,  1.1733e-01, -4.3271e+00],
        [ 2.6150e-02, -6.8144e+00,  4.2403e-03]], grad_fn=<MulBackward0>)
<MulBackward0 object at 0x000001C028EE13D0>
ggg関数を生成します.
[+]ggg関数演算を行うTensorにgrad_fnを加えると,複数回演算が行われたことを示す.MulBackward0 object
z = g/12
z.backward()

RuntimeError: grad can be implicitly created only for scalar outputs
[!] 最終的にGradientを取得する値はScalar出力でなければなりません.損失値!
z = g.sum()/12
print(z)

tensor(-1.1007, grad_fn=<DivBackward0>)
z.backward()
print(x.grad)
print( (x**2)/2 )

tensor([[0.1008, 0.2655, 0.7647],
        [0.5265, 0.0755, 0.8364],
        [0.0278, 1.1321, 0.0083]])
        
tensor([[0.1008, 0.2655, 0.7647],
        [0.5265, 0.0755, 0.8364],
        [0.0278, 1.1321, 0.0083]], grad_fn=<DivBackward0>)
g(x)=2x3g(x)=2x^3g(x)=2x3
                h(x)=112∑x\;\;\;\;\;\;\;\;h(x)=\frac{1}{12}\sum xh(x)=121​∑x
z=h(x)z = h(x)z=h(x)
∴∂z∂x=12x2\therefore\frac{\partial z}{\partial x}=\frac{1}{2}x^2∴∂x∂z​=21​x2
上記のコードおよび式に示すように、z.backward()を実行し、x.gradによって勾配を容易に求めることができる.(zzz対xxxの偏微分)