多クラスロジスティック回帰のIRLSバッチ処理アルゴリズム


はじめに

この記事ではPRML4.3.4に記載されている、多クラスロジスティック回帰のIRLSバッチ処理アルゴリズムを紹介します。PRML本文にはBishop and Nabney(2008)にアルゴリズムの詳細があると書いてありますが、どうも見当たらないので実際に導出・実装してみました。

準備

まずは数学的な準備から始めます。詳細はPRML4章を参照してください!
今回はクラス数K、データ数N、入力ベクトル$\boldsymbol{x}$の次元数Mで考えていきます。
つまり、

  • 入力 $(M, 1)$: $\boldsymbol{x}=\{x_1, ...,x_M\}^T$
  • パラメータ $(M, K)$: $\boldsymbol{W}=\{\boldsymbol{w}_1, ..., \boldsymbol{w}_K\}$
  • 予測値 $(N, K)$: $\boldsymbol{Y}=\{\boldsymbol{y}_1, ...,\boldsymbol{y}_N\}^T \text{where}\ \boldsymbol{y_n}=\{y_{n1}, ...,y_{nK}\}$
  • 目的値 $(N, K)$: $\boldsymbol{T}=\{\boldsymbol{t}_1, ...,\boldsymbol{t}_N\}^T \text{where}\ \boldsymbol{t_n}=\{t_{n1}, ...,t_{nK}\}$

であり、$\boldsymbol{t} $はここでは1-of-K符号化法を使っています。
(※括弧はその行列(ベクトル)のサイズを示しています)

また各クラスの予測確率はSoftmax関数で定義します。
つまり$y_n$は
$$ y_n=p\left(C_k|\ a_n\right)=\frac{\exp{a_k}}{\sum_{n=1}^{N}\exp{a_n}} $$
ただし$ a_n=\boldsymbol{w}_k^T\boldsymbol{\phi\left(\boldsymbol{x_n}\right)} $、ここで$ \boldsymbol{\phi}()$は基底関数ベクトルです。
交差エントロピー誤差関数は$ E\left(\boldsymbol{w}_1, ..., \boldsymbol{w}_K\right)=-\sum_{n=1}^{N}\sum_{k=1}^{K}{t_{nk}\ln{y_{nk}}} $であり、この誤差関数を最小にするような$ \boldsymbol{W} $を求めるのが目標です。
ニュートンラフソン法を使うので勾配とヘッシアンも計算して

  • 勾配 $(M, 1)$: $ \nabla_{w_j}E\left(\boldsymbol{w}_1, ..., \boldsymbol{w}_K\right)=\sum_{n=1}^{N}\left(y_{nj}-t_{nj}\right)\boldsymbol{\phi}_n $
  • ヘッシアン $(M, M)$: $\boldsymbol{H}_{jk}=\nabla_{w_j}\nabla_{w_k}E\left(\boldsymbol{w}_1, ..., \boldsymbol{w}_K\right)=\sum_{n=1}^{N}y_{nk}\left(\delta_{kj}-y_{nj}\right)\boldsymbol{\phi}_n\boldsymbol{\phi}_n^T$

この形状のままIRLSアルゴリズムを使うことで最適な$\boldsymbol{w}_j$をそれぞれ求めることができます。しかし、ここで知りたいのは全ての$\boldsymbol{w}_k,\ k=(1, ..., K)$を一括で求めるバッチアルゴリズムなので、次のように行列を定義します。

IRLSバッチ処理アルゴリズムの構成

ここでは、先ほど記載した各パラメータ行列の構成が変わるので新しくそれらを定義します.

  • 入力 $(M, 1)$: $\boldsymbol{x}=\{x_1, ...,x_M\}^T$
  • パラメータ $(MK, 1)$: $\hat{\boldsymbol{w}}=\{\boldsymbol{w}_1^T, ..., \boldsymbol{w}_K^T\}^T$
  • 予測値 $(NK, 1)$: $\hat{\boldsymbol{Y}}=\{\boldsymbol{y}_{1}, ...,\boldsymbol{y}_{K}\}^T,\text{where}\ \boldsymbol{y_{k}}=\{y_{1k}, ...,y_{Nk}\}$
  • 目的値 $(NK, 1)$: $\hat{\boldsymbol{T}}=\{\boldsymbol{t}_1, ...,\boldsymbol{t}_N\}^T,\text{where}\ \boldsymbol{t_k}=\{t_{nk}, ...,t_{nk}\}$

このように構成することにより勾配とヘッシアンは

  • 勾配 $(MK, 1)$ $$\nabla_{\hat{w}}E\left(\hat{\boldsymbol{w}}\right)=\hat{\boldsymbol{\Phi}}^T\left(\hat{\boldsymbol{Y}}-\hat{\boldsymbol{T}}\right) $$
  • ヘッシアン $(MK, MK)$
\begin{align}
\hat{\boldsymbol{H}}&=\nabla_{\hat{\boldsymbol{w}}}\nabla_{\hat{\boldsymbol{w}}}E\left(\boldsymbol{\hat{w}}\right)\\
&=\begin{pmatrix}
\boldsymbol{H}_{11} &... & \boldsymbol{H}_{1K} \\
... &...  &...\\
\boldsymbol{H}_{K1} &... & \boldsymbol{H}_{KK} 
\end{pmatrix}\\
&=\begin{pmatrix}
\boldsymbol{\Phi R_{11}\Phi} &... & \boldsymbol{\Phi R_{1K}\Phi} \\
... &...  &...\\
\boldsymbol{\Phi R_{K1}\Phi} &... & \boldsymbol{\Phi R_{KK}\Phi} 
\end{pmatrix}\\
&=\hat{\boldsymbol{\Phi}}^\top \hat{\boldsymbol{R}}\hat{\boldsymbol{\Phi}}
\end{align}

ただし、$\hat{\Phi}$は$(NK, MK)$行列で
$$ \hat{\boldsymbol{\Phi}}=diag(\boldsymbol{\Phi}) $$
$ \boldsymbol{\Phi} $は計画行列でサイズは$(N, M)$です。
また$\boldsymbol{\hat{R}}$は$(NK, NK)$行列で
$$\boldsymbol{\hat{R}}=\left(\begin{array}\
\boldsymbol{R}_{11} &... & \boldsymbol{R}_{1K}\\
... &... &...\\
\boldsymbol{R}_{K1} &... & \boldsymbol{R}_{KK}
\end{array}\right) $$
$$\boldsymbol{R}_{jk}=diag(y_{nk}\left(\delta_{jk}-y_{nj}\right)) $$
$\boldsymbol{R}_{jk}$は$(N, N)$行列です。

IRLSアルゴリズムの式にこれらを代入すると以下の式となります。

\begin{align}
\boldsymbol{\hat{w}}^{(new)}&=\boldsymbol{\hat{w}}^{(old)}-\boldsymbol{\hat{H^{-1}\Phi^T\left(\boldsymbol{\hat{Y}}-\hat{T}\right)}}\\
&=\boldsymbol{\hat{w}}^{(old)}-\boldsymbol{\hat{\boldsymbol{\left(\hat{\Phi}^T\hat{R}\hat{\Phi}\right)}^{-1}\Phi^T\left(\boldsymbol{\hat{Y}}-\hat{T}\right)}}\\
\end{align}

 この式をpythonを使って実装した$※_1$結果がこちらです。

しっかりと、バッチ学習で3クラス分類できています。

おわりに

ブログでもその他、ちょっとした内容を公開しています。
よければご覧になってください。

備考

$※_1$: 人工知能に関する断創録で通常のIRLSアルゴリズムを実装されていたので、参考にしました。