マルチカテゴリタスク混同マトリクス(pythonコード実装)


from sklearn.metrics import confusion_matrix
import matplotlib.pyplot as plt
import numpy as np

def plot_confusion_matrix(cm, result_path, title='Confusion Matrix'):

    plt.figure(figsize=(4, 4), dpi=300)
    np.set_printoptions(precision=2)

    #             
    ind_array = np.arange(len(classes))
    x, y = np.meshgrid(ind_array, ind_array)
    for x_val, y_val in zip(x.flatten(), y.flatten()):
        c = cm[y_val][x_val]
        plt.text(x_val, y_val, "%0.2f" % (c,), color="white"  if c > cm.max()/2 else "black", fontsize=10, va='center', ha='center')
    
    plt.imshow(cm, interpolation='nearest', cmap=plt.cm.Blues)
    plt.title(title)
    plt.colorbar()
    xlocations = np.array(range(len(classes)))
    plt.xticks(xlocations, classes)
    plt.yticks(xlocations, classes)
    plt.ylabel('Ground trurh')
    plt.xlabel('Predict')
    
    # offset the tick
    tick_marks = np.array(range(len(classes))) + 0.5
    plt.gca().set_xticks(tick_marks, minor=True)
    plt.gca().set_yticks(tick_marks, minor=True)
    plt.gca().xaxis.set_ticks_position('none')
    plt.gca().yaxis.set_ticks_position('none')
    plt.grid(True, which='minor', color="gray", linestyle='-')
    plt.gcf().subplots_adjust(bottom=0.05)

  
    # show confusion matrix
    plt.savefig(result_path[:-4]+'.png', format='png')
    plt.show()

classes = ['M0', 'M1', 'M2']

random_numbers = np.random.randint(3, size=50)  # 6   ,    50   
y_true = random_numbers.copy()  #       
random_numbers[:10] = np.random.randint(3, size=10)  #   10           
y_pred = random_numbers  #       


result_paths=['DL_train.csv', 'DLC_train.csv','DL_test.csv', 'DLC_test.csv']

for result_path in result_paths:
    with open(result_path, 'r') as f:
        result_list = f.read()
    
    result_list = result_list.split('
')[1:-1] result_list = [result.split(',') for result in result_list] id_list = [int(result[0]) for result in result_list] y = np.array([float(result[1]) for result in result_list]) p = np.array([float(result[2]) for result in result_list]) p[p<0.5]=0 p[(p>0.5)*(p<1.5)]=1 p[p>1.5]=2 cm = confusion_matrix(y, p) plot_confusion_matrix(cm, result_path, title='Confusion matrix',) print(result_path, (cm[0,0]+cm[1,1]+cm[2,2])/cm.sum())