PRML復活の呪文 part21 (5.6)


TL;DR

  • 条件付き分布$ p(t|x) $がガウス分布と大きく離れていると通常のニューラルネットネットはうまくいかない
  • この場合、$ p(t|x) $にパラメトリックな混合モデルを仮定し、混合モデルのパラメータをニューラルネットで推定しよう
  • このようなネットワークを混合密度ネットワークとよぶ

5.6 混合密度ネットワーク

教師あり学習の目標は、条件付き分布$ p(t|x) $をモデル化することであり、多くの回帰問題では$ p(t|x) $はガウス分布と仮定される。
しかし、実用的な問題の中には$ p(t|x) $をガウス分布とは全く異なる分布を用いる場合がある。この場合、ガウス分布の仮定では精度の悪い予測結果しか得られない。

例題を考えよう。下図の緑点は以下の条件で生成したデータ点である。
図の左側:

  • $x$:区間(0,1)に一様分布する確率変数$x$をサンプリングしたもの
  • $t$:$ x+0.3 \sin (2 \pi x) $に小さいノイズを加えたもの

一方、図の右側は$x$と$t$を入れ替えてある。すなわち、

  • $t$:区間(0,1)に一様分布する確率変数$t$をサンプリングしたもの
  • $x$:$ t+0.3 \sin (2 \pi t) $に小さいノイズを加えたもの

そして二乗和誤差関数を最小化して求まったフィッティング結果が赤線で記されている。
二乗和誤差の最小化はノイズがガウス分布であるという仮定での最尤推定解と同じである1ことを思い出そう。

図の$x = 0.3 $ぐらいのところに青線を引いたが、左側は$x = 0.3 $のときにとる$t$の値が0.5付近で上にも下にも散らばっており、ガウス分布に近い。そのため、フィッティング結果の赤線は良好な精度となっている。
一方、右側は$x = 0.3 $のときにとる$t$の値が0.1付近、0.6付近と2か所に固まって散らばっており、ガウス分布とはかけ離れた分布である。結果、フィッティング結果も悪い精度となっている。
(そもそも$x=0.3$という情報だけから$t$が0.1付近か0.6付近か、など分かるわけがない)

そこで、条件付き分布$ p(t|x) $に混合モデルを用いて条件付き分布をモデル化する方法を見ていこう。
端的には、通常のニューラルネットワークは$x$を入力して$t$を出力(予測)するモデルであるのに対し、混合密度ネットワークではパラメトリック2な混合分布であると仮定し、$x$を入力して$t$の代わりに混合係数や(混合)ガウス分布の平均や分散といったパラメータを出力(予測)するモデルである。

具体例を見ていこう。$ p(t|x) $が$K$個のガウス分布の混合モデル(参照⇒part7)で表せると仮定し、目標変数$t$は$L$次元とする。式で表すと、

$$
p(t|x) = \sum_{k=1}^K \pi_k (x) N(t | \mu_k(x), \sigma^2_k(x) I ) \tag{5.148}
$$

なお、この例では目標変数が連続であることを仮定しているのでガウス分布の混合モデルを用いたが、目標変数が離散である場合はベルヌーイ分布などのほかの分布を用いることができる。

この混合モデルは以下のパラメータで表現できる。

  • 混合係数が$\pi_1(x), \cdots, \pi_K(x) $の$K$個
  • ガウス分布の分散が$ \sigma^2_1(x), \cdots, \sigma^2_L(x) $の$K$個
  • ガウス分布の平均ベクトル$ \mu_k(x) $は$L$次元で、$ \mu_k(x) $が$ \mu_1(x), \cdots, \mu_K(x) $個の$K$個あるので合計$ K \times L $個

これらの総数は$ (L+2)K $個あるので、混合密度ネットワークの出力ユニット数を$ (L+2)K $個用意し、$x$を入力して混合係数、ガウス分布の平均、ガウス分布の分散を予測すればよい。

ただし、混合係数と分散については制約条件を満たすために以下の工夫を入れる。

混合係数は

$$
\sum_{k=1}^K \pi_k(x) = 1, 0 \leq \pi_k(x) \leq 1 \tag{5.149}
$$

を満たすために混合密度ネットワークの出力$\pi_1(x), \cdots, \pi_K(x) $にソフトマックス関数にかけたものを出力値とする。

分散は$ \sigma^2_k(x) \geq 0 $を満たすために混合密度ネットワークの出力にexpをとったものを出力値とする。

混合密度ネットワークの重みも通常のニューラルネットと同様に、誤差逆伝播法で求めることができる。誤差逆伝播法の計算に必要な誤差関数と、誤差のネットワーク出力に関する微分を求めよう。

誤差関数

誤差関数は負の対数尤度として定義できる。尤度、すなわち混合モデルの平均や分散などのパラメータをこれだと決めたときに観測データが得られる確率は

\begin{align}

\prod_{n=1}^N p(t|x) =
\prod_{n=1}^N \left\{ 
\sum_{k=1}^K \pi_k(x_n, w) N(t_n | \mu_k(x_n, w), \sigma^2_k(x_n, w) I )
\right\}

\end{align}

と書ける。なお、上式では混合係数$ \pi_k(x_n) $とガウス分布の平均/分散が重み$w$に依存していることを明示した。

よって、誤差関数は

\begin{align}

E(w) = - \sum_{n=1}^N \ln \left\{ 
\sum_{k=1}^K \pi_k(x_n, w) N(t_n | \mu_k(x_n, w), \sigma^2_k(x_n, w) I ) \tag{5.153}
\right\}

\end{align}

であり、あるデータ点$x_n$に関する誤差は

\begin{align}

E_n(w) = -  \ln \left\{ 
\sum_{k=1}^K \pi_k(x_n, w) N(t_n | \mu_k(x_n, w), \sigma^2_k(x_n, w) I ) \tag{5.153}
\right\}

\end{align}

である。

混合係数を決めるネットワークの出力に関する微分(演習5.34)

$ \pi_k(x) $を決定する出力ユニットの出力を$ a_k^{\pi} $と表記する。ここで求めたいものは$ \frac{ \partial E_n}{ \partial a_k^{\pi} }$である。
これからの式変形のために、いくつかの表記を導入しておこう。

\begin{align}

N_{nk} &= N(t_n | \mu_k(x_n), \sigma^2_k (x_n) ) \\
\gamma_{nk} &= \frac{ \pi_k N_{nk} }{ \sum_{l=1}^K \pi_l N_{nl} } \tag{5.154} \\
E_n(w) &= -  \ln \left\{ 
\sum_{k=1}^K \pi_k N_{nk} \tag{5.153'}
\right\}

\end{align}

ソフトマックス関数の影響で$ a_k^{\pi} $の値にはすべての混合係数$ \pi_j $が影響していることに留意して、連鎖律を使うと

$$
\frac{ \partial E_n}{ \partial a_k^{\pi} } = \sum_{j=1}^K \frac{ \partial E_n}{ \partial \pi_j } \frac{ \partial \pi_j }{ \partial a_k^{\pi} }
$$

である。$E_n$を微分すると$\ln$の中身が分母になって、分子は$ - \pi_j N_{nj} $の$ \pi_j $が微分によって消えることから

$$
\frac{ \partial E_n}{ \partial \pi_j } = - \frac{ N_{nj} }{ \sum_{l=1}^K \pi_l N_{nl} } = - \frac{ \gamma_{nj} }{ \pi_j }
$$

次に$ \frac{ \partial \pi_j }{ \partial a_k^{\pi} } $だが、これはソフトマックス関数の微分なのでテキストの式(4.106)の公式をそのまま使える。

$$
\frac{ \partial \pi_j }{ \partial a_k^{\pi} } = \pi_j ( I_{jk} - \pi_k )
$$

ここで$ I_{jk} $は単位行列の要素を表す。以上により、

\begin{align}

\frac{ \partial E_n}{ \partial a_k^{\pi} } &= - \sum_{j=1}^K \frac{ \gamma_{nj} }{ \pi_j } \pi_j ( I_{jk} - \pi_k ) = - \sum_{j=1}^K \gamma_{nj} ( I_{jk} - \pi_k ) \\
&= - \sum_{j=1}^K \gamma_{nj} I_{jk} + \sum_{j=1}^K \gamma_{nj} \pi_k \\
I_{jk}はj=kのときのみ1となり、その他は0であることと\sum_{j=1}^K \gamma_{nj} = 1となることから \\
&= \pi_k - \gamma_{nk} \tag{5.155}
\end{align}

各ガウス分布の平均を決めるネットワークの出力に関する微分(演習5.35)

$k$番目のガウス分布の$l$次元目の出力に関する平均$ \mu_{kl} $を決定する出力ユニットの出力を$ a_{kl}^{\mu} $と表記する。ここで求めたいものは$ \frac{ \partial E_n}{ \partial a_{kl}^{\mu} }$であるが、$ a_{kl}^{\mu} = \mu_{kl} $なので$ \frac{ \partial E_n}{ \partial \mu_{kl} }$を求めればよい。

微分した結果の分母は上記演習5.34と同じで$E_n$の$ \ln $の中身になる。分子は$ \pi_k N_{nk}$の微分となる。

\begin{align}

N_{nk} &= \frac{1}{ (2 \pi)^{L/2} } \frac{1}{ |\sigma_k^2 I|^{1/2} } \exp
\left( - \frac{1}{2} (t_n - \mu_k)^T \frac{1}{\sigma_k^2 I} (t_n - \mu_k)
\right) \\
&= \frac{1}{ (2 \pi)^{L/2} } \cdot \sigma_k^{-L} \exp
\left( - \frac{1}{ 2 \sigma_k^2} || t_n - \mu_k ||^2
\right) \tag{X} \\
&= C \exp
\left( - \frac{1}{2} (t_n - \mu_k)^T \frac{1}{\sigma_k^2 I} (t_n - \mu_k)
\right) 

\end{align}

とおくと、$ \pi_k N_{nk}$の微分は

$$
\pi_k C \exp
\left( - \frac{1}{2} (t_n - \mu_k)^T \frac{1}{\sigma_k^2 I} (t_n - \mu_k)
\right) \times ( \expの中身の微分)
$$
であると分かる。
なお、式(X)は後述の演習5.36で使うため変形しておいた。また、行列式の変形である$ \frac{1}{ |\sigma_k^2 I|^{1/2} } = \sigma_k^{-L} $が疑問な場合はpart2参照。
$L$次元ガウス分布の$l$次元目のexpの中身は

$$
- \frac{1}{ 2 \sigma_k^2} (t_{nl} - \mu_{kl} )^2 \tag{A}
$$

なので、(A)を$ \mu_{kl} $に関して微分すると

$$
\frac{1}{\sigma_k^2} ( \mu_{kl} - t_{nl} )
$$

以上により、

\begin{align}

\frac{ \partial E_n}{ \partial a_{kl}^{\mu} } &= \left( 
\frac{ \pi_k N_{nk} }{ \sum_{l=1}^K \pi_l N_{nl} }
\right) \times \frac{1}{\sigma_k^2} ( \mu_{kl} - t_{nl} ) \\

&= \gamma_{nk} \left\{ 
\frac{ \mu_{kl} - t_{nl} }{ \sigma_k^2 }
\right\} \tag{5.156}

\end{align}

各ガウス分布の分散を決めるネットワークの出力に関する微分(演習5.36)

$ \pi_k(x) $を決定する出力ユニットの出力を$ a_k^{\sigma} $と表記する。ここで求めたいものは$ \frac{ \partial E_n}{ \partial a_k^{\sigma} }$である。

連鎖律を使って

$$
\frac{ \partial E_n}{ \partial a_k^{\sigma} } = \frac{ \partial E_n}{ \partial \sigma_k } \frac{ \sigma_k }{ \partial a_k^{\sigma} }
$$

となる。
微分した結果の分母はやはり上記演習5.34と同じで$E_n$の$ \ln $の中身になる。分子は$ \pi_k N_{nk}$の微分となるが、$ \pi_k $は$ \sigma_k $に依存しないことから

$$
\frac{ \partial E_n}{ \partial \sigma_k } = - \frac{ \pi_k }{ \sum_l \pi_l N_{nl} } \cdot ( N_{nk} )'
$$

となる。式(X)を使って$ ( N_{nk} )' $を変形していくが、$ \sigma_k $がexpの中と外の2か所にあるので$ (fg)' = f'g + fg' $の公式を用いよう。

\begin{align}

( N_{nk} )' &= \frac{1}{(2 \pi)^{L/2}} (-L) \sigma_k^{-L-1} \exp \left\{ 
- \frac{1}{2 \sigma_k^2} || t_n - \mu_n ||^2
\right\} + 
\frac{1}{(2 \pi)^{L/2}} \sigma_k^{-L} \exp \left\{ 
- \frac{1}{2 \sigma_k^2} || t_n - \mu_n ||^2
\right\} \left( \sigma_k^{-3} ||t_n - \mu_k ||^2 \right) \\

&= \frac{1}{(2 \pi)^{L/2}} \exp \left\{ 
- \frac{1}{2 \sigma_k^2} || t_n - \mu_n ||^2
\right\} \left\{ 
(-L) \sigma_k^{-L-1} + \sigma_k^{-L-3} || t_n - \mu_n ||^2
\right\} \\

&= \frac{1}{(2 \pi)^{L/2}} \sigma_k^{-L} \exp \left\{ 
- \frac{1}{2 \sigma_k^2} || t_n - \mu_n ||^2
\right\} \left\{ 
(-L) \sigma_k^{-1} + \sigma_k^{-3} || t_n - \mu_n ||^2
\right\} \\

&= N_{nk} \left\{ 
(-L) \sigma_k^{-1} + \sigma_k^{-3} || t_n - \mu_n ||^2
\right\}

\end{align}

よって:

\begin{align}

\frac{ \partial E_n}{ \partial \sigma_k } &= - \frac{ \pi_k }{ \sum_l \pi_l N_{nl} } \cdot ( N_{nk} )' \\
&= - \frac{ \pi_k N_{nk} }{ \sum_l \pi_l N_{nl} }
\left\{ 
(-L) \sigma_k^{-1} + \sigma_k^{-3} || t_n - \mu_n ||^2
\right\} \\

&= - \gamma_{nk} \left\{ 
(-L) \sigma_k^{-1} + \sigma_k^{-3} || t_n - \mu_n ||^2
\right\} \\

\end{align}

さて、$\sigma_k = \exp ( a_k^{\sigma} ) $により計算されることから

$$
\frac{ \partial \sigma_k }{ \partial a_k^{\sigma} } = \exp ( a_k^{\sigma} ) = \sigma_k
$$

以上により:

$$
\frac{ \partial E_n}{ \partial a_k^{\sigma} } = \gamma_{nk} \left( L -
\frac{ || t_n - \mu_n ||^2 }{ \sigma_k^2} \tag{5.157}
\right)
$$


  1. 参照:テキストP29、1.2.5節「曲線フィッティング再訪」 

  2. 平均や分散といった少ない数のパラメータが決まると分布の形が1つに決まるような分布のこと。ガウス分布など。