正規分布の尤度を算出するまでの手順


統計学初歩の「尤度」なるものに躓いてしまいました。コマンドを動かしながら正規分布の(同時)確率がどのように変化するのか観察してみます。

正規分布

毎度お馴染み、正規分布は以下の通り表されます。わかりやすさのために、左辺を$P(x)$と表現しています。

$$P(x)={1 \over \sqrt{2\pi\sigma^{2}}} \exp \left(-{1 \over 2}{(x-\mu)^2 \over \sigma^2} \right)$$

ライブラリ
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sns
import numpy.random as rd
import matplotlib.gridspec as gridspec
%matplotlib inline
plt.rcParams['font.size']=15

def plt_legend_out():
    plt.legend(bbox_to_anchor=(1.05, 1), loc='upper left', borderaxespad=0)

コード
m = 10
s = 3

min_x = m-4*s
max_x = m+4*s

x = np.linspace(min_x, max_x, 201)
y = (1/np.sqrt(2*np.pi*s**2))*np.exp(-0.5*(x-m)**2/s**2)

plt.xlim(min_x, max_x)
plt.ylim(0,max(y)*1.1)
plt.plot(x,y)
plt.show()

正規分布からランダムにデータを10個取り出してみます。

コード
plt.figure(figsize=(8,1))
rd.seed(7)
data = rd.normal(10, 3, 10, )
plt.scatter(data, np.zeros_like(data), c="r", s=50)
plt.tick_params(left=False,labelleft=False)
plt.axhline(y=0,color='gray',lw=0.5)
plt.show()

同時確率

上記のデータが同時に発生する確率は、以下の通りです。

\begin{eqnarray}
\prod_{i=1}^NP(x) &=& P(x_1, x_2,\cdots,x_{10})\\
&=& P(x_1)P(x_2)\cdots P(x_{10})
\end{eqnarray}

ただし上記の表し方ですと、計算する際の困りごとがあります。確率同士の積が、かなり小さい値になってしまいます。小数点以下の計算は、0.01×0.01=0.001のようにゼロが増えていきます。これが10回、100回続くと、ゼロが多くなってややこしいです。そこで、$log$をとってみます。

\begin{eqnarray}
\prod_{i=1}^{10}\log{P(x_i)} &=& \log{P(x_1,x_2,\cdots,x_{10})}\\
&=& \log{(P(x_1)×P(x_2)\cdots×P(x_{10}))}\\
&=& \log{P(x_1)}+\log{P(x_2)}+\cdots+\log{P(x_{10})}\\
&=& \color{red}{\sum_{i=1}^{10}\log{P(x_i)}}
\end{eqnarray}

$log$をとることで、足し算の問題に置き換えることが出来ました。本ケースの$P(x)$は正規分布のことです。よって、上記データの同時確率は以下の通りです。

\begin{eqnarray}
\color{red}{\sum_{i=1}^{10}\log{P(x_i)}} &=& \sum_{i=1}^N{1 \over \sqrt{2\pi\sigma^{2}}} \exp \left(-{1 \over 2}{(x_i-\mu)^2 \over \sigma^2} \right)
\end{eqnarray}

実装

ここで、$\sigma=3$を固定して、$\mu=5, 10, 15$の正規分布を考えてみます。先ほど取り出したデータも併せてプロットしてみます。視覚的には、$\mu=10$と$\sigma=3$の条件が当てはまりが良いように見えます。

関数
def norm_dens(x,m,s):
    return (1/np.sqrt(2*np.pi*s**2))*np.exp(-0.5*(x-m)**2/s**2)

def log_likelihood(x,m,s):
    L = np.prod([norm_dens(x_i,m,s) for x_i in x])
    l = np.log(L)
    return l

コード
logp_ymin = -10 ;logp_ymax = 0
d_ymin = -0.01 ; d_ymax = 0.2

plt.figure(figsize=(10,2))
plt.subplots_adjust(hspace=0.1, wspace=0.1)

x = np.linspace(0, 20, 100)

###########

plt.subplot(131)
m=5 ; s=3
y = norm_dens(x,m,s)
plt.plot(x,y,label='$\mu=$'+str(m)+'$, \sigma=$'+str(s))
plt.scatter(data, np.zeros_like(data), c="r", s=50)
plt.ylabel('density')
plt.axhline(y=0,color='gray',lw=0.5)
plt.ylim(d_ymin,d_ymax)
plt.xlim(0,20)
plt.tick_params(bottom=False,labelbottom=False)
plt.title('$\mu=$'+str(m)+', $\sigma=$'+str(s))

###########

plt.subplot(132)
m=10 ; s=3
y = norm_dens(x,m,s)
plt.plot(x,y,label='$\mu=$'+str(m)+'$, \sigma=$'+str(s))
plt.scatter(data, np.zeros_like(data), c="r", s=50)
plt.axhline(y=0,color='gray',lw=0.5)
plt.xlim(0,20)
plt.ylim(d_ymin,d_ymax)
plt.tick_params(bottom=False,labelbottom=False,left=False,labelleft=False)
plt.title('$\mu=$'+str(m)+', $\sigma=$'+str(s))

###########

plt.subplot(133)
m=15 ; s=3
y = norm_dens(x,m,s)
plt.plot(x,y,label='$\mu=$'+str(m)+'$, \sigma=$'+str(s))
plt.scatter(data, np.zeros_like(data), c="r", s=50)
plt.axhline(y=0,color='gray',lw=0.5)
plt.xlim(0,20)
plt.ylim(d_ymin,d_ymax)
plt.tick_params(bottom=False,labelbottom=False,left=False,labelleft=False)
plt.title('$\mu=$'+str(m)+', $\sigma=$'+str(s))

plt.show()

では、$\log{P(x)}$も併せてプロットしてみましょう。$\mu=10$と$\sigma=3$の条件で$\sum{\log{P(x)}}$が最も大きくなりました。

コード
logp_ymin = -10 ;logp_ymax = 0
d_ymin = -0.01 ; d_ymax = 0.2

plt.figure(figsize=(10,2))
plt.subplots_adjust(hspace=0.1, wspace=0.1)

x = np.linspace(0, 20, 100)

###########

plt.subplot(131)
m=5 ; s=3
y = norm_dens(x,m,s)
plt.plot(x,y,label='$\mu=$'+str(m)+'$, \sigma=$'+str(s))
plt.scatter(data, np.zeros_like(data), c="r", s=50)
plt.ylabel('density')
plt.axhline(y=0,color='gray',lw=0.5)
plt.ylim(d_ymin,d_ymax)
plt.xlim(0,20)
plt.tick_params(bottom=False,labelbottom=False)
plt.title('$\mu=$'+str(m)+', $\sigma=$'+str(s))

###########

plt.subplot(132)
m=10 ; s=3
y = norm_dens(x,m,s)
plt.plot(x,y,label='$\mu=$'+str(m)+'$, \sigma=$'+str(s))
plt.scatter(data, np.zeros_like(data), c="r", s=50)
plt.axhline(y=0,color='gray',lw=0.5)
plt.xlim(0,20)
plt.ylim(d_ymin,d_ymax)
plt.tick_params(bottom=False,labelbottom=False,left=False,labelleft=False)
plt.title('$\mu=$'+str(m)+', $\sigma=$'+str(s))

###########

plt.subplot(133)
m=15 ; s=3
y = norm_dens(x,m,s)
plt.plot(x,y,label='$\mu=$'+str(m)+'$, \sigma=$'+str(s))
plt.scatter(data, np.zeros_like(data), c="r", s=50)
plt.axhline(y=0,color='gray',lw=0.5)
plt.xlim(0,20)
plt.ylim(d_ymin,d_ymax)
plt.tick_params(bottom=False,labelbottom=False,left=False,labelleft=False)
plt.title('$\mu=$'+str(m)+', $\sigma=$'+str(s))

plt.show()

次は$\mu=10$を固定して$\sigma=2,3,5$の正規分布を見てみましょう。先ほどの大きな違いは見られないものの、$\mu=10$と$\sigma=3$の条件で最も$\sum{\log{P(x)}}$が大きくなりました。$\sigma=5$の分布はデータに照らし合わせて広すぎる印象です。つまり$\sum{\log{P(x)}}$が大きい(同時確率が大きい)条件で、データと分布の当てはまりが良いことがわかります。

コード
logp_ymin = -6 ;logp_ymax = 0
d_ymin = -0.01 ; d_ymax = 0.25

plt.figure(figsize=(10,4))
plt.subplots_adjust(hspace=0.1, wspace=0.1)

x = np.linspace(0, 20, 100)

###########

plt.subplot(231)
m=10 ; s=2
y = norm_dens(x,m,s)
plt.plot(x,y,label='$\mu=$'+str(m)+'$, \sigma=$'+str(s))
plt.scatter(data, np.zeros_like(data), c="r", s=50)
plt.ylabel('density')
plt.axhline(y=0,color='gray',lw=0.5)
plt.ylim(d_ymin,d_ymax)
plt.xlim(0,20)
plt.tick_params(bottom=False,labelbottom=False)
plt.title('$\mu=$'+str(m)+', $\sigma=$'+str(s))

plt.subplot(234)
xl = [np.log(norm_dens(x,m,s)) for x in data]
plt.scatter(data,xl,color='red')
plt.xlim(0,20)
plt.ylabel('log p(x)')
plt.ylim(logp_ymin,logp_ymax)
plt.xlabel('x')
for i in range(len(data)):plt.plot([data[i],data[i]],[0,xl[i]],color='gray',lw=0.5,ls='dashed')
plt.text(3,-10,'$\sum{\log{p(x)}}$='+str(np.round(np.sum(xl),1)))

###########

plt.subplot(232)
m=10 ; s=3
y = norm_dens(x,m,s)
plt.plot(x,y,label='$\mu=$'+str(m)+'$, \sigma=$'+str(s))
plt.scatter(data, np.zeros_like(data), c="r", s=50)
plt.axhline(y=0,color='gray',lw=0.5)
plt.xlim(0,20)
plt.ylim(d_ymin,d_ymax)
plt.tick_params(bottom=False,labelbottom=False,left=False,labelleft=False)
plt.title('$\mu=$'+str(m)+', $\sigma=$'+str(s))

plt.subplot(235)
xl = [np.log(norm_dens(x,m,s)) for x in data]
plt.scatter(data,xl,color='red')
plt.xlim(0,20)
plt.ylim(logp_ymin,logp_ymax)
plt.tick_params(left=False,labelleft=False)
plt.xlabel('x')
for i in range(len(data)):plt.plot([data[i],data[i]],[0,xl[i]],color='gray',lw=0.5,ls='dashed')
plt.text(3,-10,'$\sum{\log{p(x)}}$='+str(np.round(np.sum(xl),1)))

###########

plt.subplot(233)
m=10 ; s=5
y = norm_dens(x,m,s)
plt.plot(x,y,label='$\mu=$'+str(m)+'$, \sigma=$'+str(s))
plt.scatter(data, np.zeros_like(data), c="r", s=50)
plt.axhline(y=0,color='gray',lw=0.5)
plt.xlim(0,20)
plt.ylim(d_ymin,d_ymax)
plt.tick_params(bottom=False,labelbottom=False,left=False,labelleft=False)
plt.title('$\mu=$'+str(m)+', $\sigma=$'+str(s))

plt.subplot(236)
xl = [np.log(norm_dens(x,m,s)) for x in data]
plt.scatter(data,xl,color='red')
plt.xlim(0,20)
plt.ylim(logp_ymin,logp_ymax)
for i in range(len(data)):plt.plot([data[i],data[i]],[0,xl[i]],color='gray',lw=0.5,ls='dashed')
plt.tick_params(left=False,labelleft=False)
plt.xlabel('x')
plt.text(3,-10,'$\sum{\log{p(x)}}$='+str(np.round(np.sum(xl),1)))
plt.show()

最後に$\mu$と$\sigma$を細かく探索してみます。探索といっても、網羅的に探索しているだけです。結果、真値に近しい値を推定出来ました。

コード
mus = np.linspace(8, 12, 50)
ss  = np.linspace(2, 4, 50)
lmu = [] ; ls = [] ; lll = []

for mu in mus:
    for s in ss:
        lmu.append(mu)
        ls.append(s)
        lll.append(log_likelihood(data,mu,s))

plt.scatter(lmu,ls,c=lll,alpha=0.8)
plt.xlabel('$\mu$')
plt.ylabel('$\sigma$')
plt.colorbar()
plt.scatter(10,3,color='r')
plt.text(10.1,3.1,'true',color='r')

pmu,ps,pll = pd.DataFrame([lmu,ls,lll]).T.sort_values(2,ascending=False).reset_index(drop=True).loc[0,:].to_numpy()
plt.scatter(pmu,ps,color='b')
plt.text(pmu+0.1,ps+0.1,'predicted',color='b')

plt.title('同時確率')
plt.show()

同時確率から尤度へ

今回は「正規分布からデータが発生した」ことがわかっていました。しかし現実問題では、データの背景にある確率分布を仮定することになります。その際は「同時確率」から「尤度」と表現を置き換えて分布のパラメータを推定することになります。やっていることは同じです。

参考URL