Pyro概要:生成モデル実装ライブラリ(一)、モデル

4710 ワード

概要:
Pythonフレームワークで生成モデルを実現するには,最も基本的に実現しなければならないのが確率関数である.このような関数の実装には、2つの要素が含まれています.
  • 確定的Pythonコード
  • 乱数発生器
  • 具体的には、ランダム関数は、__call__()メソッドを有する任意のPythonオブジェクト、またはPytorchフレームワーク内のnn.Moduleメソッドであってもよい.このチュートリアルでは、すべてのランダム関数をモデルと呼び、モデルを表現する方法は、通常のPython方法と変わりません.
    インストール:
    pip install pyro-ppl
    

    または、
    git clone https://github.com/uber/pyro.git
    cd pyro
    python setup.py install
    

    注意:Pyroはpython 3.*のみをサポートしています.Python 2.*はサポートされていません!
    最も基礎的なランダム関数モデル
    PyroはPytorchのdistribution librarayを利用した.例を挙げると、xをサンプリングして標準正規分布に従います.
    import torch
    import pyro
    
    pyro.set_rng_seed(101)
    
    loc = 0. #    0
    scale = 1. #    1
    
    normal = torch.distributions.Normal(loc, scale) #           
    x = normal.rsample() # N(0,1)  
    print("sample", x)
    print("log prob", normal.log_prob(x)) #     
    
    ##  :
    ##sample tensor(-1.3905)
    ##log prob tensor(-1.8857)
    
    normal = pyro.distributions.Normal(loc, scale) #           
    x = normal.rsample() # N(0,1)  
    print("sample", x)
    print("log prob", normal.log_prob(x)) #     
    
    ##  :
    ##sample tensor(1.3834)
    ##log prob tensor(-1.8759)
    

    ここでtorch.distributions.NormalDistributionクラスのサブクラスであり、サンプリングと採点機能を実現している.Pyroのライブラリpyro.distributionstorch.distributionsの方法をパッケージ化し,Pytorchの数学的方法と自動導出機能を利用した科学研究を行った.
    簡単な例です
    私たちが手に持っているデータを仮定して、日常の気温と曇りと晴れを記録しました.私たちは気温と曇りと晴れの関係を研究したい.そこで、まずPytorchで次の関数を構築し、データの生成プロセスを説明します.
    def weather():
        cloudy_ = torch.distributions.Bernoulli(0.3).sample()
        cloudy = 'cloudy' if cloudy_.item() == 1. else 'sunny'
        mean_temp = {'cloudy': 55., 'sunny': 75.}[cloudy] #   ,       
        scale_temp = {'cloudy': 10., 'sunny': 15.}[cloudy] 
        temp = torch.distributions.Normal(mean_temp, scale_temp).rsample()
        return cloudy, temp.item()
    

    この関数を行ごとに説明します.2行目では、2値変数cloudy_を定義し、その値が1である確率を表すバーヌリーパラメータ0.3を定義する.3行目cloudyは文字列変数であり、値は「cloudy」(曇り)または「sunny」(晴れ)である.モデルは、30%の可能性は曇り、70%の可能性は晴れと規定している.第4行では、曇り空の平均温度は華氏55度(約12.78°C)、晴天は華氏75度(約23.89°C)と規定されている.5行目は標準差を規定している.次に、Pyroを使用して上記の関数を書き直します.
    def weather():
        cloudy_ = pyro.sample('cloudy', pyro.distributions.Bernoulli(0.3))
        cloudy = 'cloudy' if cloudy_.item() == 1.0 else 'sunny'
        mean_temp = {'cloudy': 55.0, 'sunny': 75.0}[cloudy]
        scale_temp = {'cloudy': 10.0, 'sunny': 15.0}[cloudy]
        temp = pyro.sample('temp', pyro.distributions.Normal(mean_temp, scale_temp))
        return cloudy, temp.item()
    
    for _ in range(3):
        print(weather())
    
    ###  :
    #('cloudy', 64.5440444946289)
    #('sunny', 94.37557983398438)
    #('sunny', 72.5186767578125)
    

    表面的にはpyro.sampleしか利用していませんが、そうではありません.もし私たちが聞きたいならば、サンプリングが70度で、どのくらいの確率が曇り空ですか?Pyroでこの質問に答える方法を利用して、チュートリアルの連続性を損なわないように、次のチュートリアルで説明します.
    フレームワークの汎用性:再帰ランダム関数、高次ランダム関数、乱数制御フロー
    次の簡単なモデルを定義します.
    def ice_cream_scales():
        cloudy, temp = weather()
        expected_sales = 200. if cloudy == 'sunny' and temp > 80. else 50.
        ice_cream = pyro.sample('ice_cream', pyro.distributions.Normal(expected_sales, 10.))
        return ice_cream
    

    ここまで満足でした.Pyroはより複雑なモデルをカバーできるかどうかを聞きます.答えは肯定的だ.PyroはPythonコードに基づいているため、ランダム数が含まれていても問題なく、任意の複雑な制御フローを定義することができます.例えば、pyro.sampleの決定的なサンプリングに入力するたびに、再帰的なランダム関数を定義することができる.さらに、幾何学的分布、すなわち、実験が成功するまで、実験が失敗した回数をカウントすることを定義します.
    def geometric(p, t=None):
        if t is None:
            t = 0
        x = pyro.sample('x_{}'.format(t), pyro.distributions.Bernoulli(p))
        if x.item() == 1.:
            return 0
        else:
            return 1 + geometric(p, t + 1)
    print(geometric(0.5))
    #   ,              
    # 3 
    

    なお、上記のgeometric()関数では、x_0x_1のような変数が動的に生成される.他のランダム関数の結果を新しい定義関数の変数としたり、新しい関数を作成したりすることもできます.次の例を見てください.
    def normal_product(loc, scale):
        z1 = pyro.sample('z1', pyro.distributions.Normal(loc, scale))
        z2 = pyro.sample('z2', pyro.distributions.Normal(loc, scale))
        y = z1 * z2
        return y
    
    def make_normal_normal():
        mu_latent = pyro.sample('mu_latent'', pyro.distributions.Normal(0., 1.))
        fn = lambda scale: normal_product(mu_latent, scale)
        return fn
    
    make_normal_normal()において、入力の1つはランダム関数であり、このランダム関数は3つのランダム変数を含む.PyroはPythonコードの様々な形式をサポートしています:ループ、再帰、高次関数など.これはPyroが汎用性を有し、PyroがPytorchに基づいてGPU加速を柔軟に使用できることを意味する.次のチュートリアルでは、Pyroの強力な使い方を紹介します.