エントロピー・KL divergenceの復習


はじめに

最近、ベイズ推論を勉強しています。
ベイズ学習は、エントロピーKL divergenceなどの理解は避けては通れません。
今まで、何となく理解していましたが、変分推論あたりで「?」が大量発生しました。
これはエントロピーKL divergenceを復習しないといけないと思い、内容を整理することにしました。
この記事では、数式による導出よりも、大まかな「意味」を理解することを目指します。

エントロピー

確率分布の「複雑さ(予測のしにくさ)」を示す値です。
以下の式で計算されます。

{\rm H}[p(x)]=-p(x)\ln p(x)

「普通のサイコロ」と「六面体が歪んだサイコロ」を比べた場合、普通のサイコロの方がどの目が出るか予測が難しいです。(普通のサイコロの方がエントロピーが高くなります。)

pythonでは、scipyを用いて簡単にエントロピーが計算できます。

離散確率分布

from scipy.stats import multinomial

multinomial.entropy(n=1, p=[1 / 6.] * 6)  # 1.79175946923
multinomial.entropy(n=1, p=[0.4, 0.3, 0.1, 0.1, 0.05, 0.05])  # 1.48779838

上の普通のサイコロの方がエントロピーが高くなっているのが分かります。

連続確率分布
ガウス分布の場合も同様です。
以下のような「平均値は同じだが、分散は異なる分布」があったとします。

値の「予測のしにくさ」という観点で比べた場合、分散が大きい分布(オレンジ色の分布)の方が予測が難しいです。
なので、ガウス分布の場合、分散の大きさがエントロピーに比例することとなります。
これも、pythonで書いてみます。

from scipy.stats import norm

print norm.entropy(loc=0.0, scale=1.0)  # 1.4189385332
print norm.entropy(loc=0.0, scale=2.0)  # 2.11208571376

分散が大きい分布の方が、エントロピーが高くなることが確認できました。

交差エントロピー

KL divergenceを理解するためには、「交差エントロピー」についても知っておく必要があります。
交差エントロピーは下式で定義されます。

H(p,q)=-q(x)\ln p(x)

交差エントロピーは、2つの確率分布の間に定義される尺度として用いられます。

情報量基準の説明では、ビット計算の例を用いて説明したりしています。
Wikipedia にも以下の通り説明させています。

符号化方式が、真の確率分布 p(x) ではなく、ある所定の確率分布 q(x) に基づいている場合に、とりうる複数の事象の中からひとつの事象を特定するために必要となるビット数の平均値を表す。

これを読んでも、今いちピンときませんでした。。
そのまま、変分推論の勉強をしても、やっぱり何となくでしか分かりませんでした。
なので、少しだけ突っ込んでみました。

簡単な例で考えてみます。
ベルヌーイ分布に従う2つの確率分布 $p(x), q(x)$ を考えます。
観測データから、モデル $q(x)$ のパラメータ $\mu=0.7$ を推定しました。
まず、この場合のエントロピーは以下となります。

\begin{align}
{\rm H}[q(x)] &=-\sum_xq(x)\ln q(x)\\
&=-\{(0.7\times\ln (0.7))+(0.3\times\ln (0.3))\}\\
&=0.61086\cdots
\end{align}

真の確率分布 $p(x)$ のパラメータも $\mu=0.7$ だったとします。

交差エントロピーを計算すると、

\begin{align}
H(p,q)&=-\sum_xq(x)\ln p(x)\\
&=-\{(0.7\times\ln (0.7))+(0.3\times\ln (0.3))\}\\
&=0.61086\cdots
\end{align}

となり、「エントロピー=交差エントロピー」となります。
では、真の確率分布 $p(x)$ のパラメータが $\mu=0.5$ だったとします。
その場合の交差エントロピーを計算すると、

\begin{align}
H(p,q)&=-\sum_xq(x)\ln p(x)\\
&=-\{(0.7\times\ln (0.5))+(0.3\times\ln (0.5))\}\\
&=0.69314\cdots
\end{align}

と少しだけ値が大きくなることが分かります。

では、真の確率分布 $p(x)$ のパラメータが $\mu=0.3$ の場合はどうでしょうか。
交差エントロピーは、

\begin{align}
H(p,q)&=-\sum_xq(x)\ln p(x)\\
&=-\{(0.7\times\ln (0.3))+(0.3\times\ln (0.7))\}\\
&=0.94978\cdots
\end{align}

と更に大きくなりました。

計算してみると分かりますが、「$q(x)$=$p(x)$ で交差エントロピーは最小(=エントロピー)」となります。
$q(x)$=$p(x)$ が離れるほど、交差エントロピーは大きくなります。
交差エントロピーは「真の確率分布 $p(x)$ を仮定して生成された、観測データの分布($q(x)$)の予測のしにくさ」とでも言うのでしょうか。
真の分布と観測データの分布が一致している場合、交差エントロピーは「元々の確率分布が持つ予測のしにくさ(エントロピー)」だけです。しかし、真の分布と観測データの分布が一致していない場合、エントロピーに加えて、「分布がズレている分」だけ予測が難しくなります。

ピンとこない表現かもしれませんが、これが交差エントロピーの意味だと思います。

KL divergence

2つの確率分布の「差異」を計る尺度として用いられます。
公式としては以下の通りです。

\begin{align}
{\rm KL}[q(x)||p(x)] &= q(x)\ln\frac{p(x)}{q(x)}\\
&=q(x)\ln p(x) - q(x)\ln q(x)
\end{align}

交差エントロピーさえ分かれば、KL divergenceは「交差エントロピー - エントロピー」を計算するだけです。
この引き算は何を意味しているのでしょうか。
エントロピーと交差エントロピーの説明を復習します。

  • エントロピー・・ある確率分布の予測のしにくさ
  • 交差エントロピー・・ある確率分布を仮定して生成された、別の確率分布の予測のしにくさ

これを引くことで、「ある確率分布の予測のしにくさ」が引かれ、「ある確率分布と別の確率分布のズレ」が求まります。(言葉で説明すると余計に分かりにくい気がしますが・・)

あと、常に「交差エントロピー ≧ エントロピー」なので、KL divergenceは必ず非負値となります。

最後に、これをpythonで求めてみます。

離散確率分布

from scipy.stats import norm, entropy

p = 0.5
q = 0.7

slf_ent = -1 * (q * np.log(q) + (1 - q) * np.log(1 - q))
crs_ent = -1 * (q * np.log(p) + (1 - q) * np.log(1 - p))
print slf_ent  # 0.610864302055
print crs_ent  # 0.69314718056
print crs_ent - slf_ent  # 0.0822828785051
print entropy([q, 1 - q], [p, 1 - p])  # 0.0822828785051

連続確率分布

from scipy.stats import norm, entropy

x = np.linspace(-5, 5, 500)
p = norm.pdf(x, loc=0, scale=1.0)
q = norm.pdf(x, loc=1.0, scale=1.0)
print entropy(p, q)  # 0.499970194448

以上です。
次回は、今回学んだことをベースに変分ベイズに関してまとめたいと思います。