【Fine-Tuning】Base-modelにInput-modelを追加して学習する新しいスタイルの提案♬


話はどこかにありそうですが、調べた限りは初出な話なようです。
ということで、一応うまくいったので、記事にしておこうと思います。

もともと、Fine-Tuningで見慣れたモデルと学習方法のバリエーションは以下のようなものだと思います。(参考より引用)
今回は、入力部分にもネットワークを付加してさらに学習の自由度を増やす方法を提案したいと思います。

【参考】(以下はpdfに直接リンクです)
Learning without Forgetting ;Zhizhong Li, Derek Hoiem, Member, IEEE arXiv:1606.09282v3 [cs.CV]14Feb 2017

やったこと

・VGG16モデルにtop_modelを変更し、さらに入力部分もinput_modelを追加
・input_modelとしてGaussianNoise_LayerとUpsampling_Layerを追加
・上記モデルをCifar-10とCifar-100に適用した時の精度

コードは以下に置きました

SPP/VGG16_finetuning_callback.py

・VGG16モデルにtop_modelを変更し、さらに入力部分もinput_modelを追加

今回は、入力部分にもネットワークを付加してさらに学習の自由度を増やす方法を提案したいと思います。
コード的には以下のように実装しました。
コード見ればなーんだなものなので、説明する必要はなさそうです。
一応苦労したのは、以下の二点です。
1.input_tensor = x_train.shape[1:]
2.vgg16 = VGG16(include_top=False, weights='imagenet', input_tensor=input_model.output)
あとは、通常のFine-Tuningと同一です。
ここでは、input-modelとして、GaussianNoise-LayerとUpSampling-Layerを試してみました。

input_tensor = x_train.shape[1:] 
input_model = Sequential()
input_model.add(InputLayer(input_shape=input_tensor))
input_model.add(GaussianNoise(0.01))
input_model.add(UpSampling2D((2, 2)))

# Fully-connected層(FC)はいらないのでinclude_top=False)
vgg16 = VGG16(include_top=False, weights='imagenet', input_tensor=input_model.output)

# FC層を構築
top_model = Sequential()
top_model.add(Flatten(input_shape=vgg16.output_shape[1:])) 
top_model.add(Dense(256, activation='relu'))
top_model.add(Dropout(0.75))
top_model.add(Dense(num_classes, activation='softmax'))

# VGG16とFCを接続
model = Model(input=input_model.input, output=top_model(vgg16.output))

# 最後のconv層の直前までの層をfreeze
#trainingするlayerを指定 VGG16では18,15,10,1など 20で全層固定
for layer in model.layers[1:1]:  
    layer.trainable = False

ちなみに、input_model.summary(), vgg16.summary(), top_model.summary()そしてmodel.summary()の結果は、おまけのとおりに出力されます。

・input_modelとしてGaussianNoise_LayerとUpsampling_Layerを追加

今回は、input_model()としてGaussianNoiseを入れたかったのと、入力時点でUpsamplingする実験をしたかったので入れましたが、さらにinput_model()に、conv2Dなどを入れると可能性は広がると思います。
※ここは追々やって行こうと思います

・上記モデルをCifar-10とCifar-100に適用した時の精度

Upsampling+GaussianNoiseの結果は事前に画像拡大して入力した場合よりちょっとだけ悪い結果となりました。
一方、GaussianNoiseの結果は以下のとおり、Cifar-10で96.08%Cifar-100では80.24%とVGG16ではそれぞれ最高の精度となりました。
これは、参考の①、②によればWideResnetに迫る精度なのでなかなかいい値だと思います。
【参考】
[email protected]
[email protected]
以下がそれぞれのval_acc, val_lossなどの値のプロットです。
Gasussianノイズを過学習を防ぐ目的で0.01入れていますが、まだまだ過学習な状態が残っています。
一方、Gaussianノイズ0.05ではここまでの精度は出ませんでした。





まとめ

・新しいFine-TuningのモデルとしてInput_modelを追加するモデルを提案した
・Cifar-10とCifar-100に適用したところ、96.08%と80.24%というVGG16としては最高精度を得た
 ※ただし入力画像サイズは(160,160,3)に拡大している

・今回提案した学習スタイルでさらにいろいろなケースを試してみたい

おまけ

input_model.summary()
_________________________________________________________________
Layer (type)                 Output Shape              Param #
=================================================================
input_1 (InputLayer)         (None, 32, 32, 3)         0
_________________________________________________________________
gaussian_noise_1 (GaussianNo (None, 32, 32, 3)         0
_________________________________________________________________
up_sampling2d_1 (UpSampling2 (None, 64, 64, 3)         0
=================================================================
Total params: 0
Trainable params: 0
Non-trainable params: 0
_________________________________________________________________
top_model.summary()
_________________________________________________________________
Layer (type)                 Output Shape              Param #
=================================================================
flatten_1 (Flatten)          (None, 2048)              0
_________________________________________________________________
dense_1 (Dense)              (None, 256)               524544
_________________________________________________________________
dropout_1 (Dropout)          (None, 256)               0
_________________________________________________________________
dense_2 (Dense)              (None, 10)                2570
=================================================================
Total params: 527,114
Trainable params: 527,114
Non-trainable params: 0
model.summary()
  model = Model(input=input_model.input, output=top_model(vgg16.output))
_________________________________________________________________
Layer (type)                 Output Shape              Param #
=================================================================
input_1 (InputLayer)         (None, 32, 32, 3)         0
_________________________________________________________________
gaussian_noise_1 (GaussianNo (None, 32, 32, 3)         0
_________________________________________________________________
up_sampling2d_1 (UpSampling2 (None, 64, 64, 3)         0
_________________________________________________________________
block1_conv1 (Conv2D)        (None, 64, 64, 64)        1792
_________________________________________________________________
block1_conv2 (Conv2D)        (None, 64, 64, 64)        36928
_________________________________________________________________
block1_pool (MaxPooling2D)   (None, 32, 32, 64)        0
_________________________________________________________________
block2_conv1 (Conv2D)        (None, 32, 32, 128)       73856
_________________________________________________________________
block2_conv2 (Conv2D)        (None, 32, 32, 128)       147584
_________________________________________________________________
block2_pool (MaxPooling2D)   (None, 16, 16, 128)       0
_________________________________________________________________
block3_conv1 (Conv2D)        (None, 16, 16, 256)       295168
_________________________________________________________________
block3_conv2 (Conv2D)        (None, 16, 16, 256)       590080
_________________________________________________________________
block3_conv3 (Conv2D)        (None, 16, 16, 256)       590080
_________________________________________________________________
block3_pool (MaxPooling2D)   (None, 8, 8, 256)         0
_________________________________________________________________
block4_conv1 (Conv2D)        (None, 8, 8, 512)         1180160
_________________________________________________________________
block4_conv2 (Conv2D)        (None, 8, 8, 512)         2359808
_________________________________________________________________
block4_conv3 (Conv2D)        (None, 8, 8, 512)         2359808
_________________________________________________________________
block4_pool (MaxPooling2D)   (None, 4, 4, 512)         0
_________________________________________________________________
block5_conv1 (Conv2D)        (None, 4, 4, 512)         2359808
_________________________________________________________________
block5_conv2 (Conv2D)        (None, 4, 4, 512)         2359808
_________________________________________________________________
block5_conv3 (Conv2D)        (None, 4, 4, 512)         2359808
_________________________________________________________________
block5_pool (MaxPooling2D)   (None, 2, 2, 512)         0
_________________________________________________________________
sequential_2 (Sequential)    (None, 10)                527114
=================================================================
Total params: 15,241,802
Trainable params: 15,241,802
Non-trainable params: 0
_________________________________________________________________
vgg16.summary()
_________________________________________________________________
Layer (type)                 Output Shape              Param #
=================================================================
input_1 (InputLayer)         (None, 32, 32, 3)         0
_________________________________________________________________
gaussian_noise_1 (GaussianNo (None, 32, 32, 3)         0
_________________________________________________________________
up_sampling2d_1 (UpSampling2 (None, 64, 64, 3)         0
_________________________________________________________________
block1_conv1 (Conv2D)        (None, 64, 64, 64)        1792
_________________________________________________________________
block1_conv2 (Conv2D)        (None, 64, 64, 64)        36928
_________________________________________________________________
block1_pool (MaxPooling2D)   (None, 32, 32, 64)        0
_________________________________________________________________
block2_conv1 (Conv2D)        (None, 32, 32, 128)       73856
_________________________________________________________________
block2_conv2 (Conv2D)        (None, 32, 32, 128)       147584
_________________________________________________________________
block2_pool (MaxPooling2D)   (None, 16, 16, 128)       0
_________________________________________________________________
block3_conv1 (Conv2D)        (None, 16, 16, 256)       295168
_________________________________________________________________
block3_conv2 (Conv2D)        (None, 16, 16, 256)       590080
_________________________________________________________________
block3_conv3 (Conv2D)        (None, 16, 16, 256)       590080
_________________________________________________________________
block3_pool (MaxPooling2D)   (None, 8, 8, 256)         0
_________________________________________________________________
block4_conv1 (Conv2D)        (None, 8, 8, 512)         1180160
_________________________________________________________________
block4_conv2 (Conv2D)        (None, 8, 8, 512)         2359808
_________________________________________________________________
block4_conv3 (Conv2D)        (None, 8, 8, 512)         2359808
_________________________________________________________________
block4_pool (MaxPooling2D)   (None, 4, 4, 512)         0
_________________________________________________________________
block5_conv1 (Conv2D)        (None, 4, 4, 512)         2359808
_________________________________________________________________
block5_conv2 (Conv2D)        (None, 4, 4, 512)         2359808
_________________________________________________________________
block5_conv3 (Conv2D)        (None, 4, 4, 512)         2359808
_________________________________________________________________
block5_pool (MaxPooling2D)   (None, 2, 2, 512)         0
=================================================================
Total params: 14,714,688
Trainable params: 14,714,688
Non-trainable params: 0
_________________________________________________________________