PyTorchで勾配降下法


PyTorchで勾配降下法をするコードを書いてみました。

最適化したい関数

def rosenbrock(x0, x1):
    y = 10 * (x1 - x0 ** 2) ** 2 + (x0 - 1) ** 2
    return y

関数を可視化する

import numpy as np

h = 0.01
x_min = -2
y_min = -3
x_max = 2
y_max = 5

X = np.arange(x_min, x_max, h)
Y = np.arange(y_min, y_max, h)

xx, yy = np.meshgrid(X, Y)

最小値はこのへんですね

matrix = rosenbrock(xx, yy)
minimum = None
min_x = None
min_y = None
for i in range(matrix.shape[0]):
    for j in range(matrix.shape[1]):
        if minimum is None or minimum > matrix[i][j]:
            minimum = matrix[i][j]
            min_y = Y[i]
            min_x = X[j]

print(min_x, min_y, minimum)
1.0000000000000027 0.9999999999999147 8.208018832734106e-26

ドットが描かれた場所が最小値になります。

import matplotlib.pyplot as plt
plt.contourf(xx, yy, np.sqrt(rosenbrock(xx, yy)), alpha=0.5)
plt.scatter(min_x, min_y, c="k")
plt.colorbar()
plt.grid()
plt.show()

勾配降下法

import numpy as np
import torch

x0 = torch.tensor(0.0, requires_grad=True)
x1 = torch.tensor(4.0, requires_grad=True)

lr = 0.001
iters = 10000

history = []
for i in range(iters):
    history.append(np.array([np.array(x0.data), np.array(x1.data)]).flatten())
    y = rosenbrock(x0, x1)
    y.backward()

    with torch.no_grad():
        x0.data -= lr * x0.grad
        x1.data -= lr * x1.grad

        x0.grad.zero_()
        x1.grad.zero_()

結果表示

import matplotlib.pyplot as plt
plt.contourf(xx, yy, np.sqrt(rosenbrock(xx, yy)), alpha=0.5)
plt.scatter([p[0] for p in history], [p[1] for p in history])
plt.scatter(min_x, min_y, c="k")
plt.colorbar()
plt.grid()
plt.show()

複数箇所からスタート

import numpy as np
import torch

x0 = torch.tensor([0.0, -0.5, -1.5, 1.5], requires_grad=True)
x1 = torch.tensor([4.0, 4.0, -1.0, -2], requires_grad=True)

lr = 0.001
iters = 10000

history = []
for i in range(iters):
    history.append([x0.detach().clone(), x1.detach().clone()])
    y = rosenbrock(x0, x1)
    #y.backward()
    s = torch.sum(y)
    s.backward()

    with torch.no_grad():
        x0.data -= lr * x0.grad
        x1.data -= lr * x1.grad

        x0.grad.zero_()
        x1.grad.zero_()

結果表示

import matplotlib.pyplot as plt
plt.contourf(xx, yy, np.sqrt(rosenbrock(xx, yy)), alpha=0.5)
for i in range(4):
    plt.scatter([p[0][i] for p in history], [p[1][i] for p in history])
plt.colorbar()
plt.scatter(min_x, min_y, c="k")
plt.grid()
plt.show()