Pyroで始めるベイズ推定


ベイズ推定とは

観測値: $X$
パラメータ: $\theta$
変分パラメータ: $\phi$

ある確率モデル$P(X|\theta)$に対して古典的な最尤推定法では尤度を最大にするようなパラメータを点推定します。

\theta^* = \mathrm{argmax}_{\theta} P(X|\theta)

一方でベイズ推定ではパラメータも確率変数とみなし、仮定した事前確率と観測データの尤度から事後確率を求めます。

P(\theta|X) = \frac{P(\theta)P(X|\theta)}{P(X)}

このようにパラメータを点推定ではなく分布で推定することで過学習を防ぐことができるだけでなく、パラメータの分布を用いて不確実性を考慮した意思決定を行うことができます。言うなれば、点推定では最も起こりやすい場合のみ考え、ベイズ推定ではあらゆる場合を考えると言えるでしょう。

ベイズ推定で最も困難な問題は周辺尤度の積分

P(X) = \int d\theta P(\theta)P(X|\theta) = \int d\theta P(X,\theta)

が多くの場合に困難であることに尽きます。簡単なモデルについては共役事前分布を仮定することで解析的に解ける場合もありますが、そうでない場合はLaplace近似、MCMCを用いたサンプリング、変分推定などが使用されます。

変分推定

変分推定(Variational Inference)では事後分布をパラメトリックな特定の分布$Q(\theta|\phi)$で近似し、その変分パラメータ$\phi$を(点)推定するというアプローチを取ります。そして変分パラメータは以下の変分自由エネルギー(変分下限、Evidence Lower Bound, ELBO) $\mathcal{F}(\phi)$を最大にするように設定します。

\begin{align}
\mathrm{ln}P(X) & = \mathrm{ln}\int d\theta P(X,\theta) \\
 & = \mathrm{ln}\int d\theta Q(\theta|\phi) \frac{P(X,\theta)}{Q(\theta|\phi)} \\
 & \geq \int d\theta Q(\theta|\phi)\mathrm{ln} \frac{P(X,\theta)}{Q(\theta|\phi)} \\
 & = \mathrm{E}_{Q(\theta|\phi)} \left( \mathrm{ln}\frac{P(X,\theta)}{Q(\theta|\phi)} \right) \\
 & \equiv \mathcal{F}(\phi)
 \end{align}

3行目の不等式はJensenの不等式として知られています。

なお、ELBOは対数周辺尤度の下限でですが、両者の差は真の事後分布と近似事後分布のKLダイバージェンスという量であり、2つの確率分布の距離を表します。

\begin{align}
\mathrm{ln}P(X) - \mathcal{F}(\phi) & = \mathrm{ln}P(X) - \mathrm{E}_{Q(\theta|\phi)} \left( \mathrm{ln}\frac{P(X,\theta)}{Q(\theta|\phi)} \right) \\
 & = \mathrm{E}_{Q(\theta|\phi)} \left( \mathrm{ln}\frac{Q(\theta|\phi)P(X)}{P(X,\theta)} \right) \\
& = \mathrm{E}_{Q(\theta|\phi)} \left( \mathrm{ln}\frac{Q(\theta|\phi)}{P(\theta | X)} \right) \\
& \equiv \mathrm{KL}[Q|P]
\end{align}

つまり、ELBOを最大するということはKLを最小化するということであり、$Q$と$P$を近づけるということになります。

とまぁ理論は至ってシンプルで美しいのだが、実際の計算はとにかくエグい。例えばWikipediaに載っている簡単な1次元の正規分布の例でも次のような大変な導出を必要とするのでとにかく手計算はツライ。
https://en.wikipedia.org/wiki/Variational_Bayesian_methods#Derivation_of_q(%CE%BC)

Pyroで正規分布のベイズ推定

ここでやっとPyroの登場です。PyroはPyTorchをベースにした確率プログラミングのフレームワークでstanやTensorflow Probabilityと同様に使用することができます。早い話、自動微分や近似を用いてELBOの微分 $\nabla _{\phi} \mathcal{F}(\phi)$ を自動で評価し、SGDを用いて最適化することができます。すごい!!

今回は1次元の正規分布の期待値と標準偏差パラメータの事後分布を求めてみます。

データ準備

# 必要なものをいろいろimport
import matplotlib.pyplot as plt
import torch
import pyro
from pyro.optim import Adam
from pyro.infer import SVI, Trace_ELBO, TracePredictive, EmpiricalMarginal
import pyro.distributions as dist
from torch.distributions import constraints

mu = 10
sigma = 0.5
n = 10
data = torch.randn(n) * sigma + mu

モデル定義

実装に入る前にまずは数式でモデルを書き下してみます。

  1. パラメータ: $\theta = (\mu, \sigma)$
  2. 尤度: $P(X| \theta) = \mathrm{Normal}(X | \mu, \sigma)$
  3. 事前分布
    • $P(\mu) = \mathrm{Normal}(\mu |0, 20)$
    • $P(\sigma) = \mathrm{Exponential}(\sigma|10)$
  4. 変分パラメータ $\phi = (m_1, s_1, m_2, s_2)$
  5. 近似分布
    • $Q(\mu|m_1, s_1) = \mathrm{Normal}(\mu |m_1, s_1)$
    • $Q(\sigma|m_2, s_2) = \mathrm{LogNormal}(\sigma |m_2, s_2)$

事前分布はできるだけ裾野の長い分布を使用します。期待値パラメータは任意に字数なので正規分布、標準偏差パラメータは正の実数なので指数分布を使用します。近似分布も同様に設定しますが、指数分布だとパラメータが1つしか無く自由度が低いので対数正規分布を使用します。共役性を仮定しないのでこのように自由にモデリングできるのがいいですね。

実装

Pyroではmodelguideという2つの関数を作成します。前者は事前分布と尤度、後者は近似分布を表します。

前節の式を実装していきます。

def model(data):
    # muの事前分布
    m = torch.tensor(0.0)
    s = torch.tensor(20.0)
    mu = pyro.sample("mu", dist.Normal(m, s))

    # sigmaの事前分布
    lam = torch.tensor(10.0)
    sigma = pyro.sample("sigma", dist.Exponential(lam))

    # 尤度
    with pyro.plate("data_plate", len(data)):
        pyro.sample("obs", dist.Normal(mu, sigma))


def guide(*args, **kwds):
    # muの変分パラメータと近似分布
    m1 = pyro.param("m1", torch.tensor(0.0))
    s1 = pyro.param("s1", torch.tensor(1.0),
                      constraint=constraints.positive)
    mu = pyro.sample("mu", dist.Normal(m1, s1))

    # muの変分パラメータと近似分布
    m2 = pyro.param("m2", torch.tensor(0.0))
    s2 = pyro.param("s2", torch.tensor(1.0),
                      constraint=constraints.positive)
    sigma = pyro.sample("sigma", dist.LogNormal(m2, s2))

    return mu.item(), sigma.item()

まずmodelですが、事前分布の定数をTensorで定義し、pyro.sampleに変数名と事前分布を渡して計算グラフを作成します。尤度も同様にpyro.sampleを使用しますが、dataが配列であることを明示的に扱うためにpyro.plateというコンテキストを使用します。

guideでは最適化したい変分パラメータをpyro.paramに名前と初期値を渡して作り、それをpyro.sampleに渡します。推定したい事後分布の変数名(この場合はmusigma)はmodelguideで必ず対応するようにpyro.sampleに渡します。

モデルの定義ができたのでパラメータの推定を行います。基本的にはPyTorchと同じで目的関数、最適化アルゴリズムを指定してループを回します。目的関数はTrace_ELBO、最適化アルゴリズムはAdamを使用します。AdamはPyroのwrapper版を使用します。また、実際の観測データで条件化したobs_modelpyro.conditionで作成し、最適化時にはmodelの代わりにこちらを使用します。モデル、最適化アルゴリズム、目的関数をまとめてSVIクラスを作成し、このオブジェクトのstepを繰り返し呼ぶことで最適化を実行できます。Pyroではグローバル変数に現在の計算グラフの各変数を保存しているため、実際に最適を実行する前に一度だけclear_param_store関数で初期化します。

obs_model = pyro.condition(model, data={"obs": data})
optim = Adam({"lr": 0.2})
svi = SVI(model=obs_model, guide=guide, optim=optim,
          loss=Trace_ELBO(), num_samples=100)

pyro.clear_param_store()
n_steps = 500
losses = []
for i in range(n_steps):
    loss = svi.step(data)
    losses.append(loss)

ELBOの最適化状況をプロットします。

plt.plot(losses)

学習初期に随分暴れていますが最後らへんは安定しています。また、求まった変分パラメータを表示します。

for k, v in pyro.get_param_store().items():
    print(k, v.item())

>>>
m1 9.941473960876465
s1 0.1708463728427887
m2 -0.7829241752624512
s2 0.18723134696483612
\begin{align}
Q(\mu|m_1, s_1) & = \mathrm{Normal}(\mu |9.94, 0.171) \\
Q(\sigma|m_2, s_2) & = \mathrm{LogNormal}(\sigma |-0.783, 0.187)
\end{align}

と求まりました。

変分事後分布をプロットしてみます。

# mu
f = dist.Normal(9.94, 0.171)
x = torch.linspace(8, 12, 100)
p = f.log_prob(x).exp()
plt.plot(x, p)

# sigma
f = dist.LogNormal(-0.784, 0.154)
x = torch.linspace(0, 1, 100)
p = f.log_prob(x).exp()
plt.plot(x, p)