KNNアルゴリズム

61756 ワード

KNNとは
KNN(K-nearest neighbor)、すなわちK近隣アルゴリズムである.その動作原理は非常に簡単です. 
例えば、映画に登場する戦闘シーンとキスシーンの数によって、恋愛映画とアクション映画に分けられるトレーニングセットのデータがあります.
映画のタイトル
ファイト?シーン
キスシーン
映画の種類
California Man
3
104
ラブ映画
He's Not Really into Dudes
2
100
ラブ映画
Beautiful Woman
1
81
ラブ映画
Kevin Longblade
101
10
アクションフィルム
Robo Slayer 3000
99
5
アクションフィルム
Amped II
98
2
アクションフィルム
?
18
90
不明
では、次のような未知の映画をどのように分類しますか? 
映画のタイトル
ファイト?シーン
キスシーン
映画の種類
?
18
90
不明
KNNは、新しいデータの各特徴をサンプルセットの各サンプルと比較する、サンプルとの距離を算出する方法を採用する.そしてk個の距離が最も近いサンプルを探し出す.これらのサンプルが属する分類が最も多い分類は、そのデータの分類である. 
例えば、映画「?」上記サンプルセットの各サンプルと計算する、距離を得る.Kが3を取ると「?3つの恋愛映画との距離が一番近いので、「?ラブ映画でもあります. 
アルゴリズムを以下に示す. 
アルゴリズムの詳細
PythonでKNNアルゴリズムを実現し、上述の分類問題を解決した. 
まず、アルゴリズムで使用されるデータセットを生成します.
 
     
  1. def createDataSet():
  2. '''
  3. '''
  4. group = np.array([[3, 104], [2, 100], [1, 81], [101, 10],[99, 5], [98, 2]])
  5. labels = ['A', 'A', 'A', 'D','D', 'D']
  6. return group, labels

然后写出 KNN 分类算法:
 
     
  1. def classify0(inX, dataSet, labels, k):
  2. '''
  3. KNN
  4. :param inX:
  5. :param dataSet:
  6. :param labels:
  7. :param k:
  8. :return:
  9. '''
  10. dataSetSize = dataSet.shape[0]
  11. diffMat = np.tile(inX, (dataSetSize, 1)) - dataSet
  12. sqDiffMat = diffMat ** 2
  13. sqDistances = sqDiffMat.sum(axis=1)
  14. distances = sqDistances ** 0.5
  15. sortedDistIndicies = distances.argsort()
  16. classCount = {}
  17. for i in range(k):
  18. voteIlabel = labels[sortedDistIndicies[i]]
  19. classCount[voteIlabel] = classCount.get(voteIlabel, 0) + 1
  20. sortedClassCount = sorted(classCount.items(), key=operator.itemgetter(1), reverse=True)
  21. print(sortedClassCount)
  22. return sortedClassCount[0][0]

KNN 算法在使用的时候,一般需要对特征数据进行归一化处理:
 
     
  1. def autoNorm(dataSet):
  2. '''
  3. :param dataSet:
  4. :return:
  5. '''
  6. minVals = dataSet.min(0)
  7. maxVals = dataSet.max(0)
  8. ranges = maxVals - minVals
  9. normDataSet = np.zeros(np.shape(dataSet))
  10. m = dataSet.shape[0]
  11. normDataSet = dataSet - np.tile(minVals, (m, 1))
  12. normDataSet = normDataSet/np.tile(ranges, (m,1))
  13. return normDataSet, ranges, minVals


我们也可以使用 matplotlib 画出二维空间的散点图
 
     
  1. def plot(dataSet):
  2. fig = plt.figure()
  3. ax = fig.add_subplot(111)
  4. ax.scatter(dataSet[:, 0], dataSet[:, 1])
  5. plt.show()


整体代码如下:
 
     
  1. #!/usr/bin/python
  2. # -*- coding: UTF-8 -*-
  3. import numpy as np
  4. import operator
  5. import matplotlib
  6. import matplotlib.pyplot as plt
  7. def createDataSet():
  8. '''
  9. '''
  10. group = np.array([[3, 104], [2, 100], [1, 81], [101, 10],[99, 5], [98, 2]])
  11. labels = ['A', 'A', 'A', 'D','D', 'D']
  12. return group, labels
  13. def classify0(inX, dataSet, labels, k):
  14. '''
  15. KNN
  16. :param inX:
  17. :param dataSet:
  18. :param labels:
  19. :param k:
  20. :return:
  21. '''
  22. dataSetSize = dataSet.shape[0]
  23. diffMat = np.tile(inX, (dataSetSize, 1)) - dataSet
  24. sqDiffMat = diffMat ** 2
  25. sqDistances = sqDiffMat.sum(axis=1)
  26. distances = sqDistances ** 0.5
  27. sortedDistIndicies = distances.argsort()
  28. classCount = {}
  29. for i in range(k):
  30. voteIlabel = labels[sortedDistIndicies[i]]
  31. classCount[voteIlabel] = classCount.get(voteIlabel, 0) + 1
  32. sortedClassCount = sorted(classCount.items(), key=operator.itemgetter(1), reverse=True)
  33. print(sortedClassCount)
  34. return sortedClassCount[0][0]
  35. def autoNorm(dataSet):
  36. '''
  37. :param dataSet:
  38. :return:
  39. '''
  40. minVals = dataSet.min(0)
  41. maxVals = dataSet.max(0)
  42. ranges = maxVals - minVals
  43. normDataSet = np.zeros(np.shape(dataSet))
  44. m = dataSet.shape[0]
  45. normDataSet = dataSet - np.tile(minVals, (m, 1))
  46. normDataSet = normDataSet/np.tile(ranges, (m,1))
  47. return normDataSet, ranges, minVals
  48. def plot(dataSet):
  49. fig = plt.figure()
  50. ax = fig.add_subplot(111)
  51. ax.scatter(dataSet[:, 0], dataSet[:, 1])
  52. plt.show()
  53. if __name__ == '__main__':
  54. dataSet, labels = createDataSet()
  55. dataSet, _, _ = autoNorm(dataSet)
  56. print(dataSet)
  57. plot(dataSet)
  58. print(classify0([18, 90], dataSet, labels,4))

KNN

KNN . KNN . ,KNN , .