MNISTデータセットをKNNを使って実験します.
3160 ワード
SVMを使ってMNISTデータセットを分類実験したところ、98.46%の分類正確率が得られました.
今日はpythonで小さなプログラムを書いて、KNNの分類効果をテストしました.
KNNの計算量が大きいので、KD-treeを使って最適化していません.600訓練セットに対して、10000テストセットのデータ計算は遅いです.ここではKNNの効果をテストしてみたいだけです.調整しません.
Kは以前に見たことがあるので、20を超えないほうがいいと思います.ここでK=10を選びました.距離はヨーロッパ式です.改善が必要なら、Kを再調整して最高の成績を選ぶことができます.
まず、scaleを通さないで、つまり画素階調値を使ってヨーロッパの距離を計算して比較します.最初は基本的に安定していましたが、正しい確率の95%にびっくりしました.もともとKNNはあまり「勉強」のない機械学習アルゴリズムだと思っていたので、その特徴はどのような状況でも使えると推測していましたが、最高ではありませんでした.だから60%~80%は受け入れられると思います.基本的には95%で安定できるとは思いませんでしたが、アルゴリズムとコードを確認しても大丈夫です.このデータセットのほうが挑戦的ではないかと急に思いました.
MNIST公式サイトに行きますhttp://yann.lecun.com/exdb/mnist/)このデータセットをデータとするアルゴリズムの結果比較が上に掛けられている.KNNを調べてみると、多くの発見があり、誤り率は大体5%以内で、1%以内までできます.うん
走った結果、正しい率は96.688%でした.つまり、エラー率error rateは3.31%程度です.
もう一度走るとscaleを通過するデータ、つまり階調データを[0,1]の範囲に正規化します.効果が上がりましたか?
scaleを経て、最終的に走った結果、正確率はなんと96.688%です.つまり、このデータセットの下で、KNNのデータを正規化するかどうかは効果がないということです.
scaleを走る前に、個人的な推測:データを処理する前に正規化して、高次元ののろいを防ぐ(784次元の空間で高次元ののろいを受けやすい).したがって、scaleは前者より良いと予想されます.しかし、今は同じ結果になりそうです.つまり、K=10のKNNアルゴリズムではMNISTに対する予測は同じです.
scaleの前後の正確度は同じです.訓練セットには600個のデータ点がありますから、0-9の各分類には平均6000個のデータ点があります.このような場合、テストデータセットのデータ点に対して、近い10点のほとんどが他のクラスです.この時,KNNはより良い分類効果を得るだけでなく,scaleかどうかには敏感ではなく,効果は同じである.
コードは以下の通りです
今日はpythonで小さなプログラムを書いて、KNNの分類効果をテストしました.
KNNの計算量が大きいので、KD-treeを使って最適化していません.600訓練セットに対して、10000テストセットのデータ計算は遅いです.ここではKNNの効果をテストしてみたいだけです.調整しません.
Kは以前に見たことがあるので、20を超えないほうがいいと思います.ここでK=10を選びました.距離はヨーロッパ式です.改善が必要なら、Kを再調整して最高の成績を選ぶことができます.
まず、scaleを通さないで、つまり画素階調値を使ってヨーロッパの距離を計算して比較します.最初は基本的に安定していましたが、正しい確率の95%にびっくりしました.もともとKNNはあまり「勉強」のない機械学習アルゴリズムだと思っていたので、その特徴はどのような状況でも使えると推測していましたが、最高ではありませんでした.だから60%~80%は受け入れられると思います.基本的には95%で安定できるとは思いませんでしたが、アルゴリズムとコードを確認しても大丈夫です.このデータセットのほうが挑戦的ではないかと急に思いました.
MNIST公式サイトに行きますhttp://yann.lecun.com/exdb/mnist/)このデータセットをデータとするアルゴリズムの結果比較が上に掛けられている.KNNを調べてみると、多くの発見があり、誤り率は大体5%以内で、1%以内までできます.うん
走った結果、正しい率は96.688%でした.つまり、エラー率error rateは3.31%程度です.
もう一度走るとscaleを通過するデータ、つまり階調データを[0,1]の範囲に正規化します.効果が上がりましたか?
scaleを経て、最終的に走った結果、正確率はなんと96.688%です.つまり、このデータセットの下で、KNNのデータを正規化するかどうかは効果がないということです.
scaleを走る前に、個人的な推測:データを処理する前に正規化して、高次元ののろいを防ぐ(784次元の空間で高次元ののろいを受けやすい).したがって、scaleは前者より良いと予想されます.しかし、今は同じ結果になりそうです.つまり、K=10のKNNアルゴリズムではMNISTに対する予測は同じです.
scaleの前後の正確度は同じです.訓練セットには600個のデータ点がありますから、0-9の各分類には平均6000個のデータ点があります.このような場合、テストデータセットのデータ点に対して、近い10点のほとんどが他のクラスです.この時,KNNはより良い分類効果を得るだけでなく,scaleかどうかには敏感ではなく,効果は同じである.
コードは以下の通りです
#KNN for MNIST
from numpy import *
import operator
def line2Mat(line):
line = line.strip().split(' ')
label = line[0]
mat = []
for pixel in line[1:]:
pixel = pixel.split(':')[1]
mat.append(float(pixel))
return mat, label
#matrix should be type: array. Or classify() will get error.
def file2Mat(fileName):
f = open(fileName)
lines = f.readlines()
matrix = []
labels = []
for line in lines:
mat, label = line2Mat(line)
matrix.append(mat)
labels.append(label)
print 'Read file '+str(fileName) + ' to matrix done!'
return array(matrix), labels
#classify mat with trained data: matrix and labels. With KNN's K set.
def classify(mat, matrix, labels, k):
diffMat = tile(mat, (shape(matrix)[0], 1)) - matrix
#diffMat = array(diffMat)
sqDiffMat = diffMat ** 2
sqDistances = sqDiffMat.sum(axis=1)
distances = sqDistances ** 0.5
sortedDistanceIndex = distances.argsort()
classCount = {}
for i in range(k):
voteLabel = labels[sortedDistanceIndex[i]]
classCount[voteLabel] = classCount.get(voteLabel,0) + 1
sortedClassCount = sorted(classCount.iteritems(), key=operator.itemgetter(1),reverse=True)
return sortedClassCount[0][0]
def classifyFiles(trainMatrix, trainLabels, testMatrix, testLabels, K):
rightCnt = 0
for i in range(len(testMatrix)):
if i % 100 == 0:
print 'num '+str(i)+'. ratio: '+ str(float(rightCnt)/(i+1))
label = testLabels[i]
predictLabel = classify(testMatrix[i], trainMatrix, trainLabels, K)
if label == predictLabel:
rightCnt += 1
return float(rightCnt)/len(testMatrix)
trainFile = 'train_60k.txt'
testFile = 'test_10k.txt'
trainMatrix, trainLabels = file2Mat(trainFile)
testMatrix, testLabels = file2Mat(testFile)
K = 10
rightRatio = classifyFiles(trainMatrix, trainLabels, testMatrix, testLabels, K)
print 'classify right ratio:' +str(right)