イージーでレイジーな人のための 「サルでも分かる EM アルゴリズム」


機械学習の周辺を見ていると「EM アルゴリズム」という単語をよく目にします。
この分野では重要なアルゴリズムなのですが、正直よく分からない、「最尤推定(さいゆうすいてい)の特殊形」くらいにしか理解できていない、なんなら最尤推定の理解も怪しい。

だから、もう少し理解したいと思っているが、数式はできるだけ見たくない。見ても分からん。
なんというか、こう、正確でなくていいから、数式やら細かい議論やらを全部ぶっ飛ばして、イメージ的に理解したつもりになりたい!雰囲気だけ味わいたい!

本稿は、そんなイージーでレイジーな皆さまのための、EMアルゴリズムの説明です。
内容が不正確であることを承知の上で読んでください。

なお、これから本気で機械学習を勉強していきたいと考えているガチ勢志望は、このあたりを見るなり、普通に PRML を読むなりする方がいいと思います。

最尤推定とは

コイントスを10回繰り返したところ、表が7回、裏が3回出たとします。このとき、このコインの表が出る確率は、どれくらいになるでしょうか?

🟢 :表
🟤 :裏

🟢🟢🟤🟢🟢🟤🟤🟢🟢🟢

・・・はい、もちろん、7/10 です。
だって、10回中7回表が出たという「観測データ」を得ているわけですから。

え?
コインの表が出る確率は 1/2 に決まってるだろって?

それはあなたの勝手な思い込みです。
現に目の前にデータがあって、これを最も尤もらしく説明できる推定(最尤推定)は、7/10 に決まっているのです。

神様しか知り得ない「本当の確率」が 1/2 だったとしても、動かぬ事実として「10回中7回表だった」ことしか観測できない我々下界の人間は、素直に 7/10 と受け止めるしかないのです。

というわけで、Python のプログラムはこちら。
超簡単です。numpy の出番など1ミリもありません。

estimation.py
import random
NUM_TOSS = 10
r = [random.randint(0,1) for i in range(NUM_TOSS)]
print(r.count(1)/NUM_TOSS)

EM アルゴリズムとは

現実の世界では、データを正確に観測できるとは限りません。
データが欠損していたり、潜在変数としてそもそも観測できなかったりすることがほとんどです。

観測データが不完全でも最尤推定したい・・・!!
そんなワガママな欲求を満たすアルゴリズム、そう、それが「EM アルゴリズム」です。

コイントスを10回繰り返したところ、表が5回、裏が2回出て、残りの3回はうまく観測できなかったとします。このとき、このコインの表が出る確率は、どれくらいになるでしょうか?

🟢 :表
🟤 :裏
❌ :不明

🟢🟤❌🟢🟢❌🟤🟢❌🟢

まず、表が出る確率を適当に 1/2 と仮定しましょう。
「さっき勝手な思い込みって言ったじゃないか!」という苦情はスルーです。この 1/2 は単なる初期値です。深い意味などありません。
実際、0より大きく、1より小さい数であれば、どんな値であっても構わないのです。

さて、きちんと表が観測できたところは「1」と数えられるのですが、できなかったところは数えられませんね。そこで、先ほど仮定した「表が出る確率と同じだけの回数分だけ出た」として数えましょう。

要するに、「10回中5回が表」に不明3回分の 1/2+1/2+1/2=1.5回を足して、「表は5+1.5=6.5回出た」と数えましょうということです。

次に、先ほど仮定した表が出る確率を 1/2 → 6.5/10 に更新し、同じ計算をします。つまり、「10回中5回が表」に不明3回分の 6.5/10+6.5/10+6.5/10=1.95回を足して、「表は5+1.95=6.95回出た」とします。

以下、同じことを繰り返します。

6.5/10 → 6.95/10 に更新し、同じ計算をすれば 2.085 回なので、「表は5+2.085=7.085回出た」となり、さらに 6.95/10 → 7.085/10 に更新し、同じ計算をすれば 2.126 回なので、「表は5+2.126=7.126回出た」となり・・・

この計算を繰り返していると、そのうちほとんど値が変化しなくなります。更新幅が小さくなっていくんですね。これを業界用語で「サチる」と言います(「収束する」を意味する「saturation」の俗語)。

というわけで、Python で計算してみましょう。

em_estimation.py
import random
import math

NUM_TOSS = 10

# 1が表、0が裏、2が不明
r = [random.randint(0,2) for i in range(NUM_TOSS)]

THRESHOLD = 0.001
MAX_ITERATION = 20

param = 0.5
log_likelihood = []

for i in range(MAX_ITERATION):
    num_front = r.count(1) + r.count(2) * param
    param_old = param
    param = num_front / NUM_TOSS
    log_likelihood.append(num_front * math.log(param) + (NUM_TOSS - num_front) * math.log(1 - param))
    if abs(param - param_old) <= THRESHOLD:
        break

print(param)

ちょっと真面目な話に入ってみる

さて、先ほどのコイントスの例で「7/10 が最も尤らしい推定値である」と、さも当然のように言い切ったのですが、なぜ言い切れるのでしょうか?「10回中7回表が出たという観測データを得ている」ことを根拠にすると、なんとなく分かったつもりになりますが、なぜそれが根拠になるのでしょうか?

それは、「コインの表(または裏)は二項分布にしたがって観測される」と仮定しているからです。

つまり、尤度関数を二項分布とし、「N 回中 k 回表が出た」という観測が得られた場合、尤度を最大にする(その観測を最も尤もらしく説明できる)二項分布のパラメータは、必ず k/N になるのです → 数学的な証明はこのあたりを読みましょう

EM アルゴリズムでは、k について正確な値が分かりません。だって、何らかの原因で観測できなかったわけですから。でも、繰り返し計算でパラメータを更新するたびに、尤度は大きくなっていく(徐々に最大値に近づいていく=常に尤もらしさが増していく)ことが数学的に保証されています。

先のサンプルプログラム(em_estimation.py)の変数「log_likelihood」(対数尤度)を表示させると、それが分かります。

「表が5回、裏が2回、不明が3回」の場合、対数尤度は次のとおりです。

[-6.4744663903463255, -6.15041454449346, -6.034893004000668, -5.998526353071156, -5.9874605019982035, -5.984126665068666, -5.983125245152495, -5.982824704941832]

ほら、大きくなってる。

もちろん、観測が完全に得られた場合のように、真の最大値がバシッと求まるわけではありません。でも、がんばって繰り返していればだいたい正しいところまで行けることが、「数学的に保証されている」というのは、とても重要です。

だからぁ!…安心して君達はぁ!…パラメータを更新したらいいんだよぉぅ。
おーん → 接点t