tf.Keras @ Tensor Flow 2.xにおけるTFRecord活用の備忘録 2/2【学習・評価実行編】


前回の記事tf.Keras @ Tensor Flow 2.xにおけるTFRecord活用の備忘録 1/2【TFRecord作成編】で作成したTFRecordを用いて学習,評価を実行します。モデルはEfficientNetB7を用いました。

学習・評価実行用ソースコード

ソースコード例(JupyterNotebookでの実行を想定)
# 学習・評価用 Notebook

# データセット格納先パスのrootディレクトリを指定
dataset_root = ''
train_path = dataset_root + 'images/train/'
label_path = dataset_root + 'labels/'

import glob
import os
import pathlib
import json
import math
import numpy as np
import matplotlib.pyplot as plt
import pandas as pd
import pickle
import cv2
import IPython.display as display
import tensorflow as tf
from tensorflow.keras.applications.resnet50 import ResNet50
from tensorflow.keras.applications.inception_resnet_v2 import InceptionResNetV2
from tensorflow.keras.applications.efficientnet import EfficientNetB7
from tensorflow.keras.models import Sequential, Model
from tensorflow.keras.layers import Input, Flatten, Dense, GlobalAveragePooling2D
from tensorflow.keras.optimizers import Adam
from tensorflow.keras.utils import to_categorical

# Model用パラメータ
img_size = 128
num_channels = 3
num_class = 2
loss='binary_crossentropy'
metrics = ['acc', tf.keras.metrics.AUC()]
batch_size = 64
epoch = 20

# ファイル一覧を確認
for d in glob.glob(dataset_root+'**', recursive=True):
    if pathlib.Path(d).is_dir():
        print(d)

train_tfrecord_path = dataset_root + 'tfrecord/train/'
test_tfrecord_path = dataset_root + 'tfrecord/test/'
print(train_tfrecord_path)
print(test_tfrecord_path)

files = sorted(os.listdir(train_tfrecord_path))
tfrecords = []
for f in files:
    _, ext = os.path.splitext(f)
    if ext == '.tfrecord':
        tfrecords.append(train_tfrecord_path + f)
print(tfrecords)

val_idx = len(tfrecords)-1 # 初期設定は最後のTFRecordを評価用とする
val_tfrecord = tfrecords[val_idx]
train_tfrecord = [tfrecords[i] for i in range(len(tfrecords)) if i != val_idx]

# データセット用パラメータ
image_feature_description = {
    'image': tf.io.FixedLenFeature([], tf.string),
    'image_name': tf.io.FixedLenFeature([], tf.string),
    'target': tf.io.FixedLenFeature([], tf.int64),    
}

def _parse_image_function(example_proto):
  # 入力の tf.Example のプロトコルバッファを上記のディクショナリを使って解釈
  return tf.io.parse_single_example(example_proto, image_feature_description)


def _parse_image_function_for_test(example_proto):
  # 入力の tf.Example のプロトコルバッファを上記のディクショナリを使って解釈
  return tf.io.parse_single_example(example_proto, test_image_feature_description)

def _resize_image_function(example_proto):
    image = tf.image.decode_jpeg(example_proto["image"],channels=num_channels)
    image = tf.image.resize(image,(img_size,img_size))
    image = tf.cast(image,tf.uint8)
    example_proto["image"] = tf.image.encode_jpeg(image)
    return example_proto

def _extract_xy_function(example_proto):
    label = tf.cast(example_proto["target"], tf.int32)
    image = tf.image.decode_jpeg(example_proto["image"],channels = num_channels)
    image = tf.cast(image,tf.float32)
    return image, label

def _extract_xy_function_for_test(example_proto):
    name = tf.cast(example_proto["image_name"],tf.string)
    image = tf.image.decode_jpeg(example_proto["image"],channels = num_channels)
    image = tf.cast(image,tf.float32)
    return image, name

def count_dataset_size(raw_dataset):
    dataset_size = 0
    for raw_record in raw_dataset:
        dataset_size += 1
    return dataset_size

def load_tfrecords(path, shuffle_op=False):
    raw_dataset = tf.data.TFRecordDataset(path)
    dataset_size = count_dataset_size(raw_dataset)
    parsed_image_dataset = raw_dataset.map(_parse_image_function)
    dataset = parsed_image_dataset.map(_resize_image_function)
    dataset = dataset.map(_extract_xy_function)
    if shuffle_op==True:
        dataset = dataset.shuffle(buffer_size=dataset_size)
    dataset = dataset.repeat().batch(batch_size)
    return dataset, dataset_size

def load_tfrecords_for_test(path):
    raw_dataset = tf.data.TFRecordDataset(path)
    dataset_size = count_dataset_size(raw_dataset)
    parsed_image_dataset = raw_dataset.map(_parse_image_function_for_test)
    dataset = parsed_image_dataset.map(_resize_image_function)
    dataset = dataset.map(_extract_xy_function_for_test)
    dataset = dataset.batch(batch_size)
    return dataset, dataset_size

def save_model(model_name, model):
    model_file_name = model_name + '.h5'
    json_file_name = model_name + '.json'
    weights_file_name = model_name + '_weights.h5'
    model.save(model_file_name)
    print('Saved: ' + model_file_name)
    json_string = model.to_json()
    with open(json_file_name, 'w') as f:
        json.dump(json_string, f)
    print('Saved: ' + json_file_name)
    model.save_weights(weights_file_name)
    print('Saved: ' + weights_file_name)

def check_dataset(path, op):
    raw_image_dataset = tf.data.TFRecordDataset(path)

    if op == "train":
        parsed_image_dataset = raw_image_dataset.map(_parse_image_function)
    else:
        parsed_image_dataset = raw_image_dataset.map(_parse_image_function_for_test)

    for image_features in parsed_image_dataset.take(1):
        image_raw = image_features['image'].numpy()
        display.display(display.Image(data=image_raw))

## データセットの点検
check_dataset(train_tfrecord, "train")

ds, ds_size = load_tfrecords(train_tfrecord)
print(ds_size)
val_ds, val_ds_size = load_tfrecords(val_tfrecord)
print(val_ds_size)

def calc_steps(ds_size, val_ds_size, batch_size):
    return math.ceil(ds_size / batch_size), math.ceil(val_ds_size / batch_size)

steps_per_ep, val_steps = calc_steps(ds_size, val_ds_size, batch_size)
print(steps_per_ep)
print(val_steps)

input_tensor = Input(shape=(img_size, img_size, num_channels))
# base_model = InceptionResNetV2(weights='imagenet', include_top=False, pooling='avg', input_tensor=input_tensor)
base_model = EfficientNetB7(weights='imagenet', include_top=False, pooling='avg', input_tensor=input_tensor)

print(base_model.summary())

x = base_model.output

if 0 < num_class <= 2:
    model_output = Dense(1, activation='sigmoid', name='output')(x)
else:
    model_output = Dense(num_class, activation='softmax', name='output')(x)

model = Model(inputs=base_model.input, outputs=model_output)
print(model.summary())
print(len(model.layers))

# 最終層を除いて学習させない
for layer in model.layers[:-1]:
    layer.trainable = False

# 全てのパラメータを学習させる場合
# for layer in model.layers:
#     layer.trainable = True

model.compile(loss=loss, optimizer='adam', metrics=metrics)
print(model.summary())

hist = model.fit(ds, validation_data=val_ds, validation_steps=val_steps, epochs=epoch, steps_per_epoch = steps_per_ep)

idx = list(hist.history.keys())

# 学習結果をグラフで表示
pd.DataFrame({'acc': hist.history[idx[1]], 'val_acc': hist.history[idx[4]]}).plot()
pd.DataFrame({'loss': hist.history[idx[0]], 'val_loss': hist.history[idx[3]]}).plot()
pd.DataFrame({'auc': hist.history[idx[2]], 'val_auc': hist.history[idx[5]]}).plot()
plt.show()

# 学習済モデルの保存
save_model('efficientnetb7_shuffle_128', model)

test_image_feature_description = {
    'image': tf.io.FixedLenFeature([], tf.string),
    'image_name': tf.io.FixedLenFeature([], tf.string),   
}

files = sorted(os.listdir(test_tfrecord_path))
test_tfrecords = []
for f in files:
    _, ext = os.path.splitext(f)
    if ext == '.tfrecord':
        test_tfrecords.append(test_tfrecord_path + f)
test_tfrecords

## データセットの点検
check_dataset(train_tfrecord, "test")

test_ds, test_ds_size = load_tfrecords_for_test(test_tfrecords)
print(test_ds_size)

test_image_ds = test_ds.map(lambda image, label: image)
test_label_ds = test_ds.map(lambda image, label: label)

result = model.predict(test_image_ds).flatten()

print(result.shape)

names = [str(name.numpy()).split("'")[1] for name in test_label_ds.unbatch().take(test_ds_size)]

# 結果の書き出し
csv_file_name = 'efficientnetb7_shuffle_128.csv'
output_ds = pd.DataFrame({'image_name': names, 'target': result}).sort_values(by='image_name', ascending=True)
output_ds.to_csv(csv_file_name, index=False)

今後の改善点

  1. スクリプトファイル化
  2. 冗長な部分を共通化するなどしてリファクタリングする

Reference

今後の研究用(TF2.xのドキュメント)


  1. このサイトのオーナーに実装のアドバイスを貰いました。