kerasでtrainとvalidationのTensorboardの出力を分ける


これなーに

keras便利ですよね〜 
モデル作って、精度のカーブをtensorboardで確認っと・・・

trainとvalが別れてグラフにプロットされてて、ちょっと見づらい。。
特にoverfitなどを確認する時は、trainとvalのグラフの精度のカーブを一緒に見たくなるので、できたら一緒にみたい!

ということで、下みたいに一緒に見れるようにするtipsです。

やり方

以下のclassを作って、model.fitをする時の引数にcallbacks=[TrainValTensorBoard(write_graph=False)]を指定してあげる。

class TrainValTensorBoard(keras.callbacks.TensorBoard):
    def __init__(self, log_dir='../saved/', **kwargs):
        # Make the original `TensorBoard` log to a subdirectory 'training'
        training_log_dir = os.path.join(log_dir, '../saved/tensorboard/train')
        super(TrainValTensorBoard, self).__init__(training_log_dir, **kwargs)

        # Log the validation metrics to a separate subdirectory
        self.val_log_dir = os.path.join(log_dir, '../saved/tensorboard/val')

    def set_model(self, model):
        # Setup writer for validation metrics
        self.val_writer = tf.summary.FileWriter(self.val_log_dir)
        super(TrainValTensorBoard, self).set_model(model)

    def on_epoch_end(self, epoch, logs=None):
        # Pop the validation logs and handle them separately with
        # `self.val_writer`. Also rename the keys so that they can
        # be plotted on the same figure with the training metrics
        logs = logs or {}
        val_logs = {k.replace('val_', ''): v for k, v in logs.items() if k.startswith('val_')}
        for name, value in val_logs.items():
            summary = tf.Summary()
            summary_value = summary.value.add()
            summary_value.simple_value = value.item()
            summary_value.tag = name
            self.val_writer.add_summary(summary, epoch)
        self.val_writer.flush()

        # Pass the remaining logs to `TensorBoard.on_epoch_end`
        logs = {k: v for k, v in logs.items() if not k.startswith('val_')}
        super(TrainValTensorBoard, self).on_epoch_end(epoch, logs)

    def on_train_end(self, logs=None):
        super(TrainValTensorBoard, self).on_train_end(logs)
        self.val_writer.close()

def training():
    ~略~
    model.fit(x_train, y_train,
              batch_size=batch_size,
              epochs=epochs,
              verbose=1,
              validation_data=(x_test, y_test),
              callbacks=[TrainValTensorBoard(write_graph=False)])
    ~略~

一番簡単なのは、tensorboardの利用を以下の形で定義し、fitする時のcallbacksに定義した物を突っ込みますが、その代わりに上をやると良い感じでtrainとvalが同時に見えるように出力できます。

    ~略~
    tb_cb = keras.callbacks.TensorBoard(log_dir=log_dir, histogram_freq=1)
    cbks = [tb_cb]
    ~略~
    model.fit(x_train, y_train,
              batch_size=batch_size,
              epochs=epochs,
              verbose=1,
              validation_data=(x_test, y_test),
              callbacks=cbks)