小ロット勾配降下アルゴリズムpython

1588 ワード

# -*- coding: utf-8 -*-
"""
Created on Tue Mar 13 20:49:03 2018

@author: 
"""

import numpy as np
from scipy import stats
import matplotlib.pyplot as plt

##      ,     2*x+5+random.randint(50)
x=np.arange(0.,10.,0.2)
m=len(x)
print(m)
x0=np.full(m,1.0)    
input_data=np.vstack([x0,x]).T  
target_data=2*input_data[:,1]+5*input_data[:,0]+np.random.randn(m)

loop_max=100000             #        ,       
epsilon=1e-3               #           

np.random.seed(0)           #        ,          
theta=np.random.randn(2)     #           

alpha=0.001                  #    (         ,           )
diff=0. 
error = np.zeros(2)           #       
count=0                      #        
finish=0                     #        
minibatch_size=5            #       (             )
while count

slope, intercept, r_value, p_value,slope_std_error = stats.linregress(x, target_data)  
print ('intercept = %s slope = %s'% (intercept, slope) )
 
 #             
plt.plot(x, target_data, 'g*')  
plt.plot(x, theta[1]* x +theta[0],'r')  
plt.show()