ランド研究所の「機械学習による航空支配」を実装する(その4): conditional GAN の実装とトレーニング


本記事は、ランド研究所の「機械学習による航空支配」を実装する(その4)です。

Air Dominance Through Machine Learning: A Preliminary Exploration of Artificial Intelligence–Assisted Mission Planning, 2020

(その3)で、ランド研究所のレポートにある1次元問題のためのシミュレータを実装しました。これは、強化学習で言うところの環境に対応します。(但し、極力簡素化したので step by step のシミュレータではありません)。また、トレーニング・データを生成するためと、プランナーの出来を比較するために、ベースラインとなるランダム・プランナーを実装しました。

今回の(その4)では、conditional GAN を使ったミッション・プランナーを実装し、トレーニングします。ただし、レポートに細部の記載は一切ないので、独断で実装します。使用するトレーニング用の正例データ(ランダム・プランナーによりミッションに成功したデータ)、評価用データ、テスト用データは、(その3)で生成したものを使います。

実装コードは、下記 GitHub です。GAN の実装は、Tensorflow 2 を使用しました。これは以前に画像生成用に作ったものを転用したので、変数名が img 等になっていて少し変ですがご容赦ください。

https://github.com/DreamMaker-Ai/AirDominace_1D_GAN_Git

なお、私の体力不足で、GAN や conditional GAN の説明はしていません。ただ、それらを知っていないと内容が消化できないかもしれませんので、消化できなかった際は、原論文を当たるか、ネット上で適当な解説記事をお探しください(ゴメンナサイ)。

Generative Adversarial Networks, 2014
Conditional Generative Adversarial Nets, 2014

過去記事へのリンクは最下段にあります。

1. ミッション・プランニング用の conditional GAN の実装

ミッション・プランニングのための GAN プランナーを実装します。GitHub のコードは、train_conditional_gan.pyになります。

まず、レポートで、GAN と言い張っている意味不明な (その2)の Figure 2.5 ですが、下図のようなトレーニングを企図している図と解釈しました。

図1.1

ここで、conditional GAN の Generator と Discriminator に条件として与える入力は(その2)で定義した4つのミッション条件(mission conditions)です。

Fighter の射程(fighter.firing_range)
SAM の位置(sam.offset)
SAM の射程(sam.firing_range)
ジャミングを受けた時の SAM の射程(sam.jammed_firing_range)

これらは conditional GAN の Paper の目的函数(原論文からの引用):

の条件イベント y に対応します。図には、結果 (real results) という記載もありますが、これは今回は関係ありません。(conditional GAN を進化させた ACGAN 等を使って能力向上できるのか試す私案用です。そのうち実装する予定です)。

また、式 (2) の x 及び G(z|y) が Discriminator への入力になります。x は実際にシミュレーションで成功したミッション・プラン(True)です。一方、G(z|y) は、ミッション条件 y が与えられた下で、 Generator が 正規分布からサンプルした z を使って生成したミッション・プラン(Fake)になります。物理量としては、

fighter.ingress
jammer.ingress

です。

実装では、これらをそのまま使用するのではなく、深層学習で通常行われるように、以下で正規化しました。ここで、sam.max_range = 100km, jammer.jam_range = 30km, sam.max_firing_range = 40km, fighter.max_firing_range = 40km は固定値で、(その2)の Table 2.2 のミッション条件に合わせています。

    - Generator Input
        + 正規分布からのサンプル z ~ N(0, I)
    - Generator Output:2出力
        + ミッション・プランを正規化したもの
            * Jammer.ingress(進出レンジ)/sam.max_offset
            * Fighter.ingress(進出レンジ)/sam.max_offset
            * (Jammet と Fighter のタイミングについてのプランは生成しません。これは、レポートも同じです。私が、step_by_step のシミュレーションを実装しなかった理由はここです)。

    - Discriminator Input:2入力
        + Generator が生成した Fake のミッションプラン、または、シミュレーション結果である True のミッションプラン(正例のみ使用)
    - Discriminato Output
        + Fake/True の確率

    - Condition 入力:4入力 
        + ミッション条件を正規化したもの
            * Fighter の射程 = fighter.firing_range/fighter.max_firing_range
            * SAM 配備位置 = sam.offset / sam.max_offset
            * SAMの 射程 = sam.firing_range / sam.max_firing_range
            * ジャミングを受けた時の SAM の射程 = sam.firing_range * .7 / sam.max_firing_range

成功するミッション・プランを生成したいので、GAN のトレーニングは、レポートと同じく、正例(ランダム・プランナーによるプランでミッションに成功した、つまり、SAM を撃破しFighter と Jammer が生き残ったサンプル)のみを用います。
 
Generator の目的は、正規分布からサンプルした z を使って、成功しそうなミッション・プランを生成することです。一方、Discriminator の目的は、提示されたプランが Generator が正規分布から生成したプランなのか、ランダム・プランナーのプランを使った時に成功した正例プランなのかを判断することです。また、条件入力はミッション条件になります。あとは、目的函数(2)を最適化して行けば、Generator は与えられたミッション条件に対して、次第に正例、つまり成功しそうなミッション・プランを正規分布のサンプルから生成することを学習するはずです。(このため、トレーニングに正例だけを用いています)。

トレーニングは、通常の GAN と同様に、Generator と Discriminator を交互にトレーニングしました。

GitHub のコードでは、以下が GAN トレーニングの主要部分です。これは、以前に画像生成で作ったものを転用しているので、コメント部分の配列サイズは、今回の問題に対応していません(ゴメンナサイ。後で、正確な図表を出します。)

"""
Train the GAN
"""
while True:
    epoch += 1
    print(f'Current epoch = {epoch}')
    d_loss = 0
    g_loss = 0
    d_accuracy = 0
    g_accuracy = 0

    for (X_batch, y_batch) in dataset:
        # データの型を変換
        X_batch = tf.cast(X_batch, tf.float32)
        y_batch = tf.cast(y_batch, tf.float32)

        """
        Train the discriminator
        """
        discriminator.trainable = True
        generator.trainable = False

        # Sample from N(0,1)
        z = np.random.randn(batch_size, coding_size)  # (256,100)
        gen_images = generator([z, y_batch])  # (256,2), float32

        # Make the dataset for the training
        X_fake_vs_real = tf.concat([gen_images, X_batch], axis=0)  # (512,2)
        y_fake_vs_real = \
            tf.concat([tf.zeros(batch_size), tf.ones(batch_size)], axis=0)  # (512,)
        y_fake_vs_real = tf.expand_dims(y_fake_vs_real, axis=1)  # (512,1)
        cond_label = tf.concat([y_batch, y_batch], axis=0)  # (512,5)

        # Train the discriminator
        d = discriminator.fit(x=[X_fake_vs_real, cond_label], y=y_fake_vs_real, epochs=1, verbose=0)
        d_loss += d.history['loss'][0]
        d_accuracy += d.history['accuracy'][0]

        """
        Train the generator
        """
        discriminator.trainable = False
        generator.trainable = True

        # Sample from N(0,1)
        noise = np.random.randn(batch_size, coding_size)  # (256,100)
        lab = y_batch
        y_gen = tf.ones(batch_size)  # (256,)
        y_gen = tf.expand_dims(y_gen, axis=1)  # (256,1)

        # Train the generator
        g = gan.fit(x=[noise, lab], y=y_gen, epochs=1, verbose=0)
        g_loss += g.history['loss'][0]
        g_accuracy += g.history['accuracy'][0]

GitHub のコードでは、以下のクラスで GAN を定義しています。今回の応用は画像処理ではないし、特に畳み込みを使う理由は無いように思えたので、畳み込みは使用しませんでした。アーキテクチャは、深い意味があって決めたわけではなく、ほぼ思いつきだけで決まているので、最適化の余地が多分にあります。レポートには記載は0です。以下も、コメント部分の配列サイズは、今回の問題に対応していません(ゴメンナサイ。後で、正確な図表を出します。)

class GAN(object):
    """
    GANを定義
    """

    def __init__(self, coding_size):
        self.coding_size = coding_size
        self.action_size = 2
        self.condition_size = 4

    def create_generator(self):
        """
        Define the Generator
        """
        g1 = Input(shape=(self.coding_size,))  # (None,100)
        c1 = Input(shape=(self.condition_size,))  # (None,5)
        cond_in = Concatenate(axis=-1)([g1, c1])  # (None,105)

        g2 = Dense(units=HIDDEN_UNITS, activation='relu')(cond_in)  # (None,128)
        g3 = BatchNormalization()(g2)  # (None,128)
        g4 = Dense(units=HIDDEN_UNITS, activation='relu')(g3)  # (None,128)
        # g5 = Dense(units=self.action_size, activation='linear')(g4)  # (None,2)
        g5 = Dense(units=self.action_size, activation=FINAL_ACTIVATION)(g4)  # (None,2)

        generator = Model([g1, c1], g5, name='generator')
        return generator

    def create_discriminator(self):
        """
        Define the Discriminator
        """
        d1 = Input(shape=(self.action_size,))  # (None,2)
        c1 = Input(shape=(self.condition_size,))  # (None,5)
        cond_in = Concatenate(axis=-1)([d1, c1])
        d2 = Dense(units=HIDDEN_UNITS, activation='relu')(cond_in)  # (None,128)
        d3 = Dropout(rate=0.5)(d2)  # (None,128)
        d4 = Dense(units=HIDDEN_UNITS, activation='relu')(d3)  # (None,128)
        d5 = Dense(units=1, activation='sigmoid')(d4)  # (None,1)

        discriminator = Model([d1, c1], d5, name='discriminator')
        return discriminator

    def create_gan(self):
        # Create the Generator
        generator = self.create_generator()

        # Create the Discriminator
        discriminator = self.create_discriminator()
        optimizer = Adam(learning_rate=DISCRIMINATOR_LR)
        discriminator.compile(loss='binary_crossentropy', optimizer=optimizer, metrics=['accuracy'])

        # Create the GAN
        noise_input = Input(shape=(self.coding_size,))  # (None,100)
        cond_input = Input(shape=(self.condition_size,))  # (None,5)
        f_image = generator([noise_input, cond_input])  # (None,2)
        gan_output = discriminator([f_image, cond_input])  # (None,1)

        gan = Model([noise_input, cond_input], gan_output, name='gan')
        optimizer = Adam(learning_rate=GAN_LR)
        gan.compile(loss='binary_crossentropy', optimizer=optimizer, metrics=['accuracy'])

        return gan, generator, discriminator

Generator、Discriminator のアーキテクチャ、及び GAN のトレーニング・アーキテクチャは、サイズを含めて書くと下図表となります。(これは、今回用いた正確なサイズです)。図中の ? と表中の None は、Tensorflow の計算グラフ作成時には未定で、実行時に与えるバッチ次元を表しています。

1.1 Generator

1.2 Discriminator

1.3 GAN トレーニング・アーキテクチャ

1.4 ハイパー・パラメータの設定

設定したハイパー・パラメータは下表にまとめました。プランを生成する正規分布 z の次元は 20 としました。これも、レポートには何も書かれていないので、こんなもんかなぁ、だけで決めています。Generator の最終層の活性化函数としては [0, 1] の実数値を出力してくれればよいだけなので 'sigmoid' 函数にしたいところですが、sigmoid 函数だと 1 や 0 に近づくほど勾配が緩やかになってしまいぎりぎりのミッション条件をうまく処理してくれない気がしたので 'relu' 函数としたものも実験し、比較することにしました。

2. トレーニング中の評価

2.1 評価指標

GAN (Generator) の学習状況はロスを見ていてもサッパリ判らない(私だけ?)ので、評価データ、テスト・データのミッション条件 10,000 ポイントを GAN プランナーに与え、プランナーが生成したプランに従って行動した場合のシミュレーションを行い、ミッションの成功率で評価します。テストと評価は同じことをやっています(片方だけで良かったですね)。

conditional GAN では、ミッション条件に加えて、正規分布からのサンプル z も入力になります。したがって、GANのプランニング性能を正しく評価するには、各ミッション条件に対して、多数の z をサンプリングして統計量で評価する必要があります。

と言っても、ミッション条件がテスト、評価それぞれで、10,000ポイントづつあるので、Google 並みの計算機パワーを持っていればこれが可能かもしれませんが、私のビンテージマシンではできません。

そこで、以下のように評価することにしました。

私のビンテージ・マシンの性能から、ミッション条件1つに対して多数の z をサンプリングすることはできませんが、ミッション条件1つに対して1つの z をサンプリングするぐらいのことはできます。これをミッション条件すべて(10,000ポイント)に対し実施して平均成功率を出せば、その学習ステップにおける GAN の何らかの性能を表しているはずです。さらに、GANはそう急速には学習しないと考えれば、ステップ間で移動平均を求めれば、それが本来の統計量に近い量になる可能性が有ります。幸い Tensorboard や 私がよく使っている wandb には、スムージング機能(指数移動平均;EMA: Exponential Moving Average)があるので、わざわざ移動平均を計算する必要がありません。スムージング係数は 0.8 で固定しました。
 
また、成功率としては、全ミッション条件に対する成功率(これは、(その3)で示したように、上限が 58.44% です)と、ミッション成功の可能性があるミッション条件に対する成功率(これは、上限が 100% になります)が考えられるので、これら両方の成功率を計算することにしました。

2.2 トレーニング結果

2.2.1 トレーニング履歴

下図が、トレーニング・データ中の全ミッション条件に対する成功率(理論上限が 58.44%)です。横軸が学習エポック数、縦軸が成功率です。薄く見えているのが、スムージング前の値です。

図2.1

今回の実験では、Generator の最終層に "sigmoid" 函数を使ったほう(青ライン)が "relu" 函数を使う(橙ライン)よりも、学習が安定するようです。本当は、何回か実行して統計を取るべきなのですが、ビンテージマシンでは1回のトレーニングに丸1日かかってしまったので断念しました。(大半の時間は、GAN のトレーニング時間ではなく、シミュレーションの実行時間です。シミュレータも、それっぽくループで回さず、素直にバッチ処理にすべきでした)。

トレーニング開始時点では、成功率 ≃ 0% ですが、2k のトレーニング・エポック後の成功率は 40% 強といったところです。

ベースラインであるランダム・プランナーでは、ミッション条件全体に対しては、成功率が 10% 程度でしたので、約 30% 程度は成功率が改善されていますが、最大値である 58% よりは未だかなり低い値でした。また、グラフから判るように、わりと早いエポックで学習が進みます。

下図は、ミッション成功の可能性があるミッション条件に対する成功率(最大で 100%)の履歴です。成功率は 70% 程度でしたので、30% ぐらい取りこぼしがあることになります。これらについては、(その5)で詳しく見ていきます。
図2.2

2.2.2 GAN のロスと精度 (accuracy) 履歴

GAN(Generator)と Discriminator のトレーニング・ロスと精度(accuracy)は以下のような履歴となっています。gan_lossとあるのはGAN形態でGeneratorをトレーニングしている時のロスです。ロスも精度も何度か、ステップ上に変化していますが、不安定にはなっていないようです。ただ、これからそれ以上の何かが判るのかと言うと、私にはわかりません。

(その5)へ続く

(その5)では、トレーニング結果を分析し、GAN が生成しようとしているプランを読み解きます。また、GAN の性能に影響を与えているファクタを特定し、性能改善を試みます。

過去記事へのリンク