【強化学習】cpprb で Experience Replay を簡単に!


0. はじめに

先日「【強化学習】Ape-X の高速な実装を簡単に!」を公開しましたが、今回はその基礎となる「Experience Replay (経験再生)」に関してより入門者向けの記事を書いてみようと思います。

(インターネット上で、「Experience Replay」と検索するとPythonでフルスクラッチで実装した記事がたくさんヒットする (1234、etc.) ので、もっと簡単に利用できることを書こうと思い)

1. cpprb

拙作の cpprb は、強化学習の Experience Replay 向けに開発しているライブラリです。

1.1 インストール

1.1.1 Linux/Windows

PyPI からそのままバイナリをインストールできます。

pip install cpprb

1.1.2 macOS

残念ながらデフォルトで利用される clang ではコンパイルができないため、Homebrew なり MacPorts なりで、gcc を用意してもらいインストール時に手元でコンパイルしてもらうことが必要です。

/path/to/g++ はインストールした g++ のパスに置き換えてください。

CC=/path/to/g++ CXX=/path/to/g++ pip install cpprb

参考: 公式サイトのインストール手順

2. Experience Replay

2.1 概要

Experience Replayは、エージェントによる環境の探索で得られたトランジションをそのままニューラルネットワークに渡して学習させるのではなく、一旦ためておいて無作為に取り出したサンプルでニューラルネットワークを学習させる方法です。

連続するトランジションに内在する自己相関の影響による学習の不安定性を低減することが知られており、off-policyの強化学習において幅広く利用されています。

2.2 サンプルコード

cpprbを利用したExperience Replayのサンプルコードを以下に掲載します。ニューラルネットワークによるモデルの実装や、可視化・モデルの保存などはこのサンプルには含まれていません。(MockModel の部分の実装を改造してください。)

import numpy as np
import gym

from cpprb import ReplayBuffer

n_training_step = int(1e+4)
buffer_size = int(1e+6)
batch_size = 32

env = gym.make("CartPole-v1")

class MockModel:
    # ここにDQNなどのモデルを実装する
    def __init__(self):
        pass

    def get_action(self,obs):
        return env.action_space.sample()

    def train(self,sample):
        pass

model = MockModel()

obs_shape = 4
act_dim = 1
rb = ReplayBuffer(buffer_size,
                  env_dict ={"obs": {"shape": obs_shape},
                             "act": {"shape": act_dim},
                             "rew": {},
                             "next_obs": {"shape": obs_shape},
                             "done": {}})
# 保存するものを dict 形式で指定する。"shape" と "dtype" を指定可能。デフォルトは、{"shape": 1, "dtype": np.single}


obs = env.reset()

for i in range(n_training_step):
    act = model.get_action(obs)
    next_obs, rew, done, _ = env.step(act)

    # keyword argument として渡す    
    rb.add(obs=obs,act=act,rew=rew,next_obs=next_obs,done=done)

    if done:
        rb.on_episode_end()
        obs = env.reset()
    else:
        obs = next_obs

    sample = rb.sample(batch_size)
    # dict[str,np.ndarray] 形式で、ランダムサンプルされる

    model.train(sample)

3. Prioritized Experience Replay

3.1 概要

Prioritized Experience Replay (優先度付き経験再生) は、Experience Replay を発展させたバージョンで、TD誤差の大きかったトランジションをより優先してサンプルする手法です。

細かい説明はこの記事では省きますが、以下の記事やサイトに解説があります。

3.2 サンプルコード

Experience Replayのサンプルコード同様、ニューラルネットワークによるモデルの実装や、可視化、保存などはこのサンプルには含まれていません。

Prioritized Experience Replay を高速に実行するためには、Segment Tree を利用することが提案されていますが、独自実装するとバグに陥りやすく、またPythonで実装すると遅いということが多々あります。(cpprbでは、Segment Tree は C++ で実装してあり高速です。)

import numpy as np
import gym

from cpprb import PrioritizedReplayBuffer

n_training_step = int(1e+4)
buffer_size = int(1e+6)
batch_size = 32

env = gym.make("CartPole-v1")

class MockModel:
    # ここにDQNなどのモデルを実装する
    def __init__(self):
        pass

    def get_action(self,obs):
        return env.action_space.sample()

    def train(self,sample):
        pass

    def compute_abs_TD(self,sample):
        return 0

model = MockModel()

obs_shape = 4
act_dim = 1
rb = PrioritizedReplayBuffer(buffer_size,
                             env_dict ={"obs": {"shape": obs_shape},
                                        "act": {"shape": act_dim},
                                        "rew": {},
                                        "next_obs": {"shape": obs_shape},
                                        "done": {}},
                             alpha = 0.4)


obs = env.reset()

for i in range(n_training_step):
    act = model.get_action(obs)
    next_obs, rew, done, _ = env.step(act)

    # バッファに加える際に、 priority を直接指定することも可能。未指定時は最大の priority が利用される。
    rb.add(obs=obs,act=act,rew=rew,next_obs=next_obs,done=done)

    if done:
        rb.on_episode_end()
        obs = env.reset()
    else:
        obs = next_obs

    sample = rb.sample(batch_size, beta = 0.4)
    # コンストラクタで指定したトランジションに加えて、"indexes", "weights" が np.ndarray として dict に含まれる

    model.train(sample)

    abs_TD = model.compute_abs_TD(sample)
    rb.update_priorities(sample["indexes"],abs_TD)

4. 困ったときは

ユーザーフォーラムとして GitHub Discussions を開設しましたので、cpprbに関する質問などはこちらへ