Tensorflow2.0 kerasモジュールmnistデータセット識別を実現


tf.kerasはネットワークの8つの株を構築します:
六歩法:0-』1-』3-』4-』6-』7
その他は、データの強化、ブレークポイントの継続などの拡張操作に属します.
目次
0. import  ....#関連ライブラリのインポート
1.train,test#テストデータ、トレーニングデータ
2.ImageDataGenerator#データ強化
3. model = tf.keras.model.Sequential#ネットワーク構造の構築
4. model.compile#構成環境
5.断点続訓
6. model.fit#フィードデータ
7. model.summary#印刷パラメータ
8.サンプルコード
0. import  ....#関連ライブラリのインポート
1.train,test#テストデータ、トレーニングデータ
2.ImageDataGenerator#データ強化
#     
image_gen_train = ImageDataGenerator (
    rescale =1.0/1.0,   #
    rotation_range = 45, #    45 
    width_shift_range = 0.15, #    
    height_shift_range = 0.15, #    
    horizontal_flip = True , #    
    zoom_range = 0.5#   0.5
)
image_gen_train.fit(x_train)

3. model = tf.keras.model.Sequential#ネットワーク構造の構築
model =tf.keras.model.Sequential([
    tf.keras.layers.Flatten(),  #   
    tf.keras.layers.Dense(     , activation= '    ',kernel_regularizer =      ), #    
    #activation   relu, softmax, sigmoid, tanh
    #kernel_regularizer   :tf.keras.regularizer.l1(), tf.keras.regularizer.l2()
    tf.keras.layers.Conv2D(filters=      , kernel_size =      , strides =    ,
    padding = 'valid' or 'same' ),  #   
    MaxPool2D(2,2), #   
    tf.keras.layers.LSTM()    #LSTM 
])

4. model.compile#構成環境
model.compile(optimizer =   , loss =     , metrices = ['   '])
#optimizer  :'sgd' or tf.keras.optimizers.SGD(lr=    , momentum =     )
# 'adagrad' or tf.keras.optimizers.Adagrad(lr=   )
# 'adadelta' or tf.keras.optimizers.Adadelta(lr =   )
# 'adm' or tf.keras.optimizers.Adam(lr =    , beta_2 = 0.9, beta_2 = 0.999)
# loss  : 'mse' or tf.keras.losses.MeanSquaredError()
# 'sparse_categorical_crossentropy' or tf.keras.losses.SparseCategoricalCrossentropy (from logits = False)
# metrices   :'accuracy':y_ y    
# 'categorical_accuracy':y_ y      
# 'sparse_categorical_accuracy' y_  ,y     

5.断点続訓
#         
checkpoint_save_path = './checkpoint/fashion_mnist.ckpt'
if os.path.exists(checkpoint_save_path +'.index'):
    print('----------------------------load the model--------------------------------')
    model.load_weights(checkpoint_save_path)

cp_callback = tf.keras.callbacks.ModelCheckpoint(
    filepath =checkpoint_save_path,
    save_weights_only= True,
    save_best_only= True)

#4.            
history = model.fit(image_gen_train(x_train, y_train, batch_size = 32), epochs =5, 
    validation_data = (x_test, y_test), validation_freq = 1,
    callbacks = [cp_callback])

6. model.fit#フィードデータ
model.fit(        ,      , batch_size = ..., epochs = ..., 
validation_data = (        ,     ),
validation_split =               ,
validation_freq =    epochs    )

7. model.summary#印刷パラメータ
8.サンプルコード
# 0.    
import tensorflow as tf
from tensorflow.keras.layers import Dense, Flatten
from tensorflow.keras import Model
from PIL import Image
import numpy as np
import os
import matplotlib.pyplot as plt
from tensorflow.keras.preprocessing.image import ImageDataGenerator

#1.      、   
mnist = tf.keras.datasets.mnist
(x_train, y_train), (x_test, y_test) = mnist.load_data()
x_train, x_test = x_train/255.0, x_test/255.0

#     
x_train = x_train.reshape(x_train.shape[0], 28, 28, 1)
image_gen_train = ImageDataGenerator (
    rescale =1.0/1.0,   #
    rotation_range = 45, #    45 
    width_shift_range = 0.15, #    
    height_shift_range = 0.15, #    
    horizontal_flip = True , #    
    zoom_range = 0.5#   0.5
)
image_gen_train.fit(x_train)

#2.       
class MnistModel(Model):
    def __init__(self):#   
        super(MnistModel, self).__init__()
        self.flatten = Flatten()
        self.d1 = Dense(128, activation = 'relu')
        self.d2 = Dense(10, activation = 'softmax')
    #        
    def call (self, x):
        x = self.flatten(x)
        x = self.d1(x) 
        y = self.d2(x)
        return y

model = MnistModel()

#3.       
model.compile(optimizer ='adam',
loss = tf.keras.losses.SparseCategoricalCrossentropy(from_logits=False),
metrics = ['sparse_categorical_accuracy'])

#         
checkpoint_save_path = './checkpoint/mnist.ckpt'
if os.path.exists(checkpoint_save_path +'.index'):
    print('----------------------------load the model--------------------------------')
    model.load_weights(checkpoint_save_path)

cp_callback = tf.keras.callbacks.ModelCheckpoint(
    filepath =checkpoint_save_path,
    save_weights_only= True,
    save_best_only= True)

#4.            
history = model.fit(image_gen_train(x_train, y_train, batch_size = 32), epochs =5, validation_data = (x_test, y_test), validation_freq = 1,
callbacks = [cp_callback])
#5.          
model.summary()
#-----------------------------  loss acc   -----------------------------------------------
loss = history.history['loss']
val_loss = history.history['val_loss']
acc = history.history['sparse_categorical_accuracy']
val_acc = history.history['val_sparse_categorical_accuracy']

plt.subplot (1,2,1)
plt.plot(acc, label = 'training accuracy')
plt.plot(val_acc, label = 'validation accuracy')
plt.title('training and validation accuracy curve')
plt.legend()

plt.subplot (1,2,2)
plt.plot(loss, label = 'training loss')
plt.plot(val_loss, label = 'validation loss')
plt.title('training and validation loss curve')
plt.legend()
plt.show()
#------------------------------------------  -----------------------
def img_trans(img_path):
    img = Image.open(img_path)
    img = img.resize((28,28), Image.ANTIALIAS)
    img_arr = np.array (img.convert('L'))

    for i in range(28):
        for j in range(28):
            if img_arr[i][j]<200:
                img_arr [i][j] =255
            else:
                img_arr[i][j] =0
    img_arr = 255-img_arr
    img_arr = img_arr/255.0
    return img_arr

num  = int(input('the number of images: '))
for i in range(num):
    img_path = input('the path of image:')
    img_arr = img_trans(img_path)
    x_predict = img_arr [tf.newaxis, ...]
    result = model.predict(x_predict)
    print(result)
    pred = tf.argmax(result, axis =1)
    print('success predict
') print(pred)

拡張:
fashionデータベースとは、インポートデータが異なることです
mnist = tf.keras.datasets.fashion_mnist
(x_train, y_train), (x_test, y_test) = mnist.load_data()
x_train, x_test = x_train/255.0, x_test/255.0