ResNet 18ネットワークのpythonでのTensorFlow 2の実装

参考資料:北京大学、ソフトマイクロ学院、曹健先生、『人工知能実践:TensorFlow 2.0ノート』運行環境:python 3.7 tensorflow 2.1.0 numpy 1.17.4 matplotlib 3.2.1
# resnet18
import tensorflow as tf
import os
import numpy as np
from matplotlib import pyplot as plt
from tensorflow.keras.layers import Conv2D, BatchNormalization, Activation, MaxPool2D, Dropout, Flatten, Dense
from tensorflow.keras import Model


cifar10 = tf.keras.datasets.cifar10
(x_train, y_train), (x_test, y_test) = cifar10.load_data()

x_train, x_test = x_train / 255.0, x_test / 255.0

class ResnetBlock(Model):
    def __init__(self, filters, strides=1, residual_path=False):
        super(ResnetBlock, self).__init__()
        self.filters = filters
        self.strides = strides
        self.residual_path = residual_path

        self.c1 = Conv2D(filters, (3, 3), strides=strides, padding='same', use_bias=False)
        self.b1 = BatchNormalization()
        self.a1 = Activation('relu')

        self.c2 = Conv2D(filters, (3, 3), strides=1, padding='same', use_bias=False)
        self.b2 = BatchNormalization()

        # residual_path True,        ,  1*1    ,  x F(x)    ,    。
        if residual_path:
            self.down_c1 = Conv2D(filters, (1, 1), strides=strides, padding='same', use_bias=False)
            self.down_b1 = BatchNormalization()

        self.a2 = Activation('relu')

    def call(self, inputs):
        residual = inputs  # residual      ,  residual=x
        #        、BN 、   、  F(x)
        x = self.c1(inputs)
        x = self.b1(x)
        x = self.a1(x)

        x = self.c2(x)
        y = self.b2(x)

        if self.residual_path:
            residual = self.down_c1(inputs)
            residual = self.down_b1(residual)

        out = self.a2(y + residual)  #            , F(x) F(x)+Wx,      
        return out

class ResNet18(Model):
    def __init__(self, block_list, initial_filters=64):  # block_list    block      
        super(ResNet18, self).__init__()
        self.num_block = len(block_list)  #     block
        self.block_list = block_list
        self.out_filters = initial_filters

        self.c1 = Conv2D(self.out_filters, (3, 3), strides=1, padding='same', use_bias=False)
        self.b1 = BatchNormalization()
        self.a1 = Activation('relu')

        self.blocks = tf.keras.Sequential()
        #   ResNet    
        for block_id in range(len(block_list)):  #    resnetblock
            for layer_id in range(block_list[block_id]):  #   block      

                if block_id != 0 and layer_id == 0:  #      block    block        
                    block = ResnetBlock(self.out_filters, strides=2, residual_path=True)
                    block = ResnetBlock(self.out_filters, residual_path=False)
            self.out_filters *= 2  #    block      2 

        self.p1 = tf.keras.layers.GlobalAveragePooling2D()

        self.f1 = tf.keras.layers.Dense(10, activation='softmax', kernel_regularizer=tf.keras.regularizers.l2())

    def call(self, inputs):
        x = self.c1(inputs)
        x = self.b1(x)
        x = self.a1(x)
        x = self.blocks(x)
        x = self.p1(x)
        y = self.f1(x)
        return y

model = ResNet18([2, 2, 2, 2])

checkpoint_save_path = './checkpoint/ResNet18.ckpt'
if os.path.exists(checkpoint_save_path + '.index'):
    print('-------------load the model-----------------')

cp_callback = tf.keras.callbacks.ModelCheckpoint(filepath=checkpoint_save_path,

history =, y_train, batch_size=128, epochs=5, validation_data=(x_test, y_test), validation_freq=1, callbacks=[cp_callback])

# print(model.trainable_variables)
file = open('./weights_ResNet18.txt', 'w')
for v in model.trainable_variables:
    file.write(str( + '
) file.write(str(v.shape) + '
) file.write(str(v.numpy()) + '
) file.close() ############################################### show ############################################### # acc loss acc = history.history['sparse_categorical_accuracy'] val_acc = history.history['val_sparse_categorical_accuracy'] loss = history.history['loss'] val_loss = history.history['val_loss'] plt.subplot(1, 2, 1) plt.plot(acc, label='Training Accuracy') plt.plot(val_acc, label='Validation Accuracy') plt.title('Training and Validation Accuracy') plt.legend() plt.subplot(1, 2, 2) plt.plot(loss, label='Training Loss') plt.plot(val_loss, label='Validation Loss') plt.title('Training and Validation Loss') plt.legend()