[Survey]Adversarial Autoencoders


Adversarial Autoencoders

この論文の目的は、autoencoderlatent code vectorが任意の分布になるような学習方法を提案することです。
adversarial autoencodersに関する論文としては以下3つが有名で各サイトで説明されているので、今回は細かい説明は省略したいと思います。

  1. Generative Adversarial Nets
  2. Adversarial Autoencoders
  3. Unsupervised Representation learning With Deep Convolutional Generative Adversarial Networks

1番の論文に関しては下記サイトがとても詳しく、Tensorflowによる実装も掲載されています。
http://evjang.com/articles/genadv1

Adversarial autoencoders

Adversarial autoencodersのブロック図と学習の流れを示したのが下記の図になります。
上段がautoencoderになります。autoencoderは入力されたデータの次元を圧縮して(encode)、元の次元戻す(decode)処理をするもので、なるべく元の入力と同じになるように重みを学習します。autoencoderlatent code vector $z$の分布$q(z)$と任意の分布$p(z)$が同じになるようにautoencoderを学習させます。$q(z)$と$p(z)$が似ているのかどうか判断するのがDiscriminatorの役割です。もし$q(z)$と$p(z)$が同じになるようにautoencoderを学習できれば、任意の分布$p(z)$から生成した$z$をDecoderに入力してあげると、$x$を入力しなくても$x$に似たデータが$z$から生成できます。

最適化の式は下記のようになります。
ここで$D$はDiscriminator modelで$G$はGenerator Model、$p_{data}$任意の分布、$p(z)$はGeneratorにより生成された$z$の分布です。
学習は下記のようにおこないます。
1. true sample(任意の分布のデータ)とfake sample(Generatorにより生成されたデータ)をうまく分離できるようにDiscriminatorを学習します。下記の式の第一項はtrue sampleのときに1と言うようになれば値が大きくなります。第二項はfake sample($G(z)$)のときに0と言うようになれば値が大きくなります。足したものを$D$に関して最大化すればtrue samplefake sampleがうまく分離できるようになります。
2. GeneratorDiscriminatorを騙せるように学習します。Discriminatorをうまく騙せると$D(G(z))$が1になります。なので下記の式を$G$に関して最小化すればDiscriminatorをうまく騙せるようになります。

実際にMNISTのデータを学習させる時は下記のような構造にします。基本的には変わりありませんが、Discriminatorにデータが0なのか1なのか...9なのかを示したone hot vectorを渡します。

Result

もうすでにAdversarial autoencodersのコードはいろいろ公開されていますが、自分でもtensorflowで実装して学習させてみました。(もはや何番煎じかわかりませんが...)
実装の際には下記のコードやサイトを参考にしました。
http://musyoku.github.io/2016/02/22/adversarial-autoencoder/
https://github.com/takerum/adversarial_autoencoder
https://www.reddit.com/r/MachineLearning/comments/3ybj4d/151105644_adversarial_autoencoders/?

implementation

ネットワーク構造及び学習パラメータは下記の通りです。

・Encoder, Decoder, Discriminatorともに3 layer network
・すべてのネットワークのhidden layerのunit数は1000
・Encoderはfirst, second layerでReLUを使用、last layerは何もなし
・Decoderはfirst, second layerでReLUを使用、last layerはsigmoidを使用。
・Discriminatorは、first, second、last layerでReLU使用
・Latent code vectorは2次元
・100 epoch
・batch size 100
・Batch Normalization,weight decayは無し
・OptimizerはAdam
・Learning rateは、autoencoderとgeneratorは0.001でdiscriminatorは0.0002
・generatorだけ学習回数を2倍にしました。(autoencoder, discriminator, generatorの順で学習させるとうまく学習が収束しなかったので、autoencoder, discriminator, generator, generatorという感じにしました。ほんとうにこれでいいのか不明ですが・・・。)
・autoencoderは最初にpre-trainingしました。

実際に学習させた結果下記のようになりました。

10 2D Gaussian

MNISTの0〜9のデータを10個の2D Gaussian分布に押し込みます。

epoch毎の$q(z)$の分布をgif animationにしてみました。qittaは画像サイズが1M以上だと圧縮されてgif animationが動かなくなるので、160x160にResizeしました。小さすぎてなんだかよくわからなくなってしまいましたが・・・。

decoderに$p(z)$を入力した時のepoch毎の出力結果をgif animationにしたものです。




swiss roll

MNISTの0〜9のデータを2D swiss roll分布に押し込みます。

epoch毎の$q(z)$の分布をgif animationにしてみました。

decoderに$p(z)$を入力した時のepoch毎の出力結果をgif animationにしたものです。



最後に

試行錯誤のすえ、ようやく学習が収束するようになりました。自分は色々な人の実装を参考にしながらやったので何とかなりましたが、これを最初に実装した人はすごいなと思います。実装してみた感想としては、うまく収束させるのが難しいです。ただ3つを順番に学習させればいいというわけではなさそうで、3つのバランスが重要な気がします。できたばかりでコードがぐちゃぐちゃのため整理したらGithubにあげようかなと思っています。

code