小ロット勾配降下アルゴリズムpython
1588 ワード
# -*- coding: utf-8 -*-
"""
Created on Tue Mar 13 20:49:03 2018
@author:
"""
import numpy as np
from scipy import stats
import matplotlib.pyplot as plt
## , 2*x+5+random.randint(50)
x=np.arange(0.,10.,0.2)
m=len(x)
print(m)
x0=np.full(m,1.0)
input_data=np.vstack([x0,x]).T
target_data=2*input_data[:,1]+5*input_data[:,0]+np.random.randn(m)
loop_max=100000 # ,
epsilon=1e-3 #
np.random.seed(0) # ,
theta=np.random.randn(2) #
alpha=0.001 # ( , )
diff=0.
error = np.zeros(2) #
count=0 #
finish=0 #
minibatch_size=5 # ( )
while count
slope, intercept, r_value, p_value,slope_std_error = stats.linregress(x, target_data)
print ('intercept = %s slope = %s'% (intercept, slope) )
#
plt.plot(x, target_data, 'g*')
plt.plot(x, theta[1]* x +theta[0],'r')
plt.show()