Logistic回帰多分類のアヤメ
作者:金良([email protected])csdnブログ:http://blog.csdn.net/u012176591
1.複数の論理回帰モデルの原理
論理回帰モデルは二分類モデルであり,二分類問題に用いられる.これをマルチカテゴリ用に複数の論理回帰モデル(multi-nominal logistic regression model)に拡張することができる.
カテゴリYの値セットが{1,2,⋯,K}であると仮定すると、複数の論理回帰モデルは
P(y=k|x)=exp(wk⋅x)1+∑K−1k=1exp(wk⋅x)k=1,2,⋯,K−1P(y=K|x)=11+∑K−1k=1exp(wk⋅x)
尤度関数は
∏i=1N∏k=1KP(yi=k|xi)yki
内
P(yi=k|xi)はモデル入力サンプル
xiの場合はカテゴリとして扱います
kの確率は、
ykiは関数を示す役割を果たし、
kイコールサンプル
xiのラベルカテゴリは1であり、残りはいずれも0である.
尤度関数に対して対数をとり,次に負をとり,L(w 1,w 2,⋯,wK−1)(L(w)と略記する)を得,最終的に訓練するモデルパラメータw 1,w 2,⋯,wK−1はL(w)の値を最小にする.
L(w)=−1N∑i=1N∑k=1KykilogP(yi=k|xi)=−1N∑i=1N(∑k=1K−1ykilogexp(wk⋅xi)1+∑K−1j=1exp(wj⋅xi)+yKilog11+∑K−1j=1exp(wj⋅xi))=−1N∑i=1N(∑k=1K−1yki(wk⋅xi−log(1+∑j=1K−1exp(wj⋅xi)))−yKilog(1+∑j=1K−1exp(wj⋅xi)))=1N∑i=1N(∑k=1Kykilog(1+∑j=1K−1exp(wj⋅xi))−∑k=1K−1ykiwk⋅xi)
オーバーフィットの発生を考慮してL(w)に正規化項を加える
∑K−1k=1|wk|22σ2
L(w)を書く
L(w)=1N∑i=1N(∑k=1Kykilog(1+∑j=1K−1exp(wj⋅xi))−∑k=1K−1ykiwk⋅xi)+∑k=1K−1|wk|22σ2
L(w)に対してwkについて勾配を求めて、得る
∂L(w)wk=1N∑i=1Nexp(wk⋅xi)1+∑K−1k=1exp(wk⋅xi)⋅xi−1N∑i=1Nykixi+wkσ2=1N∑i=1NP(yi=k|xi)⋅xi−1N∑i=1Nykixi+wkσ2
第1項1 NΣNi=1 P(yi=k|xi)⋅xiはカテゴリkの事後期待値と見なすことができ、第2項1 NΣNi=1 ykixiはカテゴリkの先行期待値と見なすことができ、第3項は正規化項であり、過フィッティングを緩和する.
次に勾配降下法でパラメータw 1,w 2,⋯,wK−1を補正すればよい.
2.アヤメ花データの可視化
作図コードは以下の通りです.
import numpy as np
from mpl_toolkits.mplot3d import Axes3D
import matplotlib.pyplot as plt
import itertools as it
import matplotlib as mpl
#mpl.rcParams['xtick.labelsize'] = 6
#mpl.rcParams['ytick.labelsize'] = 6
mpl.rcParams['axes.labelsize'] = 6 #
attributes = ['SepalLength','SepalWidth','PetalLength','PetalWidth'] #
comb = list(it.combinations([0,1,2,3],3)) #
datas = []
labels = []
with open('Iris.txt','r') as f:
for line in f:
linedata = line.split(' ')
datas.append(linedata[:-1]) # 4 4
labels.append(linedata[-1].replace('
','')) #
datas = np.array(datas)
datas = datas.astype(float) #
labels = np.array(labels)
kinds = list(set(labels)) #3
fig = plt.figure()
plt.title(u' ',{'fontname':'STFangsong','fontsize':10})
for k in range(4):
ax = fig.add_subplot(int('22'+str(k+1)), projection='3d')
for label,color,marker in zip(kinds,['r','b','k','y','c','g'],['s','*','o','^']):
data = datas[labels == label ]
for i in range(data.shape[0]):
if i == 0: # label , , label ,
ax.plot([data[i,comb[k][0]]], [data[i,comb[k][1]]], [data[i,comb[k][2]]], color=color,marker = marker,markersize=3,label = label,linestyle = 'None')
else:
ax.plot([data[i,comb[k][0]]], [data[i,comb[k][1]]], [data[i,comb[k][2]]], color=color,marker = marker,markersize=3,linestyle = 'None')
ax.legend(loc="upper left",prop={'size':8},numpoints=1) #
ax.set_xlabel(attributes[comb[k][0]],fontsize=8) #
ax.set_ylabel(attributes[comb[k][1]],fontsize=8)
ax.set_zlabel(attributes[comb[k][2]],fontsize=8)
ax.set_xticks(np.linspace(np.min(datas[:,comb[k][0]]),np.max(datas[:,comb[k][0]]),3)) #
ax.set_yticks(np.linspace(np.min(datas[:,comb[k][1]]),np.max(datas[:,comb[k][1]]),3))
ax.set_zticks(np.linspace(np.min(datas[:,comb[k][2]]),np.max(datas[:,comb[k][2]]),3))
plt.subplots_adjust(wspace = 0.1,hspace=0.1) #
savefig('Iris.png',dpi=700,bbox_inches='tight')
plt.show()
3.アルゴリズム実装コード
attributes = ['SepalLength','SepalWidth','PetalLength','PetalWidth'] #
datas = []
labels = []
with open('Iris.txt','r') as f:
for line in f:
linedata = line.split(' ')
datas.append(linedata[:-1]) # 4 4
labels.append(linedata[-1].replace('
','')) #
datas = np.array(datas)
datas = datas.astype(float) #
labels = np.array(labels)
kinds = list(set(labels)) #3
means = datas.mean(axis=0) #
stds = datas.std(axis=0) #
N,M = datas.shape[0],datas.shape[1]+1 #N ,M
K = 3 #K
data = np.ones((N,M))
data[:,1:] = (datas - means)/stds #
W = np.zeros((K-1,M)) #
priorEs = np.array([1.0/N*np.sum(data[labels == kinds[i]],axis=0) for i in range(K-1)]) #
liklist = []
for it in range(1000):
lik = 0 #
for k in range(K-1):#
lik -= np.sum(np.dot(W[k],data[labels == kinds[k]].transpose()))
lik += 1.0/N*np.sum(np.log(np.sum(np.exp(np.dot(W,data.transpose())),axis = 0) +1)) #
liklist.append(lik) #
wx = np.exp(np.dot(W,data.transpose()))
probs = np.divide(wx,1+np.sum(wx,axis=0).transpose()) #K-1*N
posteriorEs = 1.0/N*np.dot(probs,data) #
gradients = posteriorEs - priorEs + 1.0/100 *W # , ,
W -= gradients #
#probM , data
probM = np.ones((N,K))
probM[:,:-1] = np.exp(np.dot(data,W.transpose()))
probM /= np.array([np.sum(probM,axis = 1)]).transpose() #
predict = np.argmax(probM,axis = 1).astype(int) #
#rights , labels
rights = np.zeros(N)
rights[labels == kinds[1]] =1
rights[labels == kinds[2]] =2
rights = rights.astype(int)
print ' :%d
'%np.sum(predict != rights) #
このアルゴリズムの収束性は、次の図のようになります.
4.混同行列
#
K = np.shape(list(set(rights).union(set(predict))))[0]
matrix = np.zeros((K,K))
for righ,predic in zip(rights,predict):
matrix[righ,predic] += 1
print "%-12s" % (" "),
for i in range(K):
print "%-12s" % (kinds[i]),
print '
',
for i in range(K):
print "%-12s" % (kinds[i]),
for j in range(K):
print "%-12d" % (matrix[i][j]),
print '
',
上記混同マトリクス(confusion matrix)の行は実際のカテゴリを表し、列はモデル出力カテゴリを表す.クラス
setosa
のサンプルは他のカテゴリと誤認されず、他のカテゴリのサンプルはクラスsetosa
と誤認されていないことが分かった.すなわち、私たちが訓練したモデルはクラスsetosa
のサンプルを他のクラスのサンプルと混同していない.4つのversicolor
サンプルが「virginica」と誤審され、4つの「virginica」サンプルがversicolor
と誤審された.全体的に見ると、分類効果は悪くない.5.更なるパッケージング
def LogisticRegression(datas,labels):
kinds = list(set(labels)) #3
means = datas.mean(axis=0) #
stds = datas.std(axis=0) #
N,M = datas.shape[0],datas.shape[1]+1 #N ,M
K = np.shape(list(set(labels)))[0] #K
data = np.ones((N,M))
data[:,1:] = (datas - means)/stds #
W = np.zeros((K-1,M)) #
priorEs = np.array([1.0/N*np.sum(data[labels == kinds[i]],axis=0) for i in range(K-1)]) #
for it in range(1000):
wx = np.exp(np.dot(W,data.transpose()))
probs = np.divide(wx,1+np.sum(wx,axis=0).transpose()) #K-1*N
posteriorEs = 1.0/N*np.dot(probs,data) #
gradients = posteriorEs - priorEs + 1.0/100 *W # , ,
W -= gradients #
#probM , data
probM = np.ones((N,K))
probM[:,:-1] = np.exp(np.dot(data,W.transpose()))
probM /= np.array([np.sum(probM,axis = 1)]).transpose() #
return probM
if __name__ == "__main__":
datas = []
labels = []
with open('Iris.txt','r') as f:
for line in f:
linedata = line.split(' ')
datas.append(linedata[:-1]) # 4 4
labels.append(linedata[-1].replace('
','')) #
datas = np.array(datas)
datas = datas.astype(float) #
labels = np.array(labels)
kinds = list(set(labels)) #3
#LogisticRegression , ( )
m,n = 1,2
sel = np.array(list(where(labels == kinds[m])[0])+(list(where(labels == kinds[n])[0])))
slabels = labels[sel]
sdatas = datas[sel]
probM = LogisticRegression(sdatas,slabels)
print probM