線形回帰で理解するカーネル法・ガウス過程・RVM・EMアルゴリズム・変分ベイズ(2)


この記事は線形回帰で理解するカーネル法・ガウス過程・RVM・EMアルゴリズム・変分ベイズ(1)の続きです。
PRML下巻の、線形回帰を例に出しながら各種の推定手法を説明した部分をまとめています。
前回の記事を読まなくても、今回の記事単体でも読めると思います。

前回は線形回帰をカーネル法・ガウス過程・RVMを用いて解いたので、今回はEMアルゴリズムで解いてみます。
変分ベイズはさらに別の記事で扱います。

線形回帰(再掲)

前回の線形回帰で理解するカーネル法・ガウス過程・RVM・EMアルゴリズム・変分ベイズ(1)の冒頭で述べた問題設定を再掲します。

$ n = 1,\ \ldots ,\ N $ に対して目標変数 $ t_n \in \mathbb{R} $ をパラメータ $
\boldsymbol{w} \in \mathbb{R}^d $ の線形結合 $ \boldsymbol{w}^{\top}\boldsymbol\phi(\boldsymbol x_n) $ で回帰することを考えます。
ただし $ \boldsymbol x_n $ は観測される入力変数、$ \phi(\boldsymbol x_n) \in \mathbb R^d $ は 何らかの写像 $ \boldsymbol \phi $ で入力変数を特徴空間に移したベクトルです。
例えば $ \boldsymbol \phi(x) = (1,\ x,\ x^2) $ とすれば2次の多項式回帰となります。

考える統計モデルは


\begin{align}
t_n &= \boldsymbol w^{\top} \boldsymbol \phi(x_n) + \varepsilon_n\ , \\
\varepsilon_n &\sim \mathcal{N}(0,\ \beta^{-1})\ , \\
\boldsymbol w &\sim \mathcal{N}(\boldsymbol 0,\ \alpha^{-1}I)  \tag{1}
\end{align}

です。 $ \varepsilon_n\ (n = 1,\ \ldots ,\ N) $ は独立とします。
すなわち、誤差 $\varepsilon_n$ は独立に平均 $0$, 分散 $\beta^{-1}$ の正規分布に従い、
パラメータ $ \boldsymbol w $ の事前分布は平均 $\boldsymbol 0$, 共分散行列 $\alpha^{-1}I$ の $ d $ 次元正規分布とします。

後のために式 $(1)$ をベクトル表記しておくと

\begin{align}
{\bf t} &= \Phi\boldsymbol w + \boldsymbol \varepsilon\ , \\
\boldsymbol \varepsilon &\sim \mathcal{N}(\boldsymbol 0,\ \beta^{-1}I)\ , \\
\boldsymbol w &\sim \mathcal{N}(\boldsymbol 0,\ \alpha^{-1}I)  \tag{2}
\end{align}

となります。ただし

{\bf t} = 
\begin{pmatrix}
t_1 \\
\vdots \\
t_N
\end{pmatrix}
\in \mathbb{R}^{N} , \quad
\boldsymbol \varepsilon = 
\begin{pmatrix}
\varepsilon_1 \\
\vdots \\
\varepsilon_N
\end{pmatrix}
\in \mathbb{R}^{N} , \quad
\Phi = 
\begin{pmatrix}
\boldsymbol \phi(\boldsymbol x_1)^{\top} \\
\vdots \\
\boldsymbol \phi(\boldsymbol x_N)^{\top}
\end{pmatrix}
 \in \mathbb{R}^{N \times d}

です。
このとき、$\boldsymbol w$ で条件づけた ${\bf t}$ の条件付き分布と ${\bf t}$ の周辺分布は

\begin{align}
p({\bf t}| \boldsymbol w) &= \mathcal{N}(\Phi\boldsymbol w,\ \beta^{-1}I)\ , \\
p({\bf t}) &= \mathcal{N}(\boldsymbol 0,\ \alpha^{-1}\Phi\Phi^{\top} + \beta^{-1}I) \tag{3}
\end{align}

となります(2つ目の式は The Matrix Cookbook 式 $(355)$ より)。

目標は、新たな入力 $ \boldsymbol x $ に対する $ t $ の予測値、あるいは予測分布 $p(t | {\bf t}) $ を求めることです。

EMアルゴリズム

まず、線形回帰の問題設定から離れてEMアルゴリズムの一般的な定式化を述べます。

観測変数 $\boldsymbol X,$ 潜在変数 $\boldsymbol Z,$ パラメータ $\boldsymbol \theta$ をもつモデルを考えます。
$\boldsymbol X,\ \boldsymbol Z$ は確率変数、$\boldsymbol \theta$ は未知の定数です。

代表的な例は $K$ 個の混合成分を持った混合正規分布です。
$\boldsymbol X = (\boldsymbol x_1,\ \ldots,\ \boldsymbol x_N),\ \boldsymbol Z = (\boldsymbol
z_1,\ \ldots,\ \boldsymbol z_N),\ \boldsymbol \theta = (\pi_1,\ \ldots,\ \pi_K,\ \boldsymbol \mu_1,\ \ldots,\ \boldsymbol \mu_K,\ \Sigma_1,\ \ldots,\ \Sigma_K)$ で各 $\boldsymbol z_n$ は $K$ 次元の one-hot ベクトルで、対応する $\boldsymbol x_n$ がどの混合成分に属するかを表します。 $p(z_{nk} = 1) = \pi_k$ です。
潜在変数 $\boldsymbol Z$ を与えたもとでは、$z_{nk} = 1$ なる $k$ に対して $\boldsymbol x_n$ が 正規分布 $\mathcal{N}(\boldsymbol \mu_k,\ \Sigma_k)$ に従います。
EMアルゴリズムの説明に混合正規分布を用いるのは一般的で、わかりやすい記事(EMアルゴリズム徹底解説)もあります。

今回考えている線形回帰の場合、$\boldsymbol X,\ \boldsymbol Z,\ \boldsymbol \theta$ に対応するものは以下のようになります。

\begin{align}
\boldsymbol X &= {\bf t} = (t_1,\ \ldots,\ t_N)^{\top}, \\
\boldsymbol Z &= \boldsymbol w\ , \\
\boldsymbol \theta &= (\alpha,\ \beta)^{\top}. \tag{4}
\end{align}

EMアルゴリズムでは周辺尤度 $p(\boldsymbol X|\boldsymbol \theta)$ を最大化することで未知パラメータ $\boldsymbol \theta$ を点推定するのが目的です(経験ベイズ、または第二種の最尤推定)。

前回の記事では周辺尤度を具体的に書き下すことで周辺尤度を最大化する $\boldsymbol \theta$ を求めましたが、ここでは $p(\boldsymbol X|\boldsymbol \theta) = \int p(\boldsymbol X, \boldsymbol Z |\theta)\text{d}\boldsymbol Z$ を直接 $\boldsymbol \theta$ について最大化することは難しく、代わりに同時分布 $p(\boldsymbol X, \boldsymbol Z|\boldsymbol \theta)$ を最大化することはできると仮定します。

$\boldsymbol Z$ に関する任意の確率分布 $q(\boldsymbol Z)$ を用いて、対数周辺尤度を次のように変形します。

\begin{align}
\log p(\boldsymbol X|\boldsymbol \theta)
&= \int q(\boldsymbol Z)\log p(\boldsymbol X|\boldsymbol \theta)\text{d}\boldsymbol Z \\
&= \int q(\boldsymbol Z)\left(\log p(\boldsymbol X, \boldsymbol Z|\boldsymbol \theta) - \log p(\boldsymbol Z|\boldsymbol X, \boldsymbol \theta)\right)\text{d}\boldsymbol Z \\
&= \int q(\boldsymbol Z)\left(\log p(\boldsymbol X, \boldsymbol Z|\boldsymbol \theta) - \log q(\boldsymbol Z) - (\log p(\boldsymbol Z|\boldsymbol X, \boldsymbol \theta) - \log q(\boldsymbol Z)\right)\text{d}\boldsymbol Z \\
&= \int q(\boldsymbol Z)\log\frac{p(\boldsymbol X, \boldsymbol Z|\boldsymbol \theta)}{q(\boldsymbol Z)}\text{d}\boldsymbol Z - \int q(\boldsymbol Z)\log\frac{p(\boldsymbol Z|\boldsymbol X, \boldsymbol \theta)}{q(\boldsymbol Z)}\text{d}\boldsymbol Z \\
&= \mathcal{L}(q, \boldsymbol \theta) + \text{KL}(q(\boldsymbol Z)\|p(\boldsymbol Z|\boldsymbol X, \boldsymbol \theta)) \tag{5}
\end{align}

1行目では $\int q(\boldsymbol Z)\text{d}\boldsymbol Z = 1$ を用いました。

最終行の第2項は $q(\boldsymbol Z) $ と $p(\boldsymbol Z|\boldsymbol X, \boldsymbol \theta)$ の間のKLダイバージェンスで、常に非負です。この値は $q(\boldsymbol Z) = p(\boldsymbol Z|\boldsymbol X, \boldsymbol \theta)$ のときのみ $0$ となります。

最終行の第1項はいわゆる ELBO(Evidence Lower BOund) と呼ばれる量で、これを

\mathcal{L}(q, \boldsymbol \theta) = \int q(\boldsymbol Z)\log\frac{p(\boldsymbol X, \boldsymbol Z|\boldsymbol \theta)}{q(\boldsymbol Z)}\text{d}\boldsymbol Z \tag{6}

とおいています($\boldsymbol X$ にも依存していますが引数の $\boldsymbol X$ は省略しています)。
KLダイバージェンスが非負なため $\mathcal{L}(q, \boldsymbol \theta)$ は対数周辺尤度(エビデンス)の下界になっています。

対数周辺尤度 $\log p(\boldsymbol X|\boldsymbol \theta)$ を直接 $\boldsymbol \theta$ について最大化するのが難しいので、代わりにその下界である $\mathcal{L}(q, \boldsymbol \theta)$ を最大化することを考えます。
しかし $\mathcal{L}(q, \boldsymbol \theta)$ は $q$ に依存しているので、まず $q$ を適切に定めて $\mathcal{L}(q, \boldsymbol \theta)$ ができるだけ真の対数周辺尤度 $\log p(\boldsymbol X|\boldsymbol \theta)$ に近くなるようにする必要があります。ただし $q$ を最適化するためにも $\boldsymbol \theta$ の値が必要です。
そこで、分布 $q$ とパラメータ $\boldsymbol \theta$ を交互に最適化するのがEMアルゴリズムです。
$q$ の最適化がEステップであり、$\boldsymbol \theta$ の最適化がMステップです。

Eステップ

Eステップでは、$\boldsymbol \theta$ の値を所与として分布 $q$ を最適化します。
現在の $\boldsymbol \theta$ の値を $\boldsymbol \theta = \boldsymbol \theta^{\text{old}}$ とします。

$\boldsymbol \theta = \boldsymbol \theta^{\text{old}}$ を式 $(5)$ に代入すると

\begin{align}
\log p(\boldsymbol X|\boldsymbol \theta^{\text{old}})
&= \mathcal{L}(q, \boldsymbol \theta^{\text{old}}) + \text{KL}(q(\boldsymbol Z)\|p(\boldsymbol Z|\boldsymbol X, \boldsymbol \theta^{\text{old}}))
\end{align}

となります。
上の式より、$ \text{KL}(q(\boldsymbol Z)|p(\boldsymbol Z|\boldsymbol X, \boldsymbol \theta^{\text{old}})) $ を最小化すれば、対数周辺尤度と下界のギャップが最小になることがわかります。

これを最小化する $q(\boldsymbol Z) = p(\boldsymbol Z|\boldsymbol X, \boldsymbol \theta^{\text{old}})$ を計算するのがEステップです。

Mステップ

Mステップでは、Eステップで計算した $q(\boldsymbol Z) = p(\boldsymbol Z|\boldsymbol X, \boldsymbol \theta^{\text{old}})$ を所与として、下界 $\mathcal{L}(q, \boldsymbol \theta)$ を $\boldsymbol \theta$ に関して最大化します。

$q(\boldsymbol Z) = p(\boldsymbol Z|\boldsymbol X, \boldsymbol \theta^{\text{old}})$ を式 $(6)$ に代入すると

\begin{align}
\mathcal{L}(q, \boldsymbol \theta)
&= \int p(\boldsymbol Z|\boldsymbol X, \boldsymbol \theta^{\text{old}})\log\frac{p(\boldsymbol X, \boldsymbol Z|\boldsymbol \theta)}{p(\boldsymbol Z|\boldsymbol X, \boldsymbol \theta^{\text{old}})}\text{d}\boldsymbol Z \\
&= \int p(\boldsymbol Z|\boldsymbol X, \boldsymbol \theta^{\text{old}})\log p(\boldsymbol X, \boldsymbol Z|\boldsymbol \theta)\text{d}\boldsymbol Z - \int p(\boldsymbol Z|\boldsymbol X, \boldsymbol \theta^{\text{old}})\log p(\boldsymbol Z|\boldsymbol X, \boldsymbol \theta^{\text{old}})\text{d}\boldsymbol Z \\
&= \mathbb{E}_{p(\boldsymbol Z|\boldsymbol X, \boldsymbol \theta^{\text{old}})}\left[ \log p(\boldsymbol X, \boldsymbol Z|\boldsymbol \theta) \right] + \text{const.} \tag{7}
\end{align}

となります。ここで $\boldsymbol \theta$ に依存しない第2項は $\text{const.}$ としました。
式 $(7)$ を $\boldsymbol \theta$ に関して最大化するのがMステップです。

Mステップで求めた $\boldsymbol \theta$ の値は、次のEステップで $\boldsymbol \theta^{\text{old}}$ として用いられ、反復計算が行われます。

以上のようにEステップとMステップを繰り返してパラメータ $\boldsymbol \theta$ を最適化するのがEMアルゴリズムです。
この反復計算で対数周辺尤度と下界はともに単調増加するため、$\boldsymbol \theta$ は局所最適解に収束します。ただし一般には大域最適解に収束するとは限りません。

線形回帰への適用

今述べたEMアルゴリズムを線形回帰の未知パラメータ $\boldsymbol \theta = (\alpha,\ \beta)^{\top}$ の推定に適用します。
$\boldsymbol \theta = (\alpha,\ \beta)^{\top}$ さえ求めれば式 $(3)$ を用いて $\boldsymbol w$ の事後分布 $p(\boldsymbol w | {\bf t})$ と予測分布 $p(t | {\bf t})$ を計算することができます。

Eステップ

式 $(4)$ の対応から、Eステップでは $p(\boldsymbol w|{\bf t}, \alpha^{\text{old}}, \beta^{\text{old}}),$ すなわち $\boldsymbol w$ の事後分布を求めます。

$p(\boldsymbol w|{\bf t}, \alpha^{\text{old}}, \beta^{\text{old}})$ の計算は前回の線形回帰で理解するカーネル法・ガウス過程・RVM・EMアルゴリズム・変分ベイズ(1)のRVMの項で行った計算と同様なので省略します。
事後分布を計算すると

\begin{align}
p(\boldsymbol w|{\bf t}, \alpha^{\text{old}}, \beta^{\text{old}}) &= \mathcal{N}(\boldsymbol m,\ S)\ ,
\end{align}

ただし

\begin{align}
\boldsymbol m &= \beta^{\text{old}} S \Phi^{\top}{\bf t}\ , \\
S &= (\alpha^{\text{old}} I + \beta^{\text{old}} \Phi^{\top}\Phi)^{-1} \tag{8}
\end{align}

となります。

Mステップ

式 $(7)$ より、Mステップでは

\mathbb{E}_{p(\boldsymbol w|{\bf t}, \alpha^{\text{old}}, \beta^{\text{old}})}\left[ \log p(\boldsymbol t, \boldsymbol w|\alpha, \beta) \right]

を $\alpha,\ \beta$ に関して最大化します。
長いので $\mathbb{E}_{p(\boldsymbol w|{\bf t}, \alpha^{\text{old}}, \beta^{\text{old}})}$ を $\mathbb{E}_{\boldsymbol w}$ と書くことにしてこれを計算すると

\begin{align}
&\mathbb{E}_{p(\boldsymbol w|{\bf t}, \alpha^{\text{old}}, \beta^{\text{old}})}\left[ \log p(\boldsymbol t, \boldsymbol w|\alpha, \beta) \right] \\
=\ & \mathbb{E}_{\boldsymbol w}\left[ \log p({\bf t}|\boldsymbol w, \beta) + \log p(\boldsymbol w|\alpha) \right] \\
=\ & \mathbb{E}_{\boldsymbol w}\left[ \log \mathcal{N}(\Phi\boldsymbol w,\ \beta^{-1}I) + \log\mathcal{N}(\boldsymbol 0,\ \alpha^{-1}I) \right] \\
=\ & \mathbb{E}_{\boldsymbol w}\left[ \frac{N}{2}\log\frac{\beta}{2\pi} - \frac{\beta}{2} \| {\bf t} - \Phi\boldsymbol w \|^2 + \frac{d}{2} \log\frac{\alpha}{2\pi} - \frac{\alpha}{2}\|\boldsymbol w\|^2 \right] \\
=\ & \frac{N}{2}\log\beta - \frac{\beta}{2}\left( \|{\bf t}\|^2 - 2{\bf t}^{\top}\Phi \mathbb{E}_{\boldsymbol w}[\boldsymbol w] + \mathbb{E}_{\boldsymbol w}[\boldsymbol w^{\top}\Phi^{\top}\Phi\boldsymbol w] \right) + \frac{d}{2}\log\alpha - \frac{\alpha}{2}\mathbb{E}_{\boldsymbol w}[\|\boldsymbol w\|^2] + \text{const.} \tag{9}
\end{align}

となります。
式 $(9)$ を $\alpha$ で微分して $0$ とおくと

\begin{align}
&\frac{d}{2\alpha} - \frac{\mathbb{E}_{\boldsymbol w}[\|\boldsymbol w\|^2]}{2} = 0 \\
\Rightarrow \quad & \alpha = \frac{d}{\mathbb{E}_{\boldsymbol w}[\|\boldsymbol w\|^2]} = \frac{d}{\boldsymbol m^{\top}\boldsymbol m + \text{tr}S}
\end{align} \tag{10}

となり、$\alpha$ の更新式が得られます。
ここで、$\mathbb{E}_{\boldsymbol w}[| \boldsymbol w |^2]$ の計算はThe Matrix Cookbookの式 $(318)$ より一般に

\mathbb{E}[\boldsymbol w^{\top}A\boldsymbol w]
= \mathbb{E}[\boldsymbol w]^{\top}A\mathbb{E}[\boldsymbol w] + \text{tr}(A\text{cov}(\boldsymbol w)) \tag{11}

なので

\begin{align}
\mathbb{E}_{\boldsymbol w}[\|\boldsymbol w\|^2]
&= \mathbb{E}[\boldsymbol w]^{\top}\mathbb{E}[\boldsymbol w] + \text{tr}(\text{cov}(\boldsymbol w)) \\
&= \boldsymbol m^{\top}\boldsymbol m + \text{tr}S
\end{align}

であることを用いました。

つづいて $\beta$ の更新式を求めるために式 $(9)$ を $\beta$ で微分すると

\begin{align}
&\frac{N}{2\beta} - \frac{1}{2}\left( \|{\bf t}\|^2 - 2{\bf t}^{\top}\Phi\mathbb{E}_{\boldsymbol w}[\boldsymbol w] + \mathbb{E}_{\boldsymbol w}[\boldsymbol w^{\top}\Phi^{\top}\Phi\boldsymbol w] \right) \\
=\ & \frac{N}{2\beta} - \frac{1}{2}\left( \|{\bf t}\|^2 - 2{\bf t}^{\top}\Phi\boldsymbol m + \boldsymbol m^{\top}\Phi^{\top}\Phi\boldsymbol m + \text{tr}(\Phi^{\top}\Phi S) \right)
\end{align}

となり(この計算にも式 $(11)$ を用いています)、これを $0$ とおくと

\beta = \frac{N}{\| {\bf t} - \Phi\boldsymbol m \|^2 + \text{tr}(\Phi^{\top}\Phi S)} \tag{12}

という $\beta$ の更新式が得られました。
$ \text{tr}(\Phi^{\top}\Phi S)$ はもう少し簡単にすることができて

\begin{align}
\text{tr}(\Phi^{\top}\Phi S)
&= \frac{1}{\beta^{\text{old}}}\text{tr}((\alpha^{\text{old}} I + \beta^{\text{old}} \Phi^{\top}\Phi)S - \alpha^{\text{old}} S) \\
&= \frac{1}{\beta^{\text{old}}}(\text{tr}I - \alpha^{\text{old}}\text{tr}S) \\
&= \frac{d - \alpha^{\text{old}} \text{tr}S}{\beta^{\text{old}}}
\end{align}

と計算できます。

まとめると、Eステップで式 $(8)$ を用いて $\boldsymbol m,\ S$ を更新し、Mステップで式 $(10), (12)$ を用いて $\alpha,\ \beta$ を更新する計算を繰り返すことでハイパーパラメータ $\alpha,\ \beta$ の点推定ができます。
求めた値を式 $(3)$ に代入することで、事後分布 $p(\boldsymbol w|{\bf t})$ や予測分布 $p(t|{\bf t})$ の計算が可能になります。

ここまでのまとめ

EMアルゴリズムだけで長くなってしまったので、今回はここまでとして変分ベイズは次回扱います。

今回の線形回帰の設定では周辺尤度が簡単に求められるため、前回のRVMの項で行ったように直接周辺尤度を最大化してハイパーパラメータを点推定することができましたが、EMアルゴリズムはより広い統計モデルに対してハイパーパラメータの推定を可能にしてくれます。

EMアルゴリズムの説明としては本文中でも挙げたように混合正規分布の例のほうがわかりやすいのですが、線形回帰という同じ設定でほかの推定手法と比較するのも見通しがよいかと思いました。

次回の変分ベイズでは、ハイパーパラメータを定数として扱って点推定するのでなく、事前分布を設定して完全にベイズ的に推論を行います。事後分布を陽に求めることは一般に難しいので因数分解した形で近似した近似事後分布を求めます。