python実現ロジックスティー回帰分類器LR

2253 ワード

# --*-- coding:utf-8 --*--
import numpy as np

class Logistic:
    def loadDataSet(self, fileName = 'testSet.txt'):   #     
        dataMat = []
        labelMat = []
        fr = open(fileName)
        for line in fr.readlines():  #     
            lineArr = line.strip().split()
            dataMat.append([1.0, float(lineArr[0]), float(lineArr[1])])  #     x=(1,x(1),x(2),x(3)⋯,x(n))  100 * 3
            labelMat.append(int(lineArr[-1]))   #      100

        return dataMat, labelMat

    def sigmoid(self, inX):
        return 1.0 / (1 + np.exp(-inX))

    def train(self, dataSet, labels):   #   
        dataMat = np.mat(dataSet)   #             shape=(100, 3)
        labelMat = np.mat(labels).transpose()  #              shape=(100, 1)

        print(dataMat.shape)
        m, n = np.shape(dataSet)    #   
        alpha = 0.01
        maxIter = 500
        weights = np.ones((n, 1))  # w=(b,w(1),w(2),w(3)⋯,w(n))
        print(type(weights))
        for i in range(1):    #   
            h = self.sigmoid(dataMat * weights)  # (100, 1)
            error = h - labelMat    #                (100, 1)
            weights = weights - alpha * dataMat.transpose() * error    #      

        return weights

    def nparraytrain(self, dataSet, labels):
        dataSet = np.array(dataSet)  # (100, 3)
        labelSet = np.array(labels)  # (100,)
        labelSet = labelSet[:, np.newaxis]  # (100, 1)
        m, n = np.shape(dataSet)
        alpha = 0.01
        maxIter = 500
        weights = np.ones((n, 1))  # (3, 1)
        for i in range(1):
            h = self.sigmoid(np.dot(dataSet, weights))  # (100, 1)
            error = h - labelSet  # (100, 1)
            weights = weights - alpha * np.dot(dataSet.transpose(), error)
            # weights = weights - alpha * dataSet.transpose() * error  #        

        return weights

    def classify(self, X, weights):  #        。
        prob = self.sigmoid(sum(X * weights))  #     sum
        if prob > 0.5:
            return 1.0
        else:
            return 0.0


if __name__ == '__main__':
    logistic = Logistic()
    dataSet, labels = logistic.loadDataSet()
    weights = logistic.nparraytrain(dataSet, labels)
    print(weights)