Box-Cox変換を含んだモデルをベイズ推定する


やりたいこと

Box-Cox変換は以下の式で定義される変数変換で対数変換の一般化だそうです。($\lambda$->0の極限で対数変換になる)

\mathrm{BoxCox}(y|\lambda) = \frac{y^\lambda - 1}{\lambda}

この変換を行うことで正の実数が実数全体に写像されます。したがって変数が正の実数で正規分布が使用できない場合にでもこの変換を使用することで正規分布に近づけることができます。

Box-Cox変換は式の通りパラメータ$\lambda$を含んでおり、この$\lambda$もデータに依存してチューニングするべきものです。ハイパーパラメータとして探索してもいいですし、SciPyに実装されているように線形モデルを仮定した場合には他のモデルパラメータに依存せずに最尤推定することができますが、ここでは$\lambda$も含めてベイズ推定する方法を実装してみます。

モデル

\begin{align}
z & = \mathrm{BoxCox}(y|\lambda) \\
z & \sim \mathcal{N(z|\mathbf{w}^T\mathbf{x}, \sigma^2)}
\end{align}

上記のように$y$ではなく、Box-Cox変換した$z$が正規分布に従っているという線形回帰モデルを考えます。このモデルにおいて$y$を変数に取った尤度は単純な変数変換だけではなく、次のようにJacobianを含んだ形になります。

\begin{align}
P(y|\mathbf{x}, \mathbf{w}, \sigma^2, \lambda) & = P(z|\mathbf{x}, \mathbf{w}, \sigma^2, \lambda)\frac{dz}{dy} \\
 & = \mathcal{N(z|\mathbf{w}^T\mathbf{x}, \sigma^2)} y^{\lambda - 1} \\
\ln P(y|\mathbf{x}, \mathbf{w}, \sigma^2, \lambda) & =  \ln \mathcal{N(z|\mathbf{w}^T\mathbf{x}, \sigma^2)} + (\lambda-1)\ln y
\end{align} 

Stanによる実装

これをStanで実装すると以下の通りになります。

functions {
    /* ... function declarations and definitions ... */
    vector boxcox(vector y, real lambda, int N) {
        vector[N] z;
        real inv_lambda = 1.0 / lambda;
        for (i in 1:N){
            z[i] = (pow(y[i], lambda) - 1.0) * inv_lambda;
        }
        return z;
    }

    vector inv_boxcox(vector z, real lambda, int N) {
        vector[N] y;
        real inv_lambda = 1.0 / lambda;
        for (i in 1:N){
            y[i] = pow((z[i] * lambda + 1.0), inv_lambda);
        }
        return y;
    }
}

data {
    /* ... declarations ... */
    int N;
    int N_new;
    int D;
    matrix[N, D] X;
    vector[N] y;
    matrix[N_new, D] X_new;
    real<lower=0> w_s;
    real<lower=0> sigma_beta;
    real<lower=0> lambda_s;
}

parameters {
    /* ... declarations ... */
    vector[D] w;
    real<lower=0> sigma;
    real lambda;
}

transformed parameters {
    /* ... declarations ... statements ... */
    vector[N] z;
    z = boxcox(y, lambda, N);
}

model {
    /* ... declarations ... statements ... */
    w ~ normal(0.0, w_s);
    sigma ~ exponential(1.0 / sigma_beta);
    lambda ~ normal(1.0, lambda_s);
    z ~ normal(X * w, sigma);
    target += (lambda - 1.0) * log(y);
}

generated quantities {
    /* ... declarations ... statements ... */
    vector[N_new] yp;
    if (N_new > 0){
        yp = inv_boxcox(to_vector(normal_rng(X_new * w, sigma)), lambda, N_new);
    }
}

ポイントとしては、Stanの自動微分は独立変数の方しか面倒を見てくれないので従属変数($y$)の方の変数変換は自分でJacobiantarget +=の文法で足してあげる必要があるという点です。$\lambda$は中心が1.0の正規分布を事前分布に設定しています。これは$\lambda=1.0$の時は定数項を足している以外は何もしていない変換になるのでそれを中心にしたかったからです。X_new,N_newgenerated quanitiesはパラメータ推定とは関係ないですが、推定したパラメータのサンプルを使用して未知データに対する予測を行うときに使用します。

実験

出力がPoisson分布の線形モデルのテストデータを作成し、Poisson回帰の代わりにBox-Cox変換を含んだ線形回帰モデルで係数を推定してみます。

\ln h = 5.0 - 1.2 x \\
y \sim \mathrm{Po}(y|h)
N = 50
D = 2

w = np.array([5.0, -1.2])
X = np.random.randn(N, D - 1) * 0.1
X = np.column_stack([np.ones(N), X])
h = np.exp(np.dot(X, w))
y = np.random.poisson(h)

model = BayesianBoxCoxLinearRegression(
    w_s=2.0,
    sigma_beta=2.0,
    lambda_s=2.0,
    n_warmup=500,
    n_samples=1000,
    n_chains=4,
)
model.fit(X, y)
model.summary()
           Mean    MCSE  StdDev       5%      50%      95%   N_Eff  N_Eff/s  R_hat
name                                                                              
lp__   -150.000  0.0850    1.70 -150.000 -150.000 -150.000   420.0     81.0    1.0
w[1]      4.800  0.0720    1.10    3.200    4.600    6.900   250.0     48.0    1.0
w[2]     -1.100  0.0350    0.55   -2.200   -0.990   -0.430   245.0     47.0    1.0
sigma     0.100  0.0000    0.00    0.000    0.100    0.100   245.9     47.6    1.0
lambda   -0.000  0.0000    0.10   -0.200   -0.000    0.100   266.6     51.6    1.0

$\lambda$ = 0 なので対数変換になり、$\mathbf{w}$も(5.0, -1.2)に近い値が予測できています。

NumPyroによる実装

import jax.numpy as jnp
import numpyro as pyro
import numpyro.distributions as dist
from jax import grad, random, vmap


def boxcox(y: float, lam: float):
    return (y**lam - 1.0) / lam


def inv_boxcox(z: float, lam: float):
    return (z * lam + 1.0) ** (1.0 / lam)


trans_grad = vmap(grad(boxcox), (0, None, 0, None))


def model(
    X: jnp.ndarray,
    y: Optional[jnp.ndarray] = None,
    w_s=2.0,
    sigma_beta=2.0,
    lambda_s=1.0,
):
    n, d = X.shape
    with pyro.plate("d", d):
        w = pyro.sample("w", dist.Normal(0.0, w_s))
    sigma = pyro.sample("s", dist.Exponential(sigma_beta))
    z = jnp.dot(X, w)
    lam = pyro.sample("lambda", dist.Normal(1.0, lambda_s))
    if y is not None:
        yl = transform(y, lam, z, sigma)
        grad_factor = trans_grad(y, lam, z, sigma)
        pyro.factor("", jnp.log(grad_factor).sum())
        with pyro.plate("n", n):
            pyro.sample("yl", dist.Normal(z, sigma), obs=yl)
    else:
        with pyro.plate("n", n):
            yl = pyro.sample("yl", dist.Normal(z, sigma))
            return pyro.deterministic("y", inv_boxcox(yl, lam))

基本的にはStanと同等ですが、JAXの自動微分gradを使ってBox-Cox変換の微分を計算し、vmapでベクトル化している点だけが異なります。もちろんStanの時のように解析的に求めた微分を使用してもいいですが、自動微分を利用することで今後Box-Cox変換以外にもいろいろな変換にも使えます。

その他

コード全体は https://github.com/lucidfrontier45/BaysianBoxCoxLM にあります。StanはPyStanではなくCmdStanPyを使用しています。Stanの場合はgenerated quantites、NumPyroの場合はnumpyro.infer.Predictiveを使用して事後予測分布を生成する機能も付けています。また、全体をBayesianBoxCoxLinearRegressionというクラスでまとめ、使用するbackendとしてCmdStanPyとNumPyroを切り替えられるようにしています。