Python で Java の CountDownLatch を実装してみる


エンジニアの次のステップへの勉強法を読んでいて、これなんだっけ?とふと思い、復習してみた。

そもそも CountDownLatch って何かは Java のドキュメントを参照ください。

サンプルコード

#coding:utf-8

import threading

class CountDownLatch(object):
    def __init__(self, count):
        assert count > 0, "count は 1 以上の値を与えてください。入力値[{}]".format(count)
        self.count = count
        self.condition = threading.Condition()

    def count_down(self):
        self.condition.acquire()
        self.count -= 1
        if self.count == 0:
            self.condition.notifyAll()
        self.condition.release()

    def await(self):
        self.condition.acquire()
        while self.count > 0:
            self.condition.wait()
        self.condition.release()


def woker1(latch):
    latch.await()
    print("Finished!")

def woker2(latch, n):
    latch.count_down()
    print("Worker: {}".format(n))

if __name__ == "__main__":
    count = 5
    latch = CountDownLatch(count)

    # woker1 は woker2 のスレッドが全て完了してから実行される
    thread = threading.Thread(target=woker1, args=(latch,))
    thread.start()

    for i in range(count):
        thread = threading.Thread(target=woker2, args=(latch, i))
        thread.start()

実行結果

Worker: 0
Worker: 1
Worker: 2
Worker: 3
Worker: 4

Finished!