Pyro概要:生成モデル実装ライブラリ(一)、モデル
4710 ワード
概要:
Pythonフレームワークで生成モデルを実現するには,最も基本的に実現しなければならないのが確率関数である.このような関数の実装には、2つの要素が含まれています.確定的Pythonコード 乱数発生器 具体的には、ランダム関数は、
インストール:
または、
注意:Pyroはpython 3.*のみをサポートしています.Python 2.*はサポートされていません!
最も基礎的なランダム関数モデル
PyroはPytorchのdistribution librarayを利用した.例を挙げると、
ここで
簡単な例です
私たちが手に持っているデータを仮定して、日常の気温と曇りと晴れを記録しました.私たちは気温と曇りと晴れの関係を研究したい.そこで、まずPytorchで次の関数を構築し、データの生成プロセスを説明します.
この関数を行ごとに説明します.2行目では、2値変数
表面的には
フレームワークの汎用性:再帰ランダム関数、高次ランダム関数、乱数制御フロー
次の簡単なモデルを定義します.
ここまで満足でした.Pyroはより複雑なモデルをカバーできるかどうかを聞きます.答えは肯定的だ.PyroはPythonコードに基づいているため、ランダム数が含まれていても問題なく、任意の複雑な制御フローを定義することができます.例えば、
なお、上記の
Pythonフレームワークで生成モデルを実現するには,最も基本的に実現しなければならないのが確率関数である.このような関数の実装には、2つの要素が含まれています.
__call__()
メソッドを有する任意のPythonオブジェクト、またはPytorchフレームワーク内のnn.Module
メソッドであってもよい.このチュートリアルでは、すべてのランダム関数をモデルと呼び、モデルを表現する方法は、通常のPython方法と変わりません.インストール:
pip install pyro-ppl
または、
git clone https://github.com/uber/pyro.git
cd pyro
python setup.py install
注意:Pyroはpython 3.*のみをサポートしています.Python 2.*はサポートされていません!
最も基礎的なランダム関数モデル
PyroはPytorchのdistribution librarayを利用した.例を挙げると、
x
をサンプリングして標準正規分布に従います.import torch
import pyro
pyro.set_rng_seed(101)
loc = 0. # 0
scale = 1. # 1
normal = torch.distributions.Normal(loc, scale) #
x = normal.rsample() # N(0,1)
print("sample", x)
print("log prob", normal.log_prob(x)) #
## :
##sample tensor(-1.3905)
##log prob tensor(-1.8857)
normal = pyro.distributions.Normal(loc, scale) #
x = normal.rsample() # N(0,1)
print("sample", x)
print("log prob", normal.log_prob(x)) #
## :
##sample tensor(1.3834)
##log prob tensor(-1.8759)
ここで
torch.distributions.Normal
はDistribution
クラスのサブクラスであり、サンプリングと採点機能を実現している.Pyroのライブラリpyro.distributions
はtorch.distributions
の方法をパッケージ化し,Pytorchの数学的方法と自動導出機能を利用した科学研究を行った.簡単な例です
私たちが手に持っているデータを仮定して、日常の気温と曇りと晴れを記録しました.私たちは気温と曇りと晴れの関係を研究したい.そこで、まずPytorchで次の関数を構築し、データの生成プロセスを説明します.
def weather():
cloudy_ = torch.distributions.Bernoulli(0.3).sample()
cloudy = 'cloudy' if cloudy_.item() == 1. else 'sunny'
mean_temp = {'cloudy': 55., 'sunny': 75.}[cloudy] # ,
scale_temp = {'cloudy': 10., 'sunny': 15.}[cloudy]
temp = torch.distributions.Normal(mean_temp, scale_temp).rsample()
return cloudy, temp.item()
この関数を行ごとに説明します.2行目では、2値変数
cloudy_
を定義し、その値が1である確率を表すバーヌリーパラメータ0.3
を定義する.3行目cloudy
は文字列変数であり、値は「cloudy」(曇り)または「sunny」(晴れ)である.モデルは、30%の可能性は曇り、70%の可能性は晴れと規定している.第4行では、曇り空の平均温度は華氏55度(約12.78°C)、晴天は華氏75度(約23.89°C)と規定されている.5行目は標準差を規定している.次に、Pyroを使用して上記の関数を書き直します.def weather():
cloudy_ = pyro.sample('cloudy', pyro.distributions.Bernoulli(0.3))
cloudy = 'cloudy' if cloudy_.item() == 1.0 else 'sunny'
mean_temp = {'cloudy': 55.0, 'sunny': 75.0}[cloudy]
scale_temp = {'cloudy': 10.0, 'sunny': 15.0}[cloudy]
temp = pyro.sample('temp', pyro.distributions.Normal(mean_temp, scale_temp))
return cloudy, temp.item()
for _ in range(3):
print(weather())
### :
#('cloudy', 64.5440444946289)
#('sunny', 94.37557983398438)
#('sunny', 72.5186767578125)
表面的には
pyro.sample
しか利用していませんが、そうではありません.もし私たちが聞きたいならば、サンプリングが70度で、どのくらいの確率が曇り空ですか?Pyroでこの質問に答える方法を利用して、チュートリアルの連続性を損なわないように、次のチュートリアルで説明します.フレームワークの汎用性:再帰ランダム関数、高次ランダム関数、乱数制御フロー
次の簡単なモデルを定義します.
def ice_cream_scales():
cloudy, temp = weather()
expected_sales = 200. if cloudy == 'sunny' and temp > 80. else 50.
ice_cream = pyro.sample('ice_cream', pyro.distributions.Normal(expected_sales, 10.))
return ice_cream
ここまで満足でした.Pyroはより複雑なモデルをカバーできるかどうかを聞きます.答えは肯定的だ.PyroはPythonコードに基づいているため、ランダム数が含まれていても問題なく、任意の複雑な制御フローを定義することができます.例えば、
pyro.sample
の決定的なサンプリングに入力するたびに、再帰的なランダム関数を定義することができる.さらに、幾何学的分布、すなわち、実験が成功するまで、実験が失敗した回数をカウントすることを定義します.def geometric(p, t=None):
if t is None:
t = 0
x = pyro.sample('x_{}'.format(t), pyro.distributions.Bernoulli(p))
if x.item() == 1.:
return 0
else:
return 1 + geometric(p, t + 1)
print(geometric(0.5))
# ,
# 3
なお、上記の
geometric()
関数では、x_0
、x_1
のような変数が動的に生成される.他のランダム関数の結果を新しい定義関数の変数としたり、新しい関数を作成したりすることもできます.次の例を見てください.def normal_product(loc, scale):
z1 = pyro.sample('z1', pyro.distributions.Normal(loc, scale))
z2 = pyro.sample('z2', pyro.distributions.Normal(loc, scale))
y = z1 * z2
return y
def make_normal_normal():
mu_latent = pyro.sample('mu_latent'', pyro.distributions.Normal(0., 1.))
fn = lambda scale: normal_product(mu_latent, scale)
return fn
make_normal_normal()
において、入力の1つはランダム関数であり、このランダム関数は3つのランダム変数を含む.PyroはPythonコードの様々な形式をサポートしています:ループ、再帰、高次関数など.これはPyroが汎用性を有し、PyroがPytorchに基づいてGPU加速を柔軟に使用できることを意味する.次のチュートリアルでは、Pyroの強力な使い方を紹介します.