多クラス分類の結果の混同行列の集計、評価指標の計算


多クラス分類で、未知データを入力とした際、各クラスへの分類確率の予測値が算出される。

テストデータを用いた評価の際、混同行列を書く必要があるので、この記事では、その書き方を例を示す。

データの用意

3つのクラスに対して、5つのデータを分類器に入力した結果を results、正解を y_true とする。

results = [
    [0.1, 0.2, 0.7],
    [0.5, 0.3, 0.3],
    [0.2, 0.3, 0.2],
    [0.2, 0.7, 0.2],
    [0.2, 0.5, 0.2]
]
y_true = [
    [0, 0, 1],
    [1, 0, 0],
    [0, 0, 1],
    [0, 0, 1],
    [0, 1, 0]
]

混同行列を出す

利用するものとしては、 sklearn.metrics.confusion_matrix を使う。

sklearn.metrics.confusion_matrix

この入力に合う形にデータを整形する。

import numpy as np

results2 = [np.argmax(i) for i in results]
# [2, 0, 1, 1, 1]

y_true2 = [np.argmax(i) for i in y_true]
# [2, 0, 2, 2, 1]

よって以下のように集計できる。

from sklearn.metrics import confusion_matrix
confusion_matrix(y_true2, results2)

# array([[1, 0, 0],
#        [0, 1, 0],
#        [0, 2, 1]])

seabornで可視化する

seabornを使うことで、sklearn.metrics.confusion_matrixの結果を美しく可視化できるのでやってみよう。

import pandas as pd
import seaborn as sns
import matplotlib.pyplot as plt

labels = sorted(list(set(y_true2)))

df = pd.DataFrame(confusion_matrix(y_true2, results2), index=labels, columns=labels)

plt.figure()
sns.heatmap(df, annot=True) # annot = Falseにすると図中の各セルの数値表示がなくなる。
plt.show()

↑ 図がずれているのはこちらのバグ?のようだ。

おまけ

results を, y_true と同様に、binaryで表示すると、 かつての記事でも書いた sklearn.metrics.classification_report が利用できる。これにより、各種評価指標を一度に計算してくれる。

まずは変換。

b_results = []
for i in results:
    max_index = np.argmax(i)
    tmp = [0] * max(map(len, result))
    tmp[max_index] = 1
    b_results.append(tmp)
# [[0, 0, 1], [1, 0, 0], [0, 1, 0], [0, 1, 0], [0, 1, 0]]

sklearn.metrics.classification_report を利用する。

from sklearn.metrics import classification_report

print(classification_report(np.array(y_true), np.array(b_results)))

出力結果。

              precision    recall  f1-score   support

           0       1.00      1.00      1.00         1
           1       0.33      1.00      0.50         1
           2       1.00      0.33      0.50         3

   micro avg       0.60      0.60      0.60         5
   macro avg       0.78      0.78      0.67         5
weighted avg       0.87      0.60      0.60         5
 samples avg       0.60      0.60      0.60         5