Tensorflow+matplotlibで推論&結果の表示


1. はじめに

TensorflowやPyTorch、Chainerでモデルの評価はできたけど、Deep Learningをあまり知らない人にAccuracyやLossのグラフだけ見せても…って事ありませんか?
また、手っ取り早く、この画像は正解!これは不正解!ってのが見られると嬉しい場面もあるでしょう。
そんな願いを叶えるため、複数並んだ画像に判定結果を載せていく、というのをMatplotlibで実現してみたいと思います。

「―――喜べ少年。君の願いは、ようやく叶う」

また、今回は例としてTensorflowを用いますが、画像の表示部分はフレームワークが何であっても大丈夫です。

2. 実装 & 解説

例として、Tensorflowを使ってMNISTの学習をしたモデルを用意しました。
今回、画像は40枚使用します。

validation.py
from matplotlib import pyplot as plt
import numpy as np
import tensorflow as tf

# 表示画像枚数の設定
row = 4
col = 10

# データのロード
mnist = tf.keras.datasets.mnist
(_, _), (x_test, y_test) = mnist.load_data()
x_test = np.asarray(x_test[0:row*col])
y_test = np.asarray(y_test[0:row*col])

# モデルのロード
path = 'mnist.h5' # 学習済みモデルのパス
model = tf.keras.models.load_model(path)

# 推論
x_test_flat = x_test.reshape(-1, 784) / 255.0
result = model.predict(x_test_flat)

# 画像の整列
plt.figure(figsize=(5, 5))
image_array = []
for i, image in enumerate(x_test):
    image_array.append(plt.subplot(row, col, i + 1))
    plt.axis('off')
    plt.imshow(image, cmap='gray')
plt.pause(0.1)

# ラベルの配置
for j, image in enumerate(x_test):
    bg_color = 'skyblue' if y_test[j] == np.argmax(result[j]) else 'red'
    image_array[j].text(0, 0, np.argmax(result[j]), color='black', backgroundcolor=bg_color)
    plt.pause(0.1)

# 画像全体の保存
plt.savefig('judge_result.png')

推論について

学習・評価をして認識率のグラフを出力するところまでは色々なサイトに書いてあります。
しかし、学習したモデルで推論して何かする、というところを書いてあるのは意外に少なかったりします。(自分の体感ではありますが…)

x_test_flat = x_test.reshape(-1, 784) / 255.0
result = model.predict(x_test_flat)

tensorflow.keras.modelsにはpredict()というメソッドがあり、推論ではこれを使用します。
このメソッドに、推論したい画像の配列を渡します。モデルの入力は一次元にしているため、reshape(-1, 784)で一次元配列に変換しました。
今回は40枚を一度に処理するので、(40, 784)の配列を渡しますが、1枚だけ処理する時も(1, 784)として渡す必要があります。

Chainerではresult = model.predictor(test_x).data[0]と記述することで推論が可能(なはず)です。

画像の表示について

matplotlibでは、オブジェクト指向的に記述することが出来ます。

plt.figure(figsize=(5, 5))
image_array = []
for i, image in enumerate(x_test):
    image_array.append(plt.subplot(row, col, i + 1))
    plt.axis('off')
    plt.imshow(image, cmap='gray')
plt.pause(0.05)

1つ目のfor文で、plt.subplotを利用して画像を並べていきます。
plt.subplotfigureに対する子要素です。引数に(縦、横、何番目か)を渡します。
何番目かを表す数字は、1から数えます。(0からでないので注意)
まずは全画像を表示し、あとから子要素にラベルを追加してくため、操作できるようにappend()しておきましょう。
(ただ表示するだけなら配列に入れておく必要はありませんが、後からラベルを追加したいので、今回はこのようにしておきます。)
また、今回はグラフではなく画像なので、plt.axis('off')で座標の表示を消しています。
一度に並べる画像の用意が終わったら、plt.pause()で画像を表示します。
plt.show()にしてしまうと、そこで処理が止まってしまうので、plt.pause()を使っています。

for j, image in enumerate(x_test):
    bg_color = 'skyblue' if y_test[j] == np.argmax(result[j]) else 'red'
    image_array[j].text(0, 0, np.argmax(result[j]), color='black', backgroundcolor=bg_color)
    plt.pause(0.05)

2つ目のfor文では、image_arrayの要素に一つずつラベルを追加していきます。
image_array[j].text(0, 0, np.argmax(result[j])で推論結果を画像に追加します。
推論結果np.argmax(result[j])と、正解ラベルy_test[j]が一致したら背景を青に、不一致なら赤にしてみました。
plt.pause()で画面にラベルを表示します。引数は画像を表示する時間であり、この数値を変えることで表示の更新速度が変わります。あくまで"表示"の更新速度であり、モデルの処理速度ではないため注意しましょう。

3. まとめ

今回作成したモノですが、学生実験のプレゼンで使いたかったけど、当時の自分は実装できず…
今考えたら簡単でも、学習段階によっては難しいことってありますよね。
Deep Learningを始めて間もない人や、プログラミングの講義を受けてきたけど嫌いであまり理解できていないという人に届いたら幸いです。