【教師あり学習: 分類】交差エントロピー


はじめに

機械学習に関する書籍の中で損失関数の交差エントロピーがよくわからなかったので調べました!

交差エントロピーは二種類ある教師あり学習(①回帰②分類)の中で後者である分類の損失関数として使用されます。

交差エントロピーは情報理論で登場する「情報エントロピー」と統計学で登場する「尤度関数・最尤推定」の二通りから考えることができます。この記事では後者の尤度関数・最尤推定から交差エントロピーを考えたいと思います!

注意

分類問題のシステム自体(どのように出力データを出すか、そもそも損失関数は何か)については割愛しています。

まずは導入として最尤推定についてみていきます!

尤度と最尤推定

「ある条件」の下で観測値が得られたときに、その観測値を使って「ある条件」がどんなものであるかを推定することを最尤推定といいます。その際に出てくる「尤度関数」、「尤度」についてみていきたいと思います。(この記事では尤度関数と尤度は別の用語として扱います)

ここで出てくる尤度関数、尤度が次の章で扱う交差エントロピーで大事になってきます!

では例を見ていきましょう。

「次出会う人が左利きである確率」を最尤推定してみます。街に出て出会った人にアンケートをしました。

1 2 3 4 5
結果

アンケートをとった5人中2人が左利きでした。この結果から次会う人が左利きである確率を次のように設定します。

\begin{align}
(左利き) &= p \\
(左利きでない) &= 1-p \\
\end{align}

次に、1人目が左利き、2人目が左利き、3人目が右利き、4人目が右利き、5人目が右利きである確率Pを考えます。これが尤度関数です。

\begin{align}
P &= p\times p\times (1-p)\times (1-p)\times (1-p) \\
&= p^2\times (1-p)^3 \\
\end{align}

みてわかる通り、尤度関数Pは左利きである確率pの関数となりました。では、この関数にp=0.1(次出会う人が10%の確率で左利き)を代入してみます。

\begin{align}
P &= 0.007 \\
\end{align}

となりました。この計算によって求めた値のことを尤度(ゆうど)といいます。この値自体に意味はなく、他のpを代入したときの尤度と比べて、どちらのpが尤もらしいかを判断します。ではp=0.5を代入したらどうなるでしょうか。

\begin{align}
P &= 0.031 \\
\end{align}

先ほどよりは大きい尤度になりました。よってp=0.5の方がp=0.1より尤もらしい数字であるということがわかります。では一番尤もらしいpを対数やら微分やらを用いて求めてみます。

\begin{align}
log\, P &= 2\,log\, p+ 3\,log\, (1-p) \\
\frac{d}{dp}log\, P &= \frac{2}{p}- \frac{3}{1-p}=0  \\
\Leftrightarrow p = \frac{2}{5} \\
\end{align}

となりました。この値が尤もらしいpということですね。これで最尤推定することができました。

ちなみに、p=2/5のときの尤度は

\begin{align}
P &= \bigl( \frac{2}{5} \bigr)^2\times \bigl(1-\frac{2}{5} \bigr)^3 \\
&= 0.035 \\
\end{align}

と先ほどの二つより高い値ですね。よく自然対数をとることによって尤度の大小関係をみやすくする工夫がされることがあります。

・p=0.1のとき

\begin{align}
log(0.007) &= -4.96 \\
\end{align}

・p=0.5のとき

\begin{align}
log(0.031) &= -3.47 \\
\end{align}

・p=0.4(2/5)のとき

\begin{align}
log(0.035) &= -3.35 \\
\end{align}

(マイナスがあるせいでそんなにみやすくなさそう、、笑)

次は交差エントロピーです。

交差エントロピー

ここでは左利きと右利きを分けたいとします。

正解値として、左利きと右利きの真の値を

\begin{align}
左 &= 1 \\
右 &= 0 \\
\end{align}

とします。

正解値を予測する予測機に、出会った5人が左利きかそうでないかを推定してもらいました。

1 2 3 4 5
正解値 1 1 0 0 0
予測値 0.5 0.5 0.5 0.5 0.5

予測値は全て0.5となりました。予測値は1に近い方が左利きと推定されやすく、0に近い方が右利きと推定されやすいので、かなり精度が悪いですね。

このときの尤度関数を考えたいと思います。1人目が左利き、2人目が左利き、3人目が右利き、4人目が右利き、5人目が右利きである確率Pです。

\begin{align}
P &= p_{1}\times p_{2}\times (1-p_{3})\times (1-p_{4})\times (1-p_{5}) \\
&= 0.5\times 0.5\times (1-0.5)\times (1-0.5)\times (1-0.5) \\
&= 0.031 \\
\end{align}

p1、p2、p3、p4、p5は全てその人が左利きであると推定する確率を表しています。「尤度と最尤推定」の例でもあったようにこの0.031は尤度です。

さらに、この値にある細工を加えます。自然対数をとり、マイナスをかけちゃうんです。

\begin{align}
-log(0.031) &= 3.47 \\
\end{align}

この数字がこの予測値の交差エントロピーです。

尤度を求めることによって交差エントロピーもほぼ求まったようなもんじゃありませんか

この計算過程からもわかりますが、尤度にちょっとした細工を加えているだけなので、交差エントロピーも尤度同様に値自体に意味はありません。

この予測器は精度が悪いので、別のものに変えた結果次のようになりました。

1 2 3 4 5
正解値 1 1 0 0 0
予測値 0.8 0.9 0.1 0.1 0.1

それぞれの予測値が正解値に近づいたので、先ほどの予測機より精度はいいですね。

このときの尤度関数を考えます。

\begin{align}
P &= p_{1}\times p_{2}\times (1-p_{3})\times (1-p_{4})\times (1-p_{5}) \\
&= 0.8\times 0.9\times (1-0.1)\times (1-0.1)\times (1-0.1) \\
&= 0.525 \\
\end{align}

尤度から交差エントロピーを求めます。

\begin{align}
-log(0.525) &= 0.645 \\
\end{align}

あれ、尤度のときは予測精度がいい方が値が大きかったのに、交差エントロピーのときは予測精度がいい方が値が小さいではありませんか!

これは尤度にマイナスをかけているためです。

交差エントロピーは損失関数である以上、予測精度がよければ、その分損失が少ないという意味で損失関数の値が小さくなっていた方が直感的にわかりやすいということから、マイナスをつけているだけです。また、自然対数もその後の計算をしやすくするためにしているだけです。

簡単な説明は以上です。次はより詳細に説明します。

より実践っぽく

one-hotベクトルを使って

one-hotベクトルなるものをご存知でしょうか。

\begin{align}
左 &= [1,0] \\
右 &= [0,1] \\
\end{align}

こうやって0と1で左と右を区別しています。

これは二つを分けていますが、三つやそれ以上の場合もone-hotベクトルにすることができます。

\begin{align}
赤 &= [1,0,0] \\
青 &= [0,1,0] \\
黄 &= [0,0,1] \\
\end{align}

こんな感じです。

このone-hotベクトルは正解値として振舞います。次は実際に、交差エントロピーが損失関数として利用される二値分類や多値分類をみていきます。

二値分類

他の記事などでよく見る交差エントロピーの公式?みたいなやつを考えます。

\begin{align}
E=- \sum_{i=1} \; t\,log\,p \\
\end{align}

tに正解値(one-hotベクトル)、pに予測値を代入します。その例をみましょう。

1 2 3 4 5
正解値 1 1 0 0 0
予測値 0.8 0.9 0.1 0.1 0.1

先ほど使った表です(より予測精度がいい方)。この表の正解値と予測値をone-hotベクトルでみると、

\begin{align}
(正解値)&=(t_{1},t_{2},t_{3},t_{4},t_{5}) = \\
\end{align}
\begin{pmatrix}
1 & 1 & 0 & 0 & 0 \\
0 & 0 & 1 & 1 & 1 \\
\end{pmatrix}
\begin{align}
(予測値)&=(p_{1},p_{2},p_{3},p_{4},p_{5}) = \\
\end{align}
\begin{pmatrix}
0.8 & 0.9 & 0.1 & 0.1 & 0.1 \\
0.2 & 0.1 & 0.9 & 0.9 & 0.9 \\
\end{pmatrix}

のようになります。正解値、予測値はそれぞれ行列の一行目が表の値と対応しています。また、新たに登場した行列の二行目は右利きに対する正解値、予測値が対応しています。

交差エントロピーを求めてみます。

\begin{align}
E&=- \sum_{i=1} \; t\,log\,p \\
&=-\bigl( t_{1}\,log\,p_{1}+ t_{2}\,log\,p_{2}+ t_{3}\,log\,p_{3}+ t_{4}\,log\,p_{4}+ t_{5}\,log\,p_{5} \bigr) \\
\end{align}
=
-
\Biggl(\;
\begin{pmatrix}
1  \\
0  \\
\end{pmatrix}^T  
\begin{pmatrix}
log\,0.8  \\
log\,0.2  \\
\end{pmatrix}
+
\begin{pmatrix}
1  \\
0  \\
\end{pmatrix}^T  
\begin{pmatrix}
log\,0.9  \\
log\,0.1  \\
\end{pmatrix}
+
\begin{pmatrix}
0  \\
1  \\
\end{pmatrix}^T  
\begin{pmatrix}
log\,0.1  \\
log\,0.9  \\
\end{pmatrix}
+
\begin{pmatrix}
0  \\
1  \\
\end{pmatrix}^T  
\begin{pmatrix}
log\,0.1  \\
log\,0.9  \\
\end{pmatrix}
+
\begin{pmatrix}
0  \\
1  \\
\end{pmatrix}^T  
\begin{pmatrix}
log\,0.1  \\
log\,0.9  \\
\end{pmatrix}\;
\Biggr)
\begin{align}
&=-\bigl( log\,0.8+ log\,0.9+ log\,0.9+ log\,0.9+ log\,0.9 \bigr) \\
&=0.645 \\
\end{align}

となりました。「交差エントロピー」のところで出した値と同じになりましたね。

二値分類でone-hotベクトルに慣れてところで、次は多値分類の場合もみたいと思います。

多値分類

ここでは、赤・青・黄の三種類を分類するとします。正解値、予測値は以下のようになったとします。

\begin{align}
(正解値)&=\\
\end{align}
\begin{pmatrix}
1 & 1 & 0 & 0 & 0 \\
0 & 0 & 1 & 0 & 0 \\
0 & 0 & 0 & 1 & 1 \\
\end{pmatrix}
\begin{align}
(予測値)&=\\
\end{align}
\begin{pmatrix}
0.7 & 0.8 & 0.1 & 0.1 & 0.1 \\
0.2 & 0.1 & 0.8 & 0.2 & 0.2 \\
0.1 & 0.1 & 0.1 & 0.7 & 0.7 \\
\end{pmatrix}

正解値の行列は一行目が赤、二行目が青、三行目が黄を表しています。

二値分類のときと同様に交差エントロピーを求めます。

\begin{align}
E&=- \sum_{i=1} \; t\,log\,p \\
&=-\bigl( t_{1}\,log\,p_{1}+ t_{2}\,log\,p_{2}+ t_{3}\,log\,p_{3}+ t_{4}\,log\,p_{4}+ t_{5}\,log\,p_{5} \bigr) \\
\end{align}
=
-
\Biggl(\;
\begin{pmatrix}
1  \\
0  \\
0  \\
\end{pmatrix}^T  
\begin{pmatrix}
log\,0.7  \\
log\,0.2  \\
log\,0.1  \\
\end{pmatrix}
+
\begin{pmatrix}
1  \\
0  \\
0  \\
\end{pmatrix}^T  
\begin{pmatrix}
log\,0.8  \\
log\,0.1  \\
log\,0.1  \\
\end{pmatrix}
+
\begin{pmatrix}
0  \\
1  \\
0  \\
\end{pmatrix}^T  
\begin{pmatrix}
log\,0.1  \\
log\,0.8  \\
log\,0.1  \\
\end{pmatrix}
+
\begin{pmatrix}
0  \\
0  \\
1  \\
\end{pmatrix}^T  
\begin{pmatrix}
log\,0.1  \\
log\,0.2  \\
log\,0.7  \\
\end{pmatrix}
+
\begin{pmatrix}
0  \\
0  \\
1  \\
\end{pmatrix}^T  
\begin{pmatrix}
log\,0.1  \\
log\,0.2  \\
log\,0.7  \\
\end{pmatrix}\;
\Biggr)
\begin{align}
&=-\bigl( log\,0.7+ log\,0.8+ log\,0.8+ log\,0.7+ log\,0.7 \bigr) \\
&=1.51 \\
\end{align}

ありがとうございました!

参考リンク

この記事を書くに当たって、
損失関数)クロスエントロピー誤差
[ 機械学習 ] 損失関数の実装 (二乗和誤差, 交差エントロピー誤差)
【統計学】尤度って何?をグラフィカルに説明してみる。
が大変参考になりました。ありがとうございました!