[論文解説] IMPALA: Scalable Distributed Deep-RL with Importance Weighted Actor-Learner Architectures


この記事は,以下の論文の解説です.

IMPALA: Scalable Distributed Deep-RL with Importance Weighted Actor-Learner Architectures (ICML 2018)

ただし,アルゴリズムを解説することを目的とし,検証実験に関しては触れていません.(めんどくさかったため)

記事内容では,強化学習の基礎的な知識を前提としています.
不備がございましたら,ご指摘頂けると幸いです.

概要

この論文では, 学習の安定性や効率を損なうことなく,数千のマシン上で学習を行うことのできる分散強化学習の手法 = IMPALA,およびIMPALAでも利用されている off-policyでの学習を実現するためのアルゴリズム = V-trace を提案しています.

IMPALA

図 : https://deepmind.com/blog/article/impala-scalable-distributed-deeprl-dmlab-30 から引用.

IMPALAでは,Actor-Criticの設定を用いて,方策 $\pi$ と価値関数 $V^\pi$ を学習します.IMPALAのアーキテクチャでは,データ収集のみを行う複数のプロセス(Actor)と,off-policy で学習を行う1個以上のプロセス(Learner)で構成されます(上図).

Actorは $n$ ステップごとにデータ収集を行います.まず自身の方策 $\mu$ をLearnerの方策 $\pi$ に更新し,$n$ ステップ間データ収集を行います.その後,収集した経験データ(状態,行動,報酬),方策の分布 $\mu(a_t|x_t)$ ,LSTMの初期状態をLearnerに送ります.一方,Learnerは複数のActorから送られてきたデータを用いて,繰り返し学習を行います.

このデータ収集と学習が分離されたシンプルな構成により,以下のメリットがあります.

  • Learnerは各Actorが収集したデータを利用できるので,GPUを活用してデータ並列化できる.またデータ収集を学習から分離することで,LearnerのGPU利用効率が向上する.
  • Actorを多数のマシンに分散させることが容易となる.またActorはそれぞれ非同期なので,1ステップにかかる時間にばらつきがある環境でも,同期のための待ち時間が生じない.
  • 勾配情報に比べて経験データはサイズが小さいので,通信性能が向上する.

しかし,この構成では データ収集時の方策 $\mu$ と学習している方策 $\pi$ が必ずしも一致しない という問題が生じます.そこで,この方策のズレを補正する V-trace と呼ばれるアルゴリズムを提案しています.V-trace を用いることで,サンプル効率を損なうことなく,IMPALAアーキテクチャが持つ高いスループット性能を獲得することができます.さらに,IMPALAでは 複数Learnerを用いることで,複数GPUを効率的に活用することができる というメリットがあります(上図右).

V-trace

データ収集と学習が分離した構成では,データを収集する方策と学習する方策の間にズレが生じるため,off-policy での学習がキモとなります.そこで,この方策のズレを補正するために V-trace と呼ばれるアルゴリズムを提案していきます.

以下では,無限時間のMDPにおいて,期待割引報酬和(価値関数) $V^\pi(x):=\mathbb E_\pi[\sum_{t\ge0} \gamma^t r_t]$ を最大化する方策 $\pi$ を見つけることを考えます.ここで,$\gamma \in [0,1)$ は割引率,方策は確率的方策 $a_t \sim \pi(\cdot|x_t)$ とします.以下では,行動方策 $\mu$ で収集されたデータを用いて,学習方策 $\pi$ の価値関数 $V^\pi$ を学習することを考えます.

V-traceオペレータ

$n$ ステップの V-trace オペレータ $\mathcal R^n$ を以下のように定義します.

\mathcal R^n V(x_s) := V(x_s) + \mathbb E_\mu[
\sum_{t=s}^{s+n-1} \gamma^{t-s}(c_s\cdots c_{t-1})\delta_tV
]

ここで,$\delta_tV := \rho_{t}(r_t+ \gamma V(x_{t+1})-V(x_t))$ は,Importance Sampling (IS)により重み付けられたTD誤差を表します.また,$\rho_i = \min(\bar \rho,\frac{\pi(a_i|x_i)}{\mu(a_i|x_i)})$ ,$c_i = \min(\bar c, \frac{\pi(a_i|x_i)}{\mu(a_i|x_i)})$ は,クリップされたISの重み係数を表し,クリップの閾値は $\bar \rho \ge \bar c$ を満たすこととします.

ここで,on-policy ( $\mu = \pi$ ) の場合には,

\begin{align}
\mathcal R^n V(x_s) &= V(x_s) + \mathbb E_\mu[
\sum_{t=s}^{s+n-1} \gamma^{t-s} (r_t+ \gamma V(x_{t+1})-V(x_t))] \\
&= \mathbb E_\mu[
\sum_{t=s}^{s+n-1}\gamma^{t-s} r_t + \gamma^n V(x_{s+n})]
\end{align}

と変形でき,on-policy の場合の V-trace は,オンライン学習時の $n$ ステップベルマンオペレータに一致することがわかります.

V-trace において,2つの異なるIS重み係数の閾値 $\bar \rho$ と $\bar c$ は異なる役割を果たします.

まず 重み係数 $\rho_i$ の閾値 $\bar \rho$ は,V-trace オペレータの唯一の不動点を定義している と考えることができます.関数近似誤差を生じない表形式の場合には,V-trace オペレータは,以下の式で表される方策 $\pi^{\bar \rho}$ の価値関数 $V^{\pi_{\bar \rho}}$ を唯一の不動点として持ちます(実際に計算すると,確認できます).

\pi_{\bar \rho}(a|x) := \frac{\min(\bar \rho \mu(a|x), \pi(a|x))}{\sum_b\min(\bar \rho \mu(b|x), \pi(b|x))}

よって,$\bar \rho$ が無限のとき,V-trace オペレータは方策 $\pi$ の価値関数 $V^\pi$ を唯一の不動点として持ち.$\bar \rho < \infty$ の場合には,$\pi$ と $\mu$ の間の方策の価値関数を不動点として持つことになります.ISを行うと,行動方策 $\mu$ の確率密度が低いデータにおいて重み係数が非常に大きくなり,推定分散が大きくなってしまうので,$\bar \rho$ でクリップすることで分散を抑制します.よって,$\bar \rho$ が大きいときほど off-policy の学習における分散が大きく(一方でバイアスが小さく),$\bar \rho$ が小さくなるにつれ分散が小さく(バイアスが大きく)なります.また,$\rho_i$ は $c_i$ と異なり時系列で掛け合わせる操作を行わないので,時系列によって発散するような挙動は生じません.

次に 重み係数 $c_i$ の閾値 $\bar c$ は,V-trace オペレータの収束の速さを制御している と考えることができます.重み係数の掛け合わせ $(c_s\cdots c_{t-1})$ は,時刻 $t$ でのTD誤差 $\delta_t V =\rho_t (r_t + \gamma V(x_{t+1}) - V(x_t))$ が時刻 $s$ の価値関数 $V(x_s)$ の更新にどの程度寄与するかを評価しています.$c_i$ は時系列で掛け合わせる操作を伴うため発散しやすく,分散抑制のために重み係数 $c_i$ をクリップすることが重要です.ここで,$\bar c$ の大きさは V-trace オペレータの不動点(学習が収束する点)には影響を与えない ので,より分散を抑制するために $\bar \rho$ よりも小さな値に設定することが望ましいです.

実際には,V-trace は以下のように再帰的に計算することができます.

\mathcal R^n V(x_t) = V(x_t) + + \mathbb E_{\mu}[
\delta_t V+ \gamma c_t (\mathcal R^n V(x_{t+1}) - V(x_{t+1}))
]

V-trace Actor-Critic

方策 $\pi_\omega$ をパラメータ $\omega$ で,価値関数 $V_\theta$ をパラメータ $\theta$ で関数近似します.経験データは行動方策 $\mu$ で収集されたものとします.

価値関数の学習には,V-trace オペレータにおけるTD誤差を損失関数として用います.

L_\theta = (\mathcal R^n V(x_s) - V_\theta(x_s))^2

勾配は容易に計算可能で,以下の式で表せます.

\nabla_\theta L_\theta = (\mathcal R^n V(x_s) - V_\theta(x_s)) \nabla_\theta V_\theta(x_s)

また,方策勾配定理とISにより,方策 $\pi_{\bar\rho}$ の勾配は以下の式で表せます.

E_{a_s\sim \mu}[\frac{\pi_{\bar \rho}(a_s|x_s)}{\mu(a_s|x_s)} \nabla_\omega \log \pi_{\bar \rho}(a_s|x_s) (q_s - b(x_s)) | x_s]

ここで,$q_s = r_s + \gamma \mathcal R^n V(x_{s+1})$ は V-trace オペレータの下での行動価値関数 $Q^{\pi_\omega}(x_s,a_s)$ の推定値を表し,$b(x_s)$ は分散を抑制するための状態依存のベースライン関数を表します.クリッピングよるバイアスが極めて小さい( $\bar \rho$ が十分大きい)場合,上述の勾配は $\pi_{\omega}$ の方策勾配の良い推定値であると考えられます.よって,ベースライン関数に$V_\theta(x_s)$ を用いることで,以下の方策勾配を得ます.

\nabla_\omega L_\omega = \rho_s \nabla_\omega \log \pi_\omega(a_s|x_s) (
r_s + \gamma \mathcal R^n V(x_{s+1}) - V_\theta(x_s))

また,方策 $\pi_\omega$ が局所解に収束してしまうのを防ぐため,エントロピー損失を加えることも考えられます.

L_{\rm ent} = - \sum_a \pi_\omega(a|x_s)\log \pi_\omega(a|x_s)

IMPALAでは,これら3つの損失関数を用いて学習を行います.

(検証実験の解説の需要がある方は,コメントお願いします.)