from sklearn.datasets import make_blobs
from sklearn.model_selection import train_test_split
import matplotlib.pyplot as plt
import numpy as np
X, y_ = make_blobs(n_samples=10000, n_features=2, centers=2)
y = y_[:, np.newaxis]
x_train, x_test, y_train, y_test = train_test_split(X,y)
x_train = np.c_[x_train, np.ones(x_train.shape[0])]
x_test = np.c_[x_test, np.ones(x_test.shape[0])]
class Logstic_GD:
def __init__(self):
self.W = None
def sigmod(self, z):
return 1 / (1+np.exp(-z))
def Logstic_train(self, x_hyperplane, x_train, y_train, epoch, learning_rate):
n_samples, n_feature = x_train.shape
self.W = np.zeros((n_feature, 1))
losses = []
predictions = []
for i in range(epoch):
y_pred = self.sigmod(np.dot(x_train, self.W))
self.dw = (np.dot(x_train.T, (y_pred-y_train)))/n_samples
self.W -=learning_rate*self.dw
prediction = -(x_hyperplane * self.W[0] + self.W[-1]) / self.W[1]
predictions.append(prediction)
loss = -(y_train.T.dot(np.log(y_pred) + (1-y_train).T.dot(np.log(1-y_pred))))/n_samples
loss = loss[0][0]
losses.append(loss)
if i % 10 == 0:
print(f"At {i} epoch, loss is {loss}")
return self.W, losses, predictions
def prediction(self, x_test, y_test):
y_pred = self.sigmod(np.dot(x_test, self.W))
y_diff= y_pred - y_test
score = 1-np.mean(np.abs(y_diff))
print(f"score is {score*100}%")
return y_pred
if __name__ == "__main__":
xmin = np.min(X[:, 0])
xmax = np.max(X[:, 0])
x_hyperplane = np.array([xmin, xmax])
l = Logstic_GD()
Weight, losses, predictions = l.Logstic_train(x_hyperplane, x_train, y_train, 1000, 0.8)
l.prediction(x_train, y_train)
y_pred = l.prediction(x_test, y_test)
ymin = -(xmin*Weight[0]+Weight[-1])/Weight[1]
ymax = -(xmax*Weight[0]+Weight[-1])/Weight[1]
plt.figure(figsize=(8, 6))
ax = plt.subplot(1,2,1)
ax.scatter(X[:, 0], X[:, 1], c=y_)
plt.ion()
for i in range(1000):
if i % 20 ==0:
try:
ax.lines.remove(lines[0])
except Exception:
pass
lines = ax.plot(x_hyperplane, predictions[i], c='blue', linewidth=2)
plt.pause(0.1)
plt.ioff()
plt.subplot(1,2,2)
plt.plot(np.arange(1000), losses, linewidth=2)
plt.xlabel("epoch")
plt.ylabel("loss")
plt.show()