線形回帰で理解するカーネル法・ガウス過程・RVM・EMアルゴリズム・変分ベイズ(3)


この記事は線形回帰で理解するカーネル法・ガウス過程・RVM・EMアルゴリズム・変分ベイズ(2)の続きです。
PRML下巻の、線形回帰を例に出しながら各種の推定手法を説明した部分をまとめています。

初回(1)では線形回帰をカーネル法・ガウス過程・RVMを用いて解き、前回(2)ではEMアルゴリズムを用いて解きました。今回は変分ベイズを用いて完全にベイズ的な推論を行う方法を解説します。

線形回帰(再掲)

初回・前回と同様の、線形回帰の問題設定を述べます。

$ n = 1,\ \ldots ,\ N $ に対して目標変数 $ t_n \in \mathbb{R} $ をパラメータ $
\boldsymbol{w} \in \mathbb{R}^d $ の線形結合 $ \boldsymbol{w}^{\top}\boldsymbol\phi(\boldsymbol x_n) $ で回帰することを考えます。
ただし $ \boldsymbol x_n $ は観測される入力変数、$ \phi(\boldsymbol x_n) \in \mathbb R^d $ は 何らかの写像 $ \boldsymbol \phi $ で入力変数を特徴空間に移したベクトルです。
例えば $ \boldsymbol \phi(x) = (1,\ x,\ x^2) $ とすれば2次の多項式回帰となります。

考える統計モデルは


\begin{align}
t_n &= \boldsymbol w^{\top} \boldsymbol \phi(x_n) + \varepsilon_n\ , \\
\varepsilon_n &\sim \mathcal{N}(0,\ \beta^{-1})\ , \\
\boldsymbol w &\sim \mathcal{N}(\boldsymbol 0,\ \alpha^{-1}I)  \tag{1}
\end{align}

です。 $ \varepsilon_n\ (n = 1,\ \ldots ,\ N) $ は独立とします。
すなわち、誤差 $\varepsilon_n$ は独立に平均 $0$, 分散 $\beta^{-1}$ の正規分布に従い、
パラメータ $ \boldsymbol w $ の事前分布は平均 $\boldsymbol 0$, 共分散行列 $\alpha^{-1}I$ の $ d $ 次元正規分布とします。

後のために式 $(1)$ をベクトル表記しておくと

\begin{align}
{\bf t} &= \Phi\boldsymbol w + \boldsymbol \varepsilon\ , \\
\boldsymbol \varepsilon &\sim \mathcal{N}(\boldsymbol 0,\ \beta^{-1}I)\ , \\
\boldsymbol w &\sim \mathcal{N}(\boldsymbol 0,\ \alpha^{-1}I)  \tag{2}
\end{align}

となります。ただし

{\bf t} = 
\begin{pmatrix}
t_1 \\
\vdots \\
t_N
\end{pmatrix}
\in \mathbb{R}^{N} , \quad
\boldsymbol \varepsilon = 
\begin{pmatrix}
\varepsilon_1 \\
\vdots \\
\varepsilon_N
\end{pmatrix}
\in \mathbb{R}^{N} , \quad
\Phi = 
\begin{pmatrix}
\boldsymbol \phi(\boldsymbol x_1)^{\top} \\
\vdots \\
\boldsymbol \phi(\boldsymbol x_N)^{\top}
\end{pmatrix}
 \in \mathbb{R}^{N \times d}

です。
このとき、$\boldsymbol w$ で条件づけた ${\bf t}$ の条件付き分布と ${\bf t}$ の周辺分布は

\begin{align}
p({\bf t}| \boldsymbol w) &= \mathcal{N}(\Phi\boldsymbol w,\ \beta^{-1}I)\ , \\
p({\bf t}) &= \mathcal{N}(\boldsymbol 0,\ \alpha^{-1}\Phi\Phi^{\top} + \beta^{-1}I) \tag{3}
\end{align}

となります(2つ目の式は The Matrix Cookbook 式 $(355)$ より)。

目標は、新たな入力 $ \boldsymbol x $ に対する $ t $ の予測値、あるいは予測分布 $p(t | {\bf t}) $ を求めることです。

変分ベイズ

まず、線形回帰の問題設定から離れて変分ベイズの一般的な定式化を述べます。

観測変数を $\boldsymbol X$ とし、潜在変数とパラメータを合わせて $\boldsymbol Z$ とします。
EMアルゴリズムでは潜在変数を $\boldsymbol Z,$ パラメータを $\boldsymbol \theta$ とおいていましたが、今はこの $\boldsymbol Z$ と $\boldsymbol \theta$ をまとめて $\boldsymbol Z$ で表します。

いま、$\boldsymbol Z$ に事前分布を仮定したもとで $\boldsymbol Z$ の事後分布 $ p(\boldsymbol Z | \boldsymbol X)$ を求めたいとします。
しかし事後分布の計算が困難であるため、何らかの方法で事後分布を近似したい、という状況を考えます。

事後分布 $ p(\boldsymbol Z | \boldsymbol X)$ を近似したいというモチベーションからいったん離れるようですが、ここで次のような計算をします。
EMアルゴリズムでの計算と同様に、周辺尤度 $\log p(\boldsymbol X)$ は次のように下界 $\mathcal{L}(q)$ とKLダイバージェンス $\text{KL}(q || p)$ の和に分解できます:

\begin{align}
\log p(\boldsymbol X)
&= \int q(\boldsymbol Z) \log p(\boldsymbol X)\text{d}\boldsymbol Z \\
&= \int q(\boldsymbol Z) (\log p(\boldsymbol X, \boldsymbol Z) 
- \log p(\boldsymbol Z | \boldsymbol X))\text{d}\boldsymbol Z \\
&= \int q(\boldsymbol Z) (\log p(\boldsymbol X, \boldsymbol Z) - \log q(\boldsymbol Z) 
- (\log p(\boldsymbol Z | \boldsymbol X) - \log q(\boldsymbol Z)))\text{d}\boldsymbol Z \\
&= \int q(\boldsymbol Z)\log\frac{p(\boldsymbol X, \boldsymbol Z)}{q(\boldsymbol Z)}\text{d}\boldsymbol Z - \int q(\boldsymbol Z)\log\frac{p(\boldsymbol Z|\boldsymbol X)}{q(\boldsymbol Z)}\text{d}\boldsymbol Z \\
&= \mathcal{L}(q) + \text{KL}(q(\boldsymbol Z)\|p(\boldsymbol Z|\boldsymbol X)) .\tag{4}
\end{align}

ここで

\begin{align}
\mathcal{L}(q)
&= \int q(\boldsymbol Z)\log\frac{p(\boldsymbol X, \boldsymbol Z)}{q(\boldsymbol Z)}\text{d}\boldsymbol Z, \\
\text{KL}(q(\boldsymbol Z)\|p(\boldsymbol Z|\boldsymbol X))
&= \int q(\boldsymbol Z)\log\frac{q(\boldsymbol Z)}{p(\boldsymbol Z|\boldsymbol X)}\text{d}\boldsymbol Z
\end{align}

です。
$ \text{KL}(q(\boldsymbol Z)|p(\boldsymbol Z|\boldsymbol X))$ を最小化する $q(\boldsymbol Z)$ は、ちょうどいま求めたい事後分布 $ p(\boldsymbol Z|\boldsymbol X)$ です。
そして式 $(4)$ の左辺 $\log p(\boldsymbol X)$ は $q$ に依存しないため、$ \text{KL}(q(\boldsymbol Z)|p(\boldsymbol Z|\boldsymbol X))$ の最小化は周辺尤度の下界 $\mathcal{L}(q)$ の最大化と同値です。

よって、事後分布を求めるためには、周辺尤度の下界 $\mathcal{L}(q)$ を最大化する分布 $q(\boldsymbol Z)$ を求めればよい、ということがわかりました。

とはいえ、$\mathcal{L}(q)$ の最大化もそのままでは難しいので、分布 $q$ に制約を加えることで近似的に事後分布を求めます。
分布 $q$ に加える制約として、因数分解の形を仮定する平均場近似を行い

q(\boldsymbol Z) = \prod_{i=1}^Mq_i(\boldsymbol Z_i)

と仮定します。このもとで、$\mathcal{L}(q)$ から $q_j(\boldsymbol Z_j)$ に依存する部分のみを取り出すと

\begin{align}
\mathcal{L}(q)
&= \int q(\boldsymbol Z)\log\frac{p(\boldsymbol X, \boldsymbol Z)}{q(\boldsymbol Z)}\text{d}\boldsymbol Z \\
&= \int \prod_{i=1}^Mq_i(\boldsymbol Z_i)\log\frac{p(\boldsymbol X, \boldsymbol Z)}{\prod_{i=1}^Mq_i(\boldsymbol Z_i)}\text{d}\boldsymbol Z \\
&= \int \prod_{i=1}^Mq_i(\boldsymbol Z_i)\left(\log p(\boldsymbol X, \boldsymbol Z) - \sum_{i=1}^M \log q_i(\boldsymbol Z_i)\right)\text{d}\boldsymbol Z \\
&= \int q_j(\boldsymbol Z_j)\left(\int \log p(\boldsymbol X, \boldsymbol Z)\prod_{i\neq j} q_i(\boldsymbol Z_i) \text{d}\boldsymbol Z_i\right) \text{d}\boldsymbol Z_j - \int q_j(\boldsymbol Z_j) \log q_j(\boldsymbol Z_j) \text{d}\boldsymbol Z_j + \text{const.} \\
&= \int q_j(\boldsymbol Z_j) \log \tilde{p}(\boldsymbol Z_j) \text{d}\boldsymbol Z_j - \int q_j(\boldsymbol Z_j) \log q_j(\boldsymbol Z_j) \text{d}\boldsymbol Z_j + \text{const.} \\
&= - \text{KL}(q_j(\boldsymbol Z_j) || \tilde{p}(\boldsymbol Z_j)) + \text{const.} 
\end{align}

です。ここで $q_j$ に依存しない項は $\text{const.}$ にまとめており、
また分布 $\tilde{p}(\boldsymbol Z_j)$ を

\begin{align}
\log\tilde{p}(\boldsymbol Z_j)
&= \int \log p(\boldsymbol X, \boldsymbol Z)\prod_{i\neq j} q_i(\boldsymbol Z_i) \text{d}\boldsymbol Z_i + \text{const.} \\
&= \mathbb{E}_{i\neq j}[\log p(\boldsymbol X, \boldsymbol Z)] + \text{const.}
\end{align}

で定めています。つまり $\tilde{p}(\boldsymbol Z_j)$ は、(正規化定数を除いて) $\boldsymbol Z_j$ 以外の全ての $\boldsymbol Z_i$ で $\log p(\boldsymbol X, \boldsymbol Z)$ の期待値をとったものです。

先の計算から、$\mathcal{L}(q)$ を $q_j$ に関して最大化するためには $\text{KL}(q_j(\boldsymbol Z_j) || \tilde{p}(\boldsymbol Z_j))$ を最小化すればよいので、最適解 $q_j^*(\boldsymbol Z_j)$ (の対数)は

\begin{align}
\log q_j^*(\boldsymbol Z_j) 
&= \log\tilde{p}(\boldsymbol Z_j) \\
&= \mathbb{E}_{i\neq j}[\log p(\boldsymbol X, \boldsymbol Z)] + \text{const.} \tag{5}
\end{align}

となります。
右辺は $q_i (i \neq j) $ たちに依存しているので、この式はすべての $q_j$ について陽に解ける式にはなっていませんが、この式を用いて反復計算を行うことで、平均場近似

q(\boldsymbol Z) = \prod_{i=1}^Mq_i(\boldsymbol Z_i)

を仮定したときの近似事後分布を求めることができます。

線形回帰への適用

今述べた方法を、線形回帰のパラメータ $\boldsymbol w$ とハイパーパラメータ $\alpha$ の事後分布の近似に用います。
$ \alpha $ は $\boldsymbol w$ の事前分布の精度パラメータであり、$ \alpha $ の事前分布は共役事前分布であるガンマ分布 $\text{Ga}(\alpha | a_0, b_0)$ とします。
ガンマ分布の確率密度関数は

\text{Ga}(x | a, b) = \frac{b^a}{\Gamma(a)} x^{a-1}\text{e}^{-bx}

($\Gamma(\cdot)$ はガンマ関数)です。
前節の記法との対応は

\begin{align}
\boldsymbol X &= {\bf t} = (t_1,\ \ldots,\ t_N)^{\top}, \\
\boldsymbol Z &= (\boldsymbol w^{\top}, \alpha)^{\top}
\end{align}

です。
観測ノイズの精度パラメータ $ \beta $ は、ここでは(簡単のため)与えられた定数とします。

$\boldsymbol w$ と $\alpha$ の事後分布 $p(\boldsymbol w, \alpha | {\bf t})$ を

p(\boldsymbol w, \alpha | {\bf t}) = q(\boldsymbol w)q(\alpha)

によって近似し、$q(\boldsymbol w),\ q(\alpha)$ を求めます。
式 $(5)$ より、$q(\boldsymbol w)$ の最適解 $q^*(\boldsymbol w)$ は

\log q^*(\boldsymbol w) = \mathbb{E}_{\alpha}[\log p({\bf t}, \boldsymbol w, \alpha)] + \text{const.} \tag{6}

で求められます。
ここで

\begin{align}
p({\bf t}, \boldsymbol w, \alpha) &= p({\bf t} | \boldsymbol w)p(\boldsymbol w|\alpha)p(\alpha), \\
p({\bf t} | \boldsymbol w) &= \mathcal{N}(\Phi\boldsymbol w,\ \beta^{-1}I), \\
p(\boldsymbol w|\alpha) &= \mathcal{N}(\boldsymbol 0,\ \alpha^{-1}I), \\
p(\alpha) &= \text{Ga}(a_0, b_0)
\end{align}

なので、これを用いて式 $(6)$ を計算すると($\boldsymbol w$ に依存しない項は $\text{const.}$ にまとめられることに注意して)

\begin{align}
\log q^*(\boldsymbol w)
&= \mathbb{E}_{\alpha}[\log p({\bf t}, \boldsymbol w, \alpha)] + \text{const.} \\
&= \mathbb{E}_{\alpha}[\log p({\bf t} | \boldsymbol w) + \log p(\boldsymbol w | \alpha)] + \text{const.} \\
&= \mathbb{E}_{\alpha}[\mathcal{N}(\Phi\boldsymbol w,\ \beta^{-1}I) + \mathcal{N}(\boldsymbol 0,\ \alpha^{-1}I)] + \text{const.} \\ 
&= -\frac{\beta}{2} \| {\bf t} - \Phi\boldsymbol w \|^2 - \frac{\mathbb{E}[\alpha]}{2} \| \boldsymbol w \|^2 + \text{const.}
\end{align}

となります。これは $\boldsymbol w$ の2次式なので正規分布となり、線形回帰で理解するカーネル法・ガウス過程・RVM・EMアルゴリズム・変分ベイズ(1)のRVMの節の式 $(10)$ 周辺と同様の計算(平方完成)を行うことで

\begin{align}
q^*(\boldsymbol w) &= \mathcal{N}(\boldsymbol m,\ S), \\
\boldsymbol m &= \beta S \Phi^{\top}{\bf t}, \\
S &= (\mathbb{E}[\alpha]I + \beta\Phi^{\top}\Phi)^{-1}
\end{align}

となります。
$\alpha$ が $\mathbb{E}[\alpha]$ に置き換わっているだけで、これまで扱った手法とほぼ同じ結果になっていることがわかります。

$q^*(\alpha)$ も求めると

\begin{align}
\log q^*(\alpha)
&= \mathbb{E}_{\boldsymbol w}[\log p({\bf t}, \boldsymbol w, \alpha)] + \text{const.} \\
&= \mathbb{E}_{\boldsymbol w}[\log p(\boldsymbol w | \alpha) + \log p(\alpha)] + \text{const.} \\
&= \mathbb{E}_{\boldsymbol w}[\mathcal{N}(\boldsymbol 0,\ \alpha^{-1}I) + \text{Ga}(a_0, b_0)] + \text{const.} \\
&= \frac{d}{2}\log \alpha - \frac{\mathbb{E}[\| \boldsymbol w \|^2]}{2}\alpha + (a_0 - 1)\log\alpha - b_0 \alpha + \text{const.} \\
&= \left( a_0 + \frac{d}{2} - 1 \right) \log\alpha - \left( b_0 + \frac{\mathbb{E}[\| \boldsymbol w \|^2]}{2} \right)\alpha + \text{const.}
\end{align}

よりこれはガンマ分布であり

q^*(\alpha) = \text{Ga}\left(a_0 + \frac{d}{2},\ b_0 + \frac{\mathbb{E}[\| \boldsymbol w \|^2]}{2}\right)

となります。ガンマ分布 $\text{Ga}(a, b)$ の期待値は $a/b$ なので、事前分布として $a_0 = b_0 = 0$ のガンマ分布を用いると近似事後分布の期待値が前回のEMアルゴリズムで求めたものと同じになることがわかります。

$ q^*(\boldsymbol w) $ は $\mathbb{E}[\alpha] $ に依存しており、$ q^{*}(\alpha) $ は $ \mathbb{E}[| \boldsymbol w|^2 ]$ に依存しているので、これらを交互に更新することで近似事後分布が求められます。

これらはそれぞれ

\begin{align}
\mathbb{E}[\alpha] &= \frac{a_0 + \frac{d}{2}}{b_0 + \frac{\mathbb{E}[\| \boldsymbol w \|^2]}{2}}, \\
\mathbb{E}[\| \boldsymbol w\|^2 ] &= \boldsymbol m^{\top}\boldsymbol m + \text{tr}S
\end{align}

で求められます。

予測分布

近似事後分布 $p(\boldsymbol w, \alpha | {\bf t}) = q^{*}(\boldsymbol w)q^{*}(\alpha)$ が求められれば予測分布も計算できます。予測分布は

\begin{align}
p(t|{\bf t})
&= \int p(t|\boldsymbol w, \alpha)p(\boldsymbol w, \alpha | {\bf t})\text{d}\boldsymbol w \text{d}\alpha \\
&= \int p(t|\boldsymbol w)p(\boldsymbol w, \alpha | {\bf t})\text{d}\boldsymbol w \text{d}\alpha \\
&= \int p(t|\boldsymbol w)q^*(\boldsymbol w)q^*(\alpha)\text{d}\boldsymbol w \text{d}\alpha \\
&= \int p(t|\boldsymbol w)q^*(\boldsymbol w)\text{d}\boldsymbol w \\
&= \int \mathcal{N}(\boldsymbol w^{\top}\boldsymbol \phi(\boldsymbol x),\ \beta^{-1})\mathcal{N}(\boldsymbol m,\ S)\text{d}\boldsymbol w \\
&= \mathcal{N}\left(\boldsymbol m^{\top}\boldsymbol \phi(\boldsymbol x),\ \beta^{-1} + \boldsymbol \phi(\boldsymbol x)^{\top}S\boldsymbol \phi(\boldsymbol x)\right)
\end{align}

となります。
ここで1行目は周辺化の計算、2行目では $\boldsymbol w$ を与えたもとで $t$ が $\alpha$ に依存しないことを利用しています。
3行目では平均場近似による近似事後分布を代入し、4行目では $\alpha$ が積分で消えることを利用しています。

式の形としては
線形回帰で理解するカーネル法・ガウス過程・RVM・EMアルゴリズム・変分ベイズ(1)のRVMの節で求めたものとほぼ同じになっています。

まとめ

今回は変分ベイズを用いて線形回帰のパラメータ $\boldsymbol w$ とハイパーパラメータ $\alpha$ の事後分布を近似的に求めました。
事後分布が陽に求められるのは簡単なモデルの場合のみなので、平均場近似を用いた周辺尤度の下界 $\mathcal{L}(q)$ の最大化という枠組みのおかげで、より一般のケースで(近似的に)事後分布を求められるようになります。

ここまでで扱ってきた手法で、線形回帰の問題をどう解いたかを簡単にまとめます。

  • カーネル法
    • パラメータ $\boldsymbol w$ のMAP推定による $t$ の予測値を、特徴写像 $\boldsymbol \phi$ を用いずにカーネル関数 $k(\boldsymbol x, \boldsymbol x')$ だけで表現した
  • ガウス過程回帰
    • パラメータ $\boldsymbol w$ を導入する代わりに $t$ がガウス過程に従うことを仮定することで予測分布 $p(t|{\bf t})$ を求めた
  • RVM
    • パラメータ $\boldsymbol w$ の事前分布として要素ごとに異なる精度パラメータ $\alpha_i$ を導入し、観測ノイズの精度パラメータ $\beta$ とともに周辺尤度最大化(経験ベイズ)で点推定を行うことで、$\boldsymbol w$ の多くの要素が予測に寄与しない疎な解を得た
  • EMアルゴリズム
    • ハイパーパラメータ $\alpha,\ \beta$ を点推定するために、周辺尤度の代わりにその下界の最大化を行うEステップ・Mステップの反復計算を適用した
  • 変分ベイズ
    • ハイパーパラメータ $\alpha$ にも事前分布を仮定($\beta$ は固定)して $\boldsymbol w$ と $\alpha$ の近似事後分布を平均場近似により求めた

3回にわたって解説してきた内容はほとんどPRMLの本文か演習問題になっている内容です。
線形回帰に適用するという簡単な例題でも、自分にとっては結構大変でした。

「線形回帰で理解する」というタイトルですが、この例だけではこれらの手法の利点と欠点を理解しきれないと思います(自分自身がよくわかっていません)。
これらのアルゴリズムを実装して予測結果を図示するなどすればよりわかりやすいかもしれません…。
python のパッケージなどで簡単に試せたりするのでしょうか。何かあれば教えていただけると助かります。