機械学習−感知機モデル(pocketアルゴリズム)実装


前回考慮したセンサモデルには,データが線形に分割可能であるという仮定がある.実際,騒音や他の要因の存在により,いずれも線形に分けることはできない.したがって,非線形分割が考えられる場合の実現方法が必要である.
ここではPocketアルゴリズムを使用します.
Pocketアルゴリズムの考え方は非常に簡単で、Wを検索する時、絶えず最高の精度とWを記録します.これにより、データが線形に分割されなくても、反復の回数を絶えず増加させる限り、比較的良いテスト結果が得られる.
データ:
https://www.csie.ntu.edu.tw/~htlin/course/ml15fall/hw1/hw1_18_train.dat
https://www.csie.ntu.edu.tw/~htlin/course/ml15fall/hw1/hw1_18_test.dat
計算精度関数:
#      
def checkErrorRate(testMatData, testLabelData, W):
    accuracyCount = 0
    for i in range(len(testMatData)):
        vect = testMatData[i, :]
        extraBiasVect = append(1, vect)
        resultY = vdot(W, extraBiasVect)
        if (resultY <= 0):
            labelY = -1
        else:
            labelY = 1
        if (labelY == testLabelData[i]):
            accuracyCount += 1
    return accuracyCount / len(testLabelData)

Pocketアルゴリズム
#            ,pocketPerceptron  
def pocketPerceptronLearn(trainMatData, trainLabelData, testMatData, testLabelData):
    #         
    maxIteration = 100000
    #     
    W = [0, 0, 0, 0, 0]
    # labely
    iterationFinish = False
    #       
    times = 0
    bestW = W
    bestAccuracyRate = 0
    for interationCount in range(maxIteration):
        for dataIndex in range(len(trainMatData)):
            #       
            vect = trainMatData[dataIndex, :]
            extraBiasVect = append(1, vect)
            resultY = vdot(W, extraBiasVect)
            if (resultY <= 0):
                labelY = -1
            else:
                labelY = 1

            if (labelY != trainLabelData[dataIndex]):
                W = W + trainLabelData[dataIndex] * extraBiasVect
                times += 1
                rate = checkErrorRate(testMatData, testLabelData, W)
                if (rate > bestAccuracyRate):
                    bestAccuracyRate = rate
                    bestW = W

            else:
                if (dataIndex == (len(trainMatData) - 1)):
                    iterationFinish = True
        if (iterationFinish == True):
            break
            #     
        if (times >= 50):
            print(bestW)
            print(bestAccuracyRate)
    return bestW, bestAccuracyRate

Pythonを使い始めたばかりで、多くのマトリクス/配列などの数学の操作がくどくて、効率もよくありません.使用中にどんどん悪補をしていきましょう.