機械学習(5)--K-meansクラスタリングアルゴリズム
K-meansアルゴリズムの概要:
1、K-meansアルゴリズムはクラスタリング中の古典アルゴリズムであり、同時に、データマイニングの古典アルゴリズムの一つでもある.
2、このアルゴリズムの主なパラメータK、すなわちいくつかのサンプルデータ数において、我々は
各サンプルがどんなクラスなのか分かりませんが、すべてのサンプルがいくつに分かれているか、あるいはいくつに分けたいかを知っています.ここでのいくつかのクラスはKです.
3、本例の基本手順
3.1前K個のサンプルを選択し、各サンプルを一種類に分け、このKサンプルの座標を中心点とする
3.2すべてのサンプルと中心点との距離を計算し、各中心点にどのようなサンプルがあるかを得る
3.3各中心点のすべてのサンプルの座標の平均値を計算し、新しい座標とする
3.4サイクル3.2ステップで、後のサイクルの各中心点に含まれるサンプルが変化しなくなるまで、サイクルを終了する
本文はmatplotlibを通じて中心点の変化と各クラスの変化を表示し、matplotlibをインストールしていない場合は、これらの文に関連する内容を遮断することができます.
プログラムの実行時にmatplotlibフォームから飛び出し、プログラムを中断し、フォームを閉じるとプログラムが実行されます.
1、K-meansアルゴリズムはクラスタリング中の古典アルゴリズムであり、同時に、データマイニングの古典アルゴリズムの一つでもある.
2、このアルゴリズムの主なパラメータK、すなわちいくつかのサンプルデータ数において、我々は
各サンプルがどんなクラスなのか分かりませんが、すべてのサンプルがいくつに分かれているか、あるいはいくつに分けたいかを知っています.ここでのいくつかのクラスはKです.
3、本例の基本手順
3.1前K個のサンプルを選択し、各サンプルを一種類に分け、このKサンプルの座標を中心点とする
3.2すべてのサンプルと中心点との距離を計算し、各中心点にどのようなサンプルがあるかを得る
3.3各中心点のすべてのサンプルの座標の平均値を計算し、新しい座標とする
3.4サイクル3.2ステップで、後のサイクルの各中心点に含まれるサンプルが変化しなくなるまで、サイクルを終了する
本文はmatplotlibを通じて中心点の変化と各クラスの変化を表示し、matplotlibをインストールしていない場合は、これらの文に関連する内容を遮断することができます.
プログラムの実行時にmatplotlibフォームから飛び出し、プログラムを中断し、フォームを閉じるとプログラムが実行されます.
# -*- coding:utf-8 -*-
import matplotlib.pyplot as plt
import numpy as np
# , matplotlib
from pylab import *
mpl.rcParams['font.sans-serif'] = ['SimHei']
# , ,
data='''
1,1
2,1
4,5
6,6
5,4
3,3
2,2
3,2
5,6
1,3
3,1
6,5
''';
#
#[0,0] , 4 ,[x,y,0,0]
# 1 ,
# 2 , , 1,
data=[x.split(',')+[0,0] for x in data.split('
')]
data=list(filter(lambda x: len(x)==4,data))
data=np.array(data).astype(np.float)
#print(data)
# , K ,
k=2 # matplotlib , ,
centroids=data.copy()[:k,:-2]#-2
def draw(centroids,data,title):
plt.axis([round((np.min(data,axis=0)-1)[0])
,round((np.max(data,axis=0)+1)[0])
,round((np.min(data,axis=0)-1)[1])
,round((np.max(data,axis=0)+1)[1])]) # X,Y
plt.title(title)
for index,center in enumerate(centroids):
colorStr='rgby'[index:index+1] # matplotlib , ,
centerData=np.array(list(filter(lambda x:x[-2]==index+1 ,data)))
if len(centerData)>0 :plt.scatter(centerData[:,0],centerData[:,1],c=colorStr)
plt.scatter(center[0],center[1],c=colorStr,marker='x')
plt.show()
runtimes=0
changePointLength=-1
while changePointLength!=0:
runtimes+=1
draw(centroids,data,' %d '%runtimes + (' , ' if runtimes==1 else ''))
#3.2 ,
for dataItem in data:
distances=np.sqrt(((centroids-dataItem[:-2])**2).sum(axis=1))#
minDisType=np.argmin(distances)+1 #
if dataItem[-2]==minDisType:
dataItem[-1]=0 #
else :
dataItem[-1]=1 #
dataItem[-2]=minDisType
print(data)
#3.3 ,
for index,center in enumerate(centroids):
centerData=np.array(list(filter(lambda x:x[-2]==index+1 ,data)))[:,:-2] #
centerData=centerData.mean(axis=0)
center[0]=centerData[0]
center[1]=centerData[1]
#print(data)
changePointLength=len(list(filter(lambda x:x[-1]==1 ,data))) # , ,