Pytoch手書きリニア回帰

1092 ワード

原文のリンク:http://www.cnblogs.com/LiuXinyu12378/p/11374748.html
pytoch手書きリニア回帰
 
import torch
import matplotlib.pyplot as plt
from matplotlib.animation import FuncAnimation

LEARN_RATE = 0.1
#1.    
x = torch.randn([500,1])
y_true = x*0.8+3

#2.      t_tred = x*w + b

w = torch.rand([],requires_grad=True)
b = torch.tensor(0.,requires_grad=True)

plt.figure()
plt.grid(True)

#      
plt.ion()
for i in range(50):

    plt.cla()

    for j in [w,b]:
        if j.grad is not None:
            j.grad.zero_()
    y_predict = x*w+b

    #3.    ,        0,      

    loss = (y_predict-y_true).pow(2).mean()

    loss.backward()

    #4.    ,grad    

    w.data = w.data - LEARN_RATE*w.grad
    b.data = b.data - LEARN_RATE*b.grad


    plt.scatter(x.numpy(),y_true.numpy())
    plt.plot(x.numpy(),y_predict.detach().numpy(),color="g")

    plt.pause(0.1)


    if i %50 ==0:
        print( " {} ,  {},  w={},  b={}".format(i,loss.data,w.data,b.data))

#      
plt.ioff()
plt.show()
  
転載先:https://www.cnblogs.com/LiuXinyu12378/p/11374748.html