Irisデータを使ってAutoencoderを試してみた


目的

前回Autoencoderを使って直線回帰がうまくできたので、今回は”みんな大好き”irisデータを使って試してみたいと思います

内容

IrisデータのTarget毎に別々の特徴量データフレームを作り、それぞれを使ってAutoencoderを作り、入力と出力の差分からTarget(つまり同じ種類のIris)の違いを識別できるかの確認
イメージとしてはAutoencoderによる故障検知をIrisを使ってやるということです

コード

初期設定

iris_autoencoder.ipynb
import seaborn as sns

from sklearn.datasets import load_iris
from keras.models import Sequential
from keras.layers import Dense
import pandas as pd
import numpy as np

iris=load_iris()

X=pd.DataFrame(iris["data"])
y=iris["target"]
Xa={}
Xa[0]=X.loc[y==0]
Xa[1]=X.loc[y==1]
Xa[2]=X.loc[y==2]

モデル定義と可視化関数

iris_autoencoder.ipynb
def engine(j):
    encode_dim=2
    hidden_dim=50
    model=Sequential()
    model.add(Dense(hidden_dim,input_dim=X.shape[1]))
    model.add(Dense(encode_dim))
    model.add(Dense(hidden_dim))
    model.add(Dense(X.shape[1]))
    model.summary()

    model.compile(optimizer="adam",loss="mean_squared_error")

    model.fit(Xa[j],Xa[j],epochs=500,
              batch_size=150,verbose=0)

    import matplotlib.pyplot as plt
    for label in range(3):
        plt.scatter(Xa[label].iloc[:,0],Xa[label].iloc[:,1],label=label)
    pred=pd.DataFrame(model.predict(X))
    plt.scatter(pred.iloc[:,0],pred.iloc[:,1])
    plt.legend()
    plt.show()
    delta=pred-X
    for label in range(3):
        delta1=delta.loc[y==label]
        plt.scatter(delta1.iloc[:,0],delta1.iloc[:,1],label=label)
    plt.legend()
    plt.vlines([0],ymax=2,ymin=-2)
    plt.hlines([0],xmax=2,xmin=-2)
    plt.savefig("fig_1_"+str(j)+".png")
    plt.show()

    norm=[]
    for i in range(len(X)):
        norm.append(np.linalg.norm(delta.iloc[i,:]))
    norm=pd.DataFrame(norm)

    for i in range(3):
        normd=norm.loc[y==i]
        sns.distplot(normd)
    plt.savefig("fig_2_"+str(j)+".png")
    plt.show()

実行部分

iris_autoencoder.ipynb
for i in range(len(Xa)):
    engine(i)

実行結果と考察

下のヒストグラムは、入力ベクトルから出力ベクトルを引き算してノルムから算出したものです(要するに誤差の距離のヒストグラム)。みてわかる様にターゲットのみを学習しているので、ターゲットは比較的に0に近い値となり、それ以外のデータは距離が離れているのがわかります
なので、この程度はっきりした特徴量の差があればAutoencoderを使って故障検知ができると言えます。
この場合で言うと異なる種類の花の特徴量が混じっていればそれを指摘できると言うことになります


ターゲットは青です

ターゲットはオレンジです

ターゲットは緑です