Light GBMの多クラス分類/二値分類を可視化するConfusion Matrix Plot


Introduction

Summary

・scikit-learn 0.22のアップデートであるconfusion matrixのplotを試してみた
・従来のscikit-learnのconfusion matrixはarrayで出力していたのでグラフ映えしなかった
・LightGBMを用いる場合はファンクションの引数の分類モデルを
 scikit-learnインターフェースで学習させないとできなかった
 scikit-learnの呪い?
・plotするメリットとしては、ファンクション内でy_predを勝手に計算してくれるので
 ちょっと楽なのと見やすくて映える

Dataset

みんな大好きiris。目的変数はもちろんspecies。

Body

準備

まずはアップデートを忘れずに。

ライブラリのアップデート
!pip install --upgrade scikit-learn

下記の_plot_confusion_matrixが今回のテーマのメイン。

今回必要なライブラリのimport
#default
import numpy as np
import pandas as pd
import seaborn as sns

#for_modeling
import lightgbm as lgb

#for_plot
import matplotlib.pyplot as plt
from sklearn.metrics import confusion_matrix, plot_confusion_matrix
from sklearn.model_selection import train_test_split

irisデータセットはseabornライブラリからimportします。

データセットの整理
df = sns.load_dataset('iris')

X=df.drop(["species"],axis=1)
y=df["species"]

SEED=5
X_train_tmp, X_test,y_train_tmp,y_test=train_test_split(X, y, test_size=0.2, random_state=SEED)
X_train, X_val, y_train, y_val=train_test_split(X_train_tmp, y_train_tmp, test_size=0.2, random_state=SEED)

オリジナルインターフェースで学習させたパターン

Summaryで書いた通り失敗するのだが、
まずはLightGBMモデルをオリジナルのインターフェースで記述したケースで実行してみる。

オリジナルインターフェースでの学習
#失敗例
lgb_params = {
    'objective':'multiclass',
    'n_estimators':1000,
    'seed': SEED,
    'early_stopping_rounds':100
} 

lgb_train = lgb.Dataset(X_train, y_train)
lgb_eval = lgb.Dataset(X_val, y_val)

model = lgb.train(
            lgb_params,
            lgb_train,
            valid_sets=lgb_eval,
            verbose_eval = 15
            )

confusion matrixを出力してみる。
引数は下記の通りで、estimatorに学習させたモデル、Xに検証部分の特徴量(X_test)、Yに正解データ(Y_test)を入れることで出力される。
y_predは自動的に計算されるので宣言が必要ない。

sklearn.metrics.plot_confusion_matrix(estimator, X, y_true, labels=None, sample_weight=None, normalize=None, display_labels=None, include_values=True, xticks_rotation='horizontal', values_format=None, cmap='viridis', ax=None)

引用元:scikit-learn公式ドキュメント

プロット出力
plot_confusion_matrix(model,X_test,y_test)

すると下記のエラーが返ってくる。
estimatorに指定しているモデルが分類モデルと認識されていない様。

エラー文
---------------------------------------------------------------------------
ValueError                                Traceback (most recent call last)
<ipython-input-125-5b518d7bfee9> in <module>()
----> 1 plot_confusion_matrix(model,X_test,y_test)

/usr/local/lib/python3.6/dist-packages/sklearn/metrics/_plot/confusion_matrix.py in plot_confusion_matrix(estimator, X, y_true, labels, sample_weight, normalize, display_labels, include_values, xticks_rotation, values_format, cmap, ax)
    183 
    184     if not is_classifier(estimator):
--> 185         raise ValueError("plot_confusion_matrix only supports classifiers")
    186 
    187     if normalize not in {'true', 'pred', 'all', None}:

ValueError: plot_confusion_matrix only supports classifiers

たしかLightGBMをscikit-learnインターフェースで記述する場合であれば
LGBMClassifierといういかにも分類モデルらしいLightGBMの記述方法があったことを思い出す。
試しにその記述方法で試してみる。

scikit-learnインターフェースで学習させたパターン

scikit-learnインターフェースでの学習
model2 = lgb.LGBMClassifier(objective='multiclass',
                        n_estimators=1000,
                        seed=SEED,
                        early_stopping_rounds=100)
model2.fit(X_train, y_train,
        eval_set=[(X_val, y_val)],
        verbose=50)

model2で再トライしてみる。

プロット再出力
plot_confusion_matrix(model2,X_test,y_test)

出た。
seabornでのplotのようにデフォルトで見やすいプロットが出力された。

タイトルをつける場合は下記のように記載すれば出力される。

タイトルを追加
disp=plot_confusion_matrix(model2,X_test,y_test)
disp.ax_.set_title("Confusion Matrix")
plt.show()

ちなみにコードは割愛するがtitanicの二値分類でも全く同様にして出力できた。
逆に従来のconfusion_matrix.confusion_matrixの場合は二値分類しかできず、多クラス分類だとできなかった。

Conclusion

・scikit-learnの新機能のconfusion matrix plotを用いることでかなりシンプルな記載で描画することができた。
・LightGBMの場合は現状、scikit-learnインターフェースでの学習させた方が良い。
・従来のconfusion_matrix.confusion_matrixとは異なり、二値分類でも多クラス分類でも対応している。