MNISTを元に学習したモデルを用いて手書き文字を予測させてみた


MNISTデータセットを元に学習させたモデルを用いて,手書き文字を予測させてみる.

MNISTとは

MNISTとは,手書き数字画像のデータセットのことである.トレーニング用画像60,000枚とテスト用画像10,000枚で構成されている.

学習モデルの作成

こちらの学習サンプルコードを使用させていただいた.
(
https://github.com/keras-team/keras/blob/master/examples/mnist_mlp.py)

実行結果は以下の通り

実行結果
60000 train samples
10000 test samples
_________________________________________________________________
Layer (type)                 Output Shape              Param #   
=================================================================
dense_1 (Dense)              (None, 512)               401920    
_________________________________________________________________
dropout_1 (Dropout)          (None, 512)               0         
_________________________________________________________________
dense_2 (Dense)              (None, 512)               262656    
_________________________________________________________________
dropout_2 (Dropout)          (None, 512)               0         
_________________________________________________________________
dense_3 (Dense)              (None, 10)                5130      
=================================================================
Total params: 669,706
Trainable params: 669,706
Non-trainable params: 0
_________________________________________________________________
Train on 60000 samples, validate on 10000 samples
Epoch 1/20
60000/60000 [==============================] - 2s 34us/step - loss: 0.2460 - acc: 0.9250 - val_loss: 0.1143 - val_acc: 0.9665
Epoch 2/20
60000/60000 [==============================] - 1s 15us/step - loss: 0.1036 - acc: 0.9690 - val_loss: 0.0774 - val_acc: 0.9748
Epoch 3/20
60000/60000 [==============================] - 1s 15us/step - loss: 0.0748 - acc: 0.9778 - val_loss: 0.0769 - val_acc: 0.9780
Epoch 4/20
60000/60000 [==============================] - 1s 16us/step - loss: 0.0603 - acc: 0.9815 - val_loss: 0.0782 - val_acc: 0.9794
Epoch 5/20
60000/60000 [==============================] - 1s 14us/step - loss: 0.0506 - acc: 0.9844 - val_loss: 0.0804 - val_acc: 0.9802
Epoch 6/20
60000/60000 [==============================] - 1s 14us/step - loss: 0.0442 - acc: 0.9873 - val_loss: 0.0858 - val_acc: 0.9810
Epoch 7/20
60000/60000 [==============================] - 1s 15us/step - loss: 0.0385 - acc: 0.9885 - val_loss: 0.0730 - val_acc: 0.9842
Epoch 8/20
60000/60000 [==============================] - 1s 13us/step - loss: 0.0342 - acc: 0.9899 - val_loss: 0.0877 - val_acc: 0.9807
Epoch 9/20
60000/60000 [==============================] - 1s 13us/step - loss: 0.0324 - acc: 0.9905 - val_loss: 0.0966 - val_acc: 0.9793
Epoch 10/20
60000/60000 [==============================] - 1s 13us/step - loss: 0.0291 - acc: 0.9920 - val_loss: 0.0832 - val_acc: 0.9843
Epoch 11/20
60000/60000 [==============================] - 1s 14us/step - loss: 0.0276 - acc: 0.9923 - val_loss: 0.0898 - val_acc: 0.9834
Epoch 12/20
60000/60000 [==============================] - 1s 14us/step - loss: 0.0259 - acc: 0.9926 - val_loss: 0.0923 - val_acc: 0.9828
Epoch 13/20
60000/60000 [==============================] - 1s 17us/step - loss: 0.0262 - acc: 0.9929 - val_loss: 0.1002 - val_acc: 0.9830
Epoch 14/20
60000/60000 [==============================] - 1s 14us/step - loss: 0.0214 - acc: 0.9941 - val_loss: 0.1082 - val_acc: 0.9825
Epoch 15/20
60000/60000 [==============================] - 1s 14us/step - loss: 0.0209 - acc: 0.9943 - val_loss: 0.0967 - val_acc: 0.9837
Epoch 16/20
60000/60000 [==============================] - 1s 14us/step - loss: 0.0209 - acc: 0.9945 - val_loss: 0.1151 - val_acc: 0.9825
Epoch 17/20
60000/60000 [==============================] - 1s 16us/step - loss: 0.0183 - acc: 0.9951 - val_loss: 0.1004 - val_acc: 0.9845
Epoch 18/20
60000/60000 [==============================] - 1s 16us/step - loss: 0.0178 - acc: 0.9954 - val_loss: 0.1169 - val_acc: 0.9823
Epoch 19/20
60000/60000 [==============================] - 1s 22us/step - loss: 0.0199 - acc: 0.9954 - val_loss: 0.1193 - val_acc: 0.9835
Epoch 20/20
60000/60000 [==============================] - 1s 13us/step - loss: 0.0180 - acc: 0.9952 - val_loss: 0.1114 - val_acc: 0.9836
Test loss: 0.11135620580659124
Test accuracy: 0.9836

動くことは確認できたので,このサンプルコード内に下の一行を書き加えて学習したモデルを保存する.

model.save('model.h5')

使用データ

スマートフォンのアプリで作成したものをWindowsのペイントで編集し使用する.
画像は1~9各1枚ずつであり,ファイル名は数字が1なら1.pngとしている.

画像の予測

以下のコードを用いて予測させた.

mnist_predict.py
import keras
import numpy as np
from keras.models import load_model
from keras.preprocessing.image import img_to_array,load_img
import glob
from PIL import Image

model = load_model('model.h5')
files = glob.glob('./num/*.png')

for file in files:
    array = []
    img = load_img(file, grayscale=True, target_size=(28,28))
    array.append(img_to_array(img))
    array = np.asarray(array)
    array = np.reshape(array, (1,784))
    array = array.astype('float32')
    array = array / 255.0

    features = model.predict(array)
    print("Correct:%s Predict:%d" % (file[6], features.argmax()))

実行結果は以下の通り

実行結果
Correct:9 Predict:3
Correct:1 Predict:1
Correct:5 Predict:5
Correct:4 Predict:4
Correct:3 Predict:3
Correct:8 Predict:8
Correct:7 Predict:1
Correct:2 Predict:2
Correct:6 Predict:6

間違えたのは7と9という結果になった.
7はわかるけど9は...うーん
CNN使えばもっと良くなるのかな?また色々試したいと思う.