論文読み:f-GAN: Training Generative Neural Samplers using Variational Divergence Minimization


解釈がおかしい部分があったらコメントください。

論文:f-GAN: Training Generative Neural Samplers using Variational Divergence Minimization

1. Abstract

ランダムなノイズから確率分布を基に生成 (Sample) を行うGenerative neural samplerは表現力はあるが, サンプリングした画像の尤度の計算などは行っていない. そこで、その問題を改善するために, discriminative neural network を用いたGANの敵対的学習が提案された.
GANでは, 目的関数の最適化に JS Divergence を用いているが, 本論文ではそれを一般化した f-divergence を利用した f-GANを提案し, 画像の生成において, どのdivergenceを利用したら精度が高くなるのか検証をした.

2. Introduction

Generative neural sampler(GNS)のような生成モデルでは以下の操作が可能かが重要である.

  • Sampling:生成を行う. 生成結果によって生成モデルの確率分布やその生成過程を推定することができる
  • Estimation:未知の真の確率分布から生成されたサンプルを基にその真の確率分布を推定する
  • Point-wise likelihood evaluation:与えられたサンプルに対して尤度を計算する

GANでは以下の Jensen-Shannon divergence を最小化することで, GNSの最適化を行っていた.

D_{JS}(P||Q) = \frac{1}{2}D_{KL}(P||\frac{1}{2}(P+Q))+\frac{1}{2}D_{KL}(Q||\frac{1}{2}(P+Q))

GANの学習は収束しづらいと言われているので, 本論文では, Nguyenによって提案された f-divergence を利用してGNSの最適化を行うことで, 学習を安定させた.

3. Method

確率変数$x$の集合$\chi$において, f-divergence は以下のように定義される.

D_{f}(p||q)=\int_{\chi}q(x)f\left(\frac{p(x)}{q(x)}\right)dx\tag{1}

ここで, divergence の同一性より, $f$は $f(1)=0$を満たす関数で, 例えば, $f(x)=xlog(x)$とすれば KL divergence になる. f-GANでは, モデルの学習に, この f-divergence を用いるが, 本論文ではこの手法をvariational divergence minimization(VDM)と呼ぶ. VDM における目的関数を示すために, f-divergence が提案された論文における divergenceの推定方法を見ていく.

下半連続な凸関数(任意の $x_{0}$について $\underset{x\rightarrow x_{0}}{lim}:inf:f(x)\ge f(x_{0}$))は, 以下のような凸共役関数 $f^{*}$ を持つ.

f^{\ast}(t)=\underset{u\in dom_{f}}{sup}\{ut-f(u)\}

また, この共役関数は真凸関数であり, 下半連続なので, 閉真凸関数である.
従って, 共役関数$f^{*}$の共役関数 $f^{**}$($f$の双共役関数)に対して, $f^{**}=f$ が成り立つ. よって,

f(u)=f^{\ast*}(u)=\underset{t\in dom_{f^{*}}}{sup}\{tu-f^{*}(t)\}

が成り立つ. これを f-divergence の定義の式に代入すると,


D_{f}(p||q)=\int_{\chi}q(x)\underset{t\in dom_{f^{*}}}{sup}\{t\frac{p(x)}{q(x)}-f^{*}(t)\}dx

$f^{*}(x)$は凸関数なので, イェンセンの不等式が成り立つ. 従って,

\geq\sup_{T\in\mathrm{\tau}}(\int_{\chi}p(x)T(x)dx-\int_{\chi}q(x)f^{*}(T(x))dx)
=\sup_{T\in\mathrm{\tau}}(\underset{Sample\:from\:P}{\underbrace{\mathbb{E}_{x\sim P}[T(x)]}}-\underset{Sample\:from\:Q}{\underbrace{\mathbb{E}_{x\sim Q}[f^{*}(T(x))]}})\tag{2}

ここで, $\tau$はそれぞれの divergenceである. つまり, 上式は f-divergence がそれぞれの divergence の下限であることを示していて, どのdivergenceも上の式で推定することができることを表している. また, 上式において以下のような関係が成り立っている.

T^{*}(x)=f'(\frac{p(x)}{q(x)})

これが, $f(x)$を選ぶときの指針となる. 例えば下の図のreverse Kullback-Leibler divergence は, $f(x)=-log(x)$であり, $T^{*}(x)=-\frac{q(x)}{p(x)}$である.

3.1 Variational Divergence Minimization (VDM)

真の関連度分布$P$ (データセット) が与えられた時にGenerator の確率分布を推定するため(1)の式を利用する.ノイズをインプットとして受け取り, そのノイズからサンプルを生成するGeneratorを$Q$とし, そのパラメータを$θ$とする. 一方サンプルをインプットとして受け取り, スカラー値を返す関数を$T$とし(Discriminator), そのパラメータを$ω$とする. その時, f-GANの目的関数は以下のようになる.

F(\theta,\omega)=\mathbb{E}_{x\sim P}[T_{\omega}(x)]-\mathbb{E}_{x\sim Q_{\theta}}[f^{*}(T_{\omega}(x))]\tag{3}

この目的関数を基にMinMaxゲームでθを最小化し, ωを最大化していく.

3.2 Representation for the Variational Function

先ほどの目的関数に対して, 様々なdivergenceを適用させるために, 凸共役関数 $f^{*}$ の有効領域 $dom_{f^{*}}$を考えると, (3)の式は $T_{\omega}(x)=g_{f}(V_{\omega}(x))$として以下のように書き換えられる.

F(θ,ω)=\mathbb{E}_{x\sim P}[g_{f}(V_{ω}(x))]+\mathbb{E}_{x\sim Q_{θ}}[-f^{*}(g_{f}(V_{ω}(x)))]\tag{4}

$g_{f}$は, 特定のdivergenceを使った時の活性化関数の出力を表す, それぞれの凸共役関数 $f^{*}$ (divergence)に対して適切な出力の活性化関数とその有効領域をまとめたのが以下の図である.

ここで, $f'(1)$は真のサンプルとGeneratorによるサンプルを区別するために$T(x)$に適用される閾値のようなものを表す.
それぞれのdivergenceにおいて(4)の目的関数の左項と右項の学習の様子は以下の図のようになる.

3.3 Example: Univariate Mixture of Gaussians

それぞれのdivergenceとそれらを基にした目的関数 $F(\theta,\omega)$ に対して, 混合ガウスモデルを用いてその確率密度分布と$\theta=(\mu,\sigma)$($\mu$は平均, $\sigma$は分散)を求めたところ以下ような結果が得られた.

どのdivergenceにおいても, $D_{f}(P||Q_{\theta^{*}})\ge F(\hat{\omega},\hat{\theta})$なので, f-divergenceが, それぞれのdivergenceの下限(lower bound) であることが確認された. これは, f-divergenceにおいて$f$を変えるだけで様々なdivergenceが使えるということを示している. 次に, それぞれのdivergenceで学習させた後, $Q$を固定して再度$T$を学習させてテストしたところ, 以下の図のように学習させた時のdivergenceと一致している時に確率分布のズレが少なくなった. これは, f-divergence が divergence 毎の学習ができていることを示している.

つまり, Generatorは学習に用いたdivergenceに強く影響を受けることを示している.

4. Algorithms for Variational Divergence Minimization (VDM)

GANの学習では, Double-loop method を用いていたが, f-ganでは, Single-step gradient method によって学習を行う.

Double-loop method

  • inner loop : divergence の 下限(lower bound)をきつく(tightに)する
  • outer loop : inner loopを基に Generator の損失を最小化する

この方法ではouter loopの際に2度の誤差逆伝播が必要になる

Single-step gradient method

  • inner loopがなく, $\omega$と $\theta$ の勾配は一回の誤差逆伝播で計算される.

Single-step gradient method による最適化をまとめると, 以下のようなアルゴリズムで行われる.

GANは学習が不安定と言われるが, このような学習方法で, $F(\theta,\omega)$ が収束するのか考える. 収束したときの $(\theta,\omega)$ を $(\theta^{*},\omega^{*})$とすると, その周辺において$F$は$\theta$に対して強く凸性を持ち, $\omega$に対して強く凹性を持つので, 以下のような仮定ができる.

\nabla_{θ\text{}}F(θ^{*},ω^{*})=0,\;\nabla_{ω}F(θ^{*},ω^{*})=0,\;\nabla_{θ}^{2}F(θ,ω)\geqq\delta I,\;\nabla_{ω}^{2}F(θ,ω)\leqq-\delta I\tag{5}

簡略化のために, $\pi^{t}=(\theta^{t},\omega^{t})$ として, Algorithm1の収束を証明していく.
$\pi^{*}=(\theta^{*},\omega^{*})$において(5)が成り立つと仮定したとき, $J(\pi)=\frac{1}{2}||\nabla F(\pi)||_2$と定義すると, $\pi^{*}$よりも上の(above)$\pi$において, $F$は十分滑らかな関数(微分可能)なので, $||\nabla J(\pi')- \nabla J(\pi)||_{2}{\le}L||\pi'-\pi||_{2}$となり, この時の定数 $L$ は$L>0$となる. Algorithm1のステップ数 $\eta=\frac{\delta}{L}$を使うと, 以下の数式が得られる.

J(\pi^{t})\leq(1-\frac{\delta^{2}}{L})^{t}J(\pi^{0})

従って, $J(\pi)=\frac{1}{2}||\nabla F(\pi)||_{2}^{2}$ より, 勾配 $\nabla F(\pi)$ は小さくなり, 収束に向かうことがわかる.($\pi$を$\pi^{*}$よりも上と仮定したため)

GANでは, Generatorの最適化の際に

\mathbb{E}_{x\sim Q_{\theta}}[log(1-D_{\omega}(x))]

を最小化する代わりに,

\mathbb{E}_{x\sim Q_{\theta}}[logD_{\omega}(x)]

を最大化していたので, f-GANにおいても, Algorithm1の4行目を以下のように置き換える.

\theta^{t+1}=\theta^{t}+\eta\nabla_{\theta}\mathbb{E}_{x\sim Q_{\theta^{t}}}[g_{f}(V_{\omega^{t}}(x))]\tag{6}

5. Experiments

MNISTとLSUNのデータセットを用いて, VDMに基づいたGeneratorを学習させた.

5.1 MNIST Digits

実験の設定

  • 60,000枚の手書き数字の画像があるMNISTデータセットを用いて, それぞれのdivergenceを基にGeneratorの学習を行った
  • Generatorは2層のレイヤーがあり, それぞれのレイヤーの後でBatch normalizationを行い, それぞれの層の出力には活性化関数のReLUを適用した.また最後的な出力にはsigmoid関数を用いた
  • $V_{\omega}(x)$(Discriminatorのようなもの)は3層のレイヤーがあり, それぞれの層の間には活性化関数として, ELUが用いられている. 最後の層ではそれぞれのdivergenceごとの活性化関数を用いている
  • VAEと比較を行った

結果
生成された画像(16k)に対して, カーネル密度推定モデルによって推定された確率密度関数の対数尤度(確率密度の当てはまり)と標準誤差を見たところ, 以下の結果が得られた.

対数尤度が大きくなるほど当てはまりが良いということを表す. この結果より, GANよりもKL-divergenceを用いたほうが, 対数尤度が高く, 標準誤差 (統計量のばらつき) も小さいことが示された. 実際にそれぞれのdivergence(上からKL, reverse KL, Hellinger, Jensen)を用いて画像を生成した結果が以下の図である.

5.2 LSUN Natural Images

実験の設定

  • 様々なカテゴリーの画像からなるデータセットであるLSUNの中から今回の実験では 'classroom' のカテゴリの画像168,103枚用いて学習を行った
  • Generatorは, DCGANの論文と同じ構造で, 1つの全結合層と3つの畳み込み層からなる. その層の間には, 先ほどと同じようにBatch normalizationとReLUがある.
  • $V_{\omega}(x)$もDCGANの論文のDiscriminatorと同じで, CNN->Batch normalization->ELU->全結合層である

結果
それぞれのdivergenceを, GAN, KL, Hellingerにして画像を生成した結果以下のような画像が得られた.

6. 感想

  • GANのような二つの確率分布間の距離の最小化は、画像生成の他にも様々な分野で応用が可能だと思うので、本論文のようなf-divergenceの評価は重要だなと感じました
  • 全体的にフワッとしか理解できてない感じが否めないです

参考資料・スライド