Pytoch手書きリニア回帰
1092 ワード
原文のリンク:http://www.cnblogs.com/LiuXinyu12378/p/11374748.html
pytoch手書きリニア回帰
転載先:https://www.cnblogs.com/LiuXinyu12378/p/11374748.html
pytoch手書きリニア回帰
import torch
import matplotlib.pyplot as plt
from matplotlib.animation import FuncAnimation
LEARN_RATE = 0.1
#1.
x = torch.randn([500,1])
y_true = x*0.8+3
#2. t_tred = x*w + b
w = torch.rand([],requires_grad=True)
b = torch.tensor(0.,requires_grad=True)
plt.figure()
plt.grid(True)
#
plt.ion()
for i in range(50):
plt.cla()
for j in [w,b]:
if j.grad is not None:
j.grad.zero_()
y_predict = x*w+b
#3. , 0,
loss = (y_predict-y_true).pow(2).mean()
loss.backward()
#4. ,grad
w.data = w.data - LEARN_RATE*w.grad
b.data = b.data - LEARN_RATE*b.grad
plt.scatter(x.numpy(),y_true.numpy())
plt.plot(x.numpy(),y_predict.detach().numpy(),color="g")
plt.pause(0.1)
if i %50 ==0:
print( " {} , {}, w={}, b={}".format(i,loss.data,w.data,b.data))
#
plt.ioff()
plt.show()
転載先:https://www.cnblogs.com/LiuXinyu12378/p/11374748.html