TensorFlowのQueueを用いた非同期について


※ 著者の勘違いが孕んでいる可能性があるので注意して読んでほしいです。

強化学習の分野にあるA3Cという非同期でActor Criticを学習するアルゴリズムを勉強するにあたって、どのようにして非同期をTensorFlowで実現するのか調べたのでまとめました。

申し訳程度にいくつかのクラスについての説明を載せていますが、それらのクラスを用いたプログラムを最後に載せているので、動かしてから調べてもらったほうが良いかもしれません。

キュー(Queue)

tf.FIFOQueue

TensorFlowで複数のスレッドを非同期にテンソルを計算する便利なキューです。
例えば、非同期で動かしたいメソッドをこのキューに追加(queue.enqueue(function))していきます。他にも、ランダムにデキューをするキューが実装されており、tf.RandomShuffleQueueがあります。

必要な引数は以下の二つです。

capacity: integer型。 キューに追加できる上限数。
dtypes: リストに入るオブジェクトの型。

詳しくはtf.FIFOQueueを参照してください。

tf.train.QueueRunner

オペレーション(operation)をenqueue()したものをリストに保持し、それぞれをスレッド内で実行します。

qr = tf.train.QueueRunner(queue, [queue.enqueue() for _ in [...]])
tf.train.queue_runner.add_queue_runner(qr)

詳しくはtf.train.QueueRunnerを参照してください。

tf.train.Coordinator

tf.train.Coordinatorクラスは開始したスレッドを調整してくれます。tf.train.start_queue_runners()はセッション内のグラフで収集されたすべてのtf.QueueRunner()のスレッドを実行してくれます。この二つを併用することで、非同期に実行できます。

coord = tf.train.Coordinator()
threads = tf.train.start_queue_runners(sess, coord)

詳しくはtf.train.Coordinatorを参照してください。

実装例

非同期に変数(loss)から一様乱数(tf.random_uniform)を減算するだけのものです。

import logging

import tensorflow as tf


class Group(object):
  def __init__(self, scope):
    with tf.name_scope(scope):
      self.loss = tf.Variable(1000., trainable=False, name="loss")

  def loss_op(self, queue: tf.FIFOQueue):
    loss = tf.assign_sub(self.loss, tf.random_uniform(shape=[], maxval=10.))
    print("loss op: {0}".format(self.loss))
    tf.summary.scalar(self.loss.name, self.loss)
    return queue.enqueue(loss)


def main(_):
  thread_size = 5
  groups = [Group(scope="thread_{0}".format(i)) for i in range(thread_size)]
  queue = tf.FIFOQueue(capacity=thread_size * 10,
                       dtypes=[tf.float32], )
  qr = tf.train.QueueRunner(queue, [g.loss_op(queue) for g in groups])
  tf.train.queue_runner.add_queue_runner(qr)
  loss = queue.dequeue()
  mean_loss = tf.reduce_mean([g.loss for g in groups])

  init = tf.global_variables_initializer()
  with tf.Session() as sess:
    init.run()
    summaries = tf.summary.merge_all()
    summary_writer = tf.summary.FileWriter('log', sess.graph)
    sess.graph.finalize()

    coord = tf.train.Coordinator()
    threads = tf.train.start_queue_runners(sess, coord)

    try:
      for i in range(100):
        if coord.should_stop():
          break
        _loss, _mean, _summaries = sess.run([loss, mean_loss, summaries])
        print("loss: {0:0.2f}, mean loss: {1:0.2f}".format(_loss, _mean))
        summary_writer.add_summary(_summaries, global_step=i)
    except Exception as e:
      coord.request_stop(e)
    finally:
      coord.request_stop()
      coord.join(threads)


if __name__ == '__main__':
tf.app.run()

TensorBoard