機械学習実戦のk-近隣アルゴリズム(KNN)手書き数字認識

12528 ワード

機械学習実戦のk-近隣アルゴリズム(KNN)手書き数字認識
『機械学習実戦』第2章k-近隣アルゴリズムの識別手書き数字
k-近隣アルゴリズムの原理
  • ターゲット要素と既知のカテゴリ要素との距離を計算することにより、距離が小さいほど、要素間が類似することを示す
  • .
  • 最も近いk個の要素
  • をとる
  • というk個の要素のうち、最も出現回数が多いカテゴリは、ターゲット要素のカテゴリ
  • である.
    コードと解釈
    import numpy as np
    from os import listdir		#            
    
    #         
    def classify0(X, dataset, labels, k):
        size = dataset.shape[0]
        dis = np.sqrt(np.sum(np.square(dataset-X), axis=1))		#     、  、  
        sort_index = np.argsort(dis)		#      ,          
        label_count = {}
        for i in range(k):
            label = labels[sort_index[i]]
            label_count[label] = label_count.get(label, 0)+1
        sort_label = sorted(label_count.items(), key=lambda d: d[1], reverse=True)		#            
        return sort_label[0][0]
    
    # 32*32           1*1024   
    def digVector(filename):
        vect = np.zeros((1, 1024))
        f = open(filename)
        for i in range(32):
            lines = f.readline()
            for j in range(32):
                vect[0, 32*i+j] = int(lines[j])
        return vect
    
    def classifyDigits():
        trainList = listdir('trainingDigits')		#           
        nTrain = len(trainList)
        train_X = np.zeros((nTrain, 1024))
        train_y = []
        for i in range(nTrain):
            filename = trainList[i]
            train_X[i, :] = digVector('trainingDigits/%s'%filename)
            train_y.append(int(filename.split('.')[0].split('_')[0]))		#     ,       , 0_24.txt,     0
    
        testList = listdir('testDigits')
        nTest = len(testList)
        error_count = 0
        for i in range(nTest):
            filename = testList[i]
            label = int(filename.split('.')[0].split('_')[0])
            test = digVector('testDigits\%s'%filename)
            result = classify0(test, train_X, train_y, 3)
            print('predict: %d, real: %d'%(result, label))
            if result != label: error_count += 1
        print('the error rate: ', error_count/float(nTest))
    

    欠点
  • 時間非効率
  • 各ターゲットベクトルテストnTrain次距離計算
  • を行う
  • 各距離計算はm次元の浮動小数点演算を含む、例えば手書き数字が1024次元の
  • であることを識別する.
  • ストレージスペースのオーバーヘッドが大きい
  • 訓練データに空間を用意する必要があり、訓練データが大きい場合、大量の記憶空間
  • を使用しなければならない.
  • は、データのインフラストラクチャ情報を与えることができず、平均インスタンスサンプルおよび典型的なインスタンスサンプルの特徴を知ることができない
  • .