Keras Functional APIでXORを実装してみる


はじめに

本格的にKerasを始めようと思い、まずはXORの予測モデルを構築するコードを書いてみた。将来的にGraph Convolutional Networksなど複雑なモデルも作りたいと思っているため、Functional APIから始めてみる。

ソース

こんな感じ。
BatchNormalizationをつけないと、局所解に陥りやすかったのでつけている。最終層はlinear層にしている事例が多かったが、0, 1のクラス分類問題であり、納得がいかなかったので、sigmoid関数とした。中間層2層のユニット数は、それぞれ8としてみた。

sample.py
import tensorflow as tf
from tensorflow.keras.models import Sequential, Model
from tensorflow.keras.layers import Dense, Input, Dropout, BatchNormalization

import numpy as np


def main():
    x_input = np.array([[0, 0], [0, 1], [1, 0], [1, 1]])
    y_input = np.array([[0], [1], [1], [0]])

    input_tensor = Input(shape=(x_input.shape[1]))
    # x = Dense(units=32, activation="tanh", kernel_initializer='random_normal')(input_tensor)
    x = Dense(units=8, activation="relu", kernel_initializer='random_normal', use_bias=True)(input_tensor)
    x = BatchNormalization()(x)
    x = Dropout(0.1)(x)
    x = Dense(units=8, activation="relu", kernel_initializer='random_normal', use_bias=True)(x)
kernel_initializer='random_normal', use_bias=False)(x)
    output_layer = Dense(units=1, activation='sigmoid', kernel_initializer='random_normal', use_bias=False)(x)
    model = Model(input_tensor, output_layer)

    model.compile(loss='mse',  optimizer='sgd', metrics=['accuracy'])
    model.summary()

    # 学習
    model.fit(x_input, y_input, nb_epoch=2000, batch_size=2, verbose=2)

    # 予測
    print(model.predict(np.array([[0, 0]])))
    print(model.predict(np.array([[1, 0]])))
    print(model.predict(np.array([[0, 1]])))
    print(model.predict(np.array([[1, 1]])))


if __name__ == "__main__":
    main()

モデル概要

_________________________________________________________________
Layer (type)                 Output Shape              Param #
=================================================================
input_1 (InputLayer)         [(None, 2)]               0
_________________________________________________________________
dense (Dense)                (None, 8)                 24
_________________________________________________________________
batch_normalization (BatchNo (None, 8)                 32
_________________________________________________________________
dropout (Dropout)            (None, 8)                 0
_________________________________________________________________
dense_1 (Dense)              (None, 8)                 72
_________________________________________________________________
dense_2 (Dense)              (None, 1)                 8
=================================================================
Total params: 136
Trainable params: 120
Non-trainable params: 16

予測結果

[[0.09092495]]
[[0.9356866]]
[[0.90092343]]
[[0.08152929]]

おわりに

今回手始めにやってみたが、パラメータチューニング、可視化など色々試してみたい。