Jaxで 自動微分を使って Unsupervised Kernel Regression 実装


この記事は古川研究室 Advent_calendar 14日目の記事です。
本記事は古川研究室の学生が学習の一環として書いたものです。内容が曖昧であったり表現が多少異なったりする場合があります。


本記事ではJaxを使ってUnsupervised Kernel Regression:UKRを実装します.

UKRについては,以前のJuliaで実装した記事ですでに概要を説明しましたので本記事では主に実装できたよってことだけ紹介します.

UKRをもうちょっと詳しく知りたいという人は同カレンダーの22日の記事を参考にしてください.


私にとってのJax

Jax は高性能な機械学習研究のために結集された自動微分とXLAです.

PythonとNumpyのネイティブ関数を自動的に微分できるようになっています.

つまり,Google主導で作っている機械学習用のライブラリです.

numpyにない自動微分やGPUのサポートとかを簡単にしてくれるものだと認識しています.

僕らは基本的に,アルゴリズムはScratch実装してきました.微分を実装するときも,アルゴリズムを導出->(紙で)手動で微分.してその解を実装していました.

最近はラボの基盤技術も変わってきており,連続な関数を扱うことが増えてきましたので,そこで最近になって自動微分が使われるようになってきました.

ラボ内スタンダードはPytorchですが,Numpyとの乖離が激しいのがつらいところです.例えばそれは,変数名がちょっと違うとこだったり,引数が違うとこだったり(axis/dim),変数.detach().numpy()としないとnumpy形式に変換できなかったり,numpy形式にしないとmatplotlibが動かなかったり...

特に,プログラミングを習うのがそもそも初めての新入生にとっては頑張ってnumpyに慣れたところに,微分を使うならpytorchだよ✨っていうのもちょっと心苦しいわけです.

そこでJaxですよ.奥さん.

変数名はnumpyとほぼ同じです.(ランダムとかはちょっと違いますが).

自動微分の機能を覚えるだけです.

また,GPUを積んでたら勝手にGPUを使ってくれます.

それじゃさっそくJaxを使っていきましょう.

UKR実装

Preliminary

import numpy as np
import jax
import jax.numpy as jnp
from jax.ops import index, index_add, index_update # 配列の一部を変えるためのメソッド(未使用)

データの作成

入力となるデータを作ります.またしても,双曲面データです.


def create_data(samples, seed=1):
    # latent_dim =2
    # x_dim =2
    rng = jax.random.PRNGKey(seed)
    # z = jax.random.normal(rng, [samples, 2]) 
    z = jax.random.uniform(rng, [samples, 2], minval=-1, maxval=1) 
    x_1 = z[:, 0][:, None]
    x_2 = z[:, 1][:, None]
    x_3 = (z[:, 0] ** 2 - z[:, 1] ** 2)[:, None]
    return jnp.concatenate([x_1, x_2, x_3], axis=1)

samples = 200
x = create_data(samples)
global x # ここglobalなんだよね・・

潜在変数の初期化

潜在変数のスケールは$\sigma$より小さくする必要があるため,*1e-04をかけています.

    rng = jax.random.PRNGKey(seed)
    z = jax.random.normal(rng, shape=[N, L]) * 0.00001

次に,$(f(\mathbf{z}_n))$をつくる式を作ります.


def estimate_x(z1, z2, x):
    # This is Nadaraya-Watson Esitmator
    # Need z2 shape (N,)
    Dist = jnp.sum((z1[:, None, :] - z2[None, :, :])**2, axis=2)
    H = jnp.exp(-0.5 * Dist)
    G = jnp.sum(H, axis=1)[:, None]
    R = H/G
    x_hat = R@x
    return x_hat

次にE(Z)をつくります.

def E(z):
    x_hat = estimate_x(z, z, x)
    E = jnp.sum((x - x_hat)**2)
    return E

じゃあ学習します.

    for e in range(epoch):
        grad_E = jax.grad(E)
        z -= eta * grad_E(z)
        # all_z = index_add(all_z, index[e + 1,:,:], z)
        all_z[e, :, :] = z
    return all_z

E(z)をつくるのは,微分したい変数のみを入力とした関数をつくる必要があるからです.
つまり, $\frac{\partial E(Z)}{\partial Z}$の$\partial Z$を関数の引数で表現するということですね.

なので微分に使う変数で他に指定したいときは,globalやクラスのself.を使って,暗黙的に使用できる必要があります.
これをこうしなくても回避する方法があったら教えてください.

コメントで教えていただきました.gradの引数にargnumsというのがあって,そこの引数で微分したい引数を選択できるみたいです.
なので,微分する関数の引数はE(Z, X)としても構いません.

学習結果

コードはGistにはっており,colabでも開けるようになっているので,必要なコードがある人がいたらみてみてください.


注意点

  • jaxは基本32bit演算(オプションで64にはできる)
  • 配列の一部を入れ替えるっていう操作は癖あり(ただ,微分に関わるところだけjax numpy形式を使えば良いからあまり困らない.基本はnumpy形式でいいはず)(numpyの変数[i,:] = jaxの変数ができるので保存用の変数はそうすればよい)