TensorFlowのQueueを用いた非同期について
※ 著者の勘違いが孕んでいる可能性があるので注意して読んでほしいです。
強化学習の分野にあるA3C
という非同期でActor Critic
を学習するアルゴリズムを勉強するにあたって、どのようにして非同期をTensorFlow
で実現するのか調べたのでまとめました。
申し訳程度にいくつかのクラスについての説明を載せていますが、それらのクラスを用いたプログラムを最後に載せているので、動かしてから調べてもらったほうが良いかもしれません。
キュー(Queue)
tf.FIFOQueue
TensorFlow
で複数のスレッドを非同期にテンソルを計算する便利なキューです。
例えば、非同期で動かしたいメソッドをこのキューに追加(queue.enqueue(function)
)していきます。他にも、ランダムにデキューをするキューが実装されており、tf.RandomShuffleQueue
があります。
必要な引数は以下の二つです。
capacity: integer型。 キューに追加できる上限数。
dtypes: リストに入るオブジェクトの型。
詳しくはtf.FIFOQueue
を参照してください。
tf.train.QueueRunner
オペレーション(operation
)をenqueue()
したものをリストに保持し、それぞれをスレッド内で実行します。
qr = tf.train.QueueRunner(queue, [queue.enqueue() for _ in [...]])
tf.train.queue_runner.add_queue_runner(qr)
詳しくはtf.train.QueueRunner
を参照してください。
tf.train.Coordinator
tf.train.Coordinator
クラスは開始したスレッドを調整してくれます。tf.train.start_queue_runners()
はセッション内のグラフで収集されたすべてのtf.QueueRunner()
のスレッドを実行してくれます。この二つを併用することで、非同期に実行できます。
coord = tf.train.Coordinator()
threads = tf.train.start_queue_runners(sess, coord)
詳しくはtf.train.Coordinator
を参照してください。
実装例
非同期に変数(loss
)から一様乱数(tf.random_uniform
)を減算するだけのものです。
import logging
import tensorflow as tf
class Group(object):
def __init__(self, scope):
with tf.name_scope(scope):
self.loss = tf.Variable(1000., trainable=False, name="loss")
def loss_op(self, queue: tf.FIFOQueue):
loss = tf.assign_sub(self.loss, tf.random_uniform(shape=[], maxval=10.))
print("loss op: {0}".format(self.loss))
tf.summary.scalar(self.loss.name, self.loss)
return queue.enqueue(loss)
def main(_):
thread_size = 5
groups = [Group(scope="thread_{0}".format(i)) for i in range(thread_size)]
queue = tf.FIFOQueue(capacity=thread_size * 10,
dtypes=[tf.float32], )
qr = tf.train.QueueRunner(queue, [g.loss_op(queue) for g in groups])
tf.train.queue_runner.add_queue_runner(qr)
loss = queue.dequeue()
mean_loss = tf.reduce_mean([g.loss for g in groups])
init = tf.global_variables_initializer()
with tf.Session() as sess:
init.run()
summaries = tf.summary.merge_all()
summary_writer = tf.summary.FileWriter('log', sess.graph)
sess.graph.finalize()
coord = tf.train.Coordinator()
threads = tf.train.start_queue_runners(sess, coord)
try:
for i in range(100):
if coord.should_stop():
break
_loss, _mean, _summaries = sess.run([loss, mean_loss, summaries])
print("loss: {0:0.2f}, mean loss: {1:0.2f}".format(_loss, _mean))
summary_writer.add_summary(_summaries, global_step=i)
except Exception as e:
coord.request_stop(e)
finally:
coord.request_stop()
coord.join(threads)
if __name__ == '__main__':
tf.app.run()
TensorBoard
Author And Source
この問題について(TensorFlowのQueueを用いた非同期について), 我々は、より多くの情報をここで見つけました https://qiita.com/ashigirl966/items/99b0f8d9713ee90db13a著者帰属:元の著者の情報は、元のURLに含まれています。著作権は原作者に属する。
Content is automatically searched and collected through network algorithms . If there is a violation . Please contact us . We will adjust (correct author information ,or delete content ) as soon as possible .