Kerasで始める機械学習(MNIST編)
はじめに
機械学習におけるHelloWorld的位置にあるMNIST(手書き文字)の認識をKerasを用いて行います。
想定読者は機械学習、DeepLearningについて概念を学んだのでKerasで簡単な実装を行ないたい方です。
実行環境
macOS Mojave
Python 3.6.8
Keras 2.2.4
MNISTデータセットのLOAD
keasと必要なライブラリをImportし、MNISTデータセットをロードします。
MNISTデータはkerasに実装されている関数load_dataを利用し取得できます。
load_dataの詳細については公式サイトを参照ください。
import keras
import numpy
from keras.datasets import mnist
# MNISTデータのロード
train, test = mnist.load_data()
X_train, y_train = train
X_test, y_test = test
# サイズを変更(画像数、解像度(縦、横)、白黒)
X_train = X_train.reshape(60000,28,28,1)
X_test = X_test.reshape(10000,28,28,1)
MNISTデータを確認する
ロードしたMNISTデータをmatplotlibを使って確認してみます。
参考:matplotlib.pyplot.imshow
import matplotlib.pyplot as plt
def show_img(img, figsize=(2,2)):
fig = plt.figure(figsize=figsize,dpi=100)
plt.imshow(img, cmap = 'gray', interpolation = 'bicubic')
plt.xticks([]), plt.yticks([]) # to hide tick values on X and Y axis
plt.show()
# X_trainの一番目の画像を表示
show_img(X_train[0])
出力結果
数字の”5”と思われる画像が確認できます。
One-Hot encoding
教師ラベルをOne-Hot Encodingします。
KerasではOne-hot encodingする関数としてto_categoricalが用意されています。
One-hot とは1つだけHigh(1)であり、他はLow(0)であるようなビット列のことです。(wikipedia)
MNISTは0~9の教師ラベルを持つ手書き画像なので例えば教師ラベルの[5]は[0,0,0,0,0,1,0,0,0,0]のように6番目のみ1が立つようにEncodeされます。
from keras.utils.np_utils import to_categorical
y_trn_cgl = to_categorical(y_train)
y_test_cgl = to_categorical(y_test)
print(f"train_labels.shape:{y_train.shape}, y_trn.shape:{y_test.shape}")
print(f"y_trn_cgl:{y_trn_cgl}")
print(f"y_test_cgl:{y_test_cgl}")
CNNモデルの構築
モデルの構造は (2D Conv -> Relu -> MaxPooling)*2 -> 全結合層 -> Softmax
となるシンプルなニューラルネットを構築します。
各レイヤの仕様について公式サイトのリンクを貼っておきます。
・Conv2D:2次元畳み込み層
・Activation:活性化関数(今回はReluを使用)
・MaxPooling2D:Maxプーリング層
・Dense:全結合層
from keras.models import Model
from keras.layers import Input, Conv2D, MaxPooling2D, Dropout, Flatten
def create_cnn_model(input_size):
# 入力データサイズを指定
X_input = Input((input_size[0], input_size[1], 1))
# 2次元畳み込み層。
X = Conv2D(filters=64, kernel_size=(5,5), padding="valid")(X_input)
# 活性化関数はReluを使用
X = Activation('relu')(X)
# MaxPooling層
X = MaxPooling2D(pool_size = (2,2))(X)
X = Conv2D(filters=128, kernel_size=(5,5), padding="valid")(X)
X = Activation('relu')(X)
X = MaxPooling2D(pool_size = (2,2))(X)
X = Flatten()(X)
# 全結合層。MNISTでは10クラスへ分類するため出力10で指定
X = Dense(10, activation="softmax")(X)
model = Model(inputs=X_input, outputs=X)
return model
CNNモデルの学習
MNISTデータの準備とCNNモデルの準備ができましたので、
モデルの学習を行なっていきます。
from keras import optimizers
# セッションの初期化
K.clear_session()
# 入力画像サイズを指定
input_size = (28, 28)
# モデルを作成
cnn_model = create_cnn_model(input_size)
# 最適化アルゴリズムを指定(lr=学習率)
opt = optimizers.Adagrad(lr=1e-4)
# 多クラス分類用の損失関数の指定
loss = "categorical_crossentropy"
metrics = ['accuracy']
# 設定を反映
cnn_model.compile(optimizer=opt, loss=loss, metrics=metrics)
#学習を実行
# 学習経過をhistory_cnnに格納してます。
history_cnn = cnn_model.fit(X_train[0:6000], y_trn_cgl[0:6000], epochs=10)
テストデータの予測
学習したモデルでテストデータを予測します。
predict = cnn_model.predict(X_test)
eval_loss, eval_acc = cnn_model.evaluate(X_test, y_test_cgl)
print("eval_loss:{0}\neval_acc{1}".format(eval_loss, eval_acc))
def show_history(history):
fig, ax = plt.subplots(1, 2, figsize=(15,5))
ax[0].set_title('loss')
ax[0].plot(history.epoch, history.history["loss"], label="Train loss")
#ax[0].plot(history.epoch, history.history["val_loss"], label="Validation loss")
ax[1].set_title('categorical_accuracy')
ax[1].plot(history.epoch, history.history["acc"], label="Train accuracy")
#ax[1].plot(history.epoch, history.history["val_acc"], label="Validation accuracy")
ax[0].legend()
ax[1].legend()
plt.show()
show_history(history_cnn)
実行結果
左の図がLossを、右の図が正答率(categorical accuracy)を示しています。
横軸はEpoch数、つまり学習回数です。
学習回数を重ねるにつれLossが下がり学習が進み、それに伴い正答率が向上していることがわかります。
MNISTは比較的簡単な画像分類のため10回の学習で85%近い正答率が出せることがわかります。
特に工夫はしていないCNNモデルでも学習回数を増やせば90%近い正答率は出せると思います。
参考リンク
Author And Source
この問題について(Kerasで始める機械学習(MNIST編)), 我々は、より多くの情報をここで見つけました https://qiita.com/note-tech/items/8e799e4bad9af36b2c4f著者帰属:元の著者の情報は、元のURLに含まれています。著作権は原作者に属する。
Content is automatically searched and collected through network algorithms . If there is a violation . Please contact us . We will adjust (correct author information ,or delete content ) as soon as possible .