PyTorchの自動微分を用いて関数の接線を描画してみた
この記事では、fastbookのサンプルコード04_mnist_basics.ipynbをベースに自分の実験を入れながら、次のように関数の接線を表示するコードを書いてみました。
モジュールのインポート
plot_function()
やtensor()
など、一部fastbookのツール関数を使っている関係で、サンプルコードに従って、次のようにfast.ai提供のモジュールをインポートしています。(plot_function()
を自作すれば、普通にnumpy
matplotlib
torch
でOKです)
!pip install -Uqq fastbook
import fastbook
fastbook.setup_book()
from fastai.vision.all import *
from fastbook import *
接線の方程式を求める
まず、接線の方程式を求める必要があります。関数を$f(x)$とし、$x=x_0$での微分値を$a$とし、$y_0 = f(x_0)$として、接線の方程式を$y=ax+b$とすれば、その接線は$(x_0,y_0)$を通るので
\begin{align}
y_0 &= a x_0 + b\\
\therefore\ y &= y_0 + a(x - x_0) = a x + y_0 - a x_0\\
b &= y_0 - a x_0
\end{align}
となります。次のコードは、この結果に基づいて接線の方程式を求める関数です。
def tangent_line(f, x0:Tensor):
y0 = f(x0)
y0.backward()
with torch.no_grad():
a = x0.grad.clone()
b = y0 - a*x0
x0.grad.zero_()
return lambda x : a*x + b
x0.grad.zero_()
は、テンソル変数の微分値をゼロクリアするためのコードで、このtangent_line
関数を単独にコールし引数$x_0$が使いまわさいれていると、関数tangent_line
をコールする都度微分値(grad
)が累積されるためです。また、a = x0.grad.clone()
とするのは、そうしないとa
とx_0.grad
が同じメモリをシェアするらしく、x0.grad.zero_()
をコールした時点でa
が参照する値がクリアされるためです。
接線を描画する
次のコードは、与えられた関数f(x)
の指定されたx座標x
での接線を描画する関数です。
def plot_tangent_line(f, x, x_range=1.):
xt = tensor(x).requires_grad_()
tf = tangent_line(f, xt)
with torch.no_grad():
x_r = np.array( [x - x_range, x + x_range] )
plot_function(f, 'x', 'f(x)', min=x_r[0], max=x_r[1])
plt.scatter(xt, f(xt), color='red')
plt.plot(x_r,tf(x_r))
plt.show()
tf
は接線の方程式で、x_r
はx座標の描画範囲[x-x_range, x+x_range]
です。また、tensor
とplot_function
はfast.aiが提供する関数です。そして、接線はplt.plot(x_r,tf(x_r))
で描画しています。具体的には、(x_r[0],tf(x_r[0]))
と(x_r[1],tf(x_r[1]))
の2点を結ぶ直線になります。
たとえば、$f(x)=0.78x^2+sin(x)$の$x=-0.75$での接線を描画するには次のようにします。
def f(x): return 0.78*x**2 + torch.sin(x)
plot_tangent_line(f, -.75)
関連記事
fast.aiで作るハンドメイドの超簡単な手書き数字判別器
fastbookのサンプルコード(02_production.ipynb)をGoogle Colabで試してみた Part 1
PyTorchの自動微分を用いて2変数関数の接平面を描画してみた
fast.ai公式ページ
Practical Deep Learning for Coders
Lesson 3 - Deep Learning for Coders (2020)
Lesson 4 - Deep Learning for Coders (2020)
Author And Source
この問題について(PyTorchの自動微分を用いて関数の接線を描画してみた), 我々は、より多くの情報をここで見つけました https://qiita.com/kumagorou/items/6081d9cf15024c839103著者帰属:元の著者の情報は、元のURLに含まれています。著作権は原作者に属する。
Content is automatically searched and collected through network algorithms . If there is a violation . Please contact us . We will adjust (correct author information ,or delete content ) as soon as possible .