Pythonで並列処理のすすめ

35865 ワード

Python(CPython) の並列処理で適切にパフォーマンス改善する方法を解説します。

一般に並列処理で思い付くのはスレッド処理ですが、Python には GIL(グローバル・インタプリタ・ロック)と呼ばれるロック機構があるため、その方法ではパフォーマンス改善を望めないことがあります。
なぜならば Python コードを実行できるのは GIL を保持したスレッドだけなので複数スレッドで並列処理をしたつもりでも実際に実行されるのは1スレッドだけだからです。

当記事はスレッドベースの並列処理の問題とそれに代わる方法であるプロセスベースの並列処理のメリットデメリットをまとめました。ちなみに検証には Python 3.9.12 / Ubuntu 20.04.4 LTS を利用しています。

GIL の影響を見てみる

以下のような素朴なアルゴリズムで素数をカウントする関数を用意しました。
並列処理から実行結果を受け取るために return ではなく Queue を利用しています。

from queue import Queue

def count_primes(num: int, queue: Queue) -> None:
    primes = 0
    for i in range(2, num + 1):
        for j in range(2, i):
            if i % j == 0:
                break
        else:
            primes += 1
    queue.put(primes)

まずは以下のようにスレッドを1つ生成して実行します。

from threading import Thread

queue1 = Queue()
thread1 = Thread(target=count_primes, args=(100000, queue1))
thread1.start()
thread1.join()
print(queue1.get())

10万までの素数をカウントしたら30秒かかりました。

$ time python count_primes_thread.py 
9592

real    0m30.982s
user    0m30.965s
sys     0m0.010s

次は2つのスレッドで同時に実行します。

queue1 = Queue()
thread1 = Thread(target=count_primes, args=(100000, queue1))
thread1.start()

queue2 = Queue()
thread2 = Thread(target=count_primes, args=(100000, queue2))
thread2.start()

thread1.join()
print(queue1.get())
thread2.join()
print(queue2.get())

結果はほぼ2倍で約1分間かかりました。

$ time python count_primes_thread.py 
9592
9592

real    1m1.449s
user    1m1.665s
sys     0m0.730s

実行環境は複数コアのCPUなので並列処理できるはずですが、GIL の影響で同時実行が制限されています。

では Python プロセスベースの並列処理でパフォーマンスを改善する方法を考えます。

プロセスベースの並列処理

Python の GIL はプロセス単位で作用するので fork して生成された子プロセスは親プロセスの GIL の影響を受けずに並列処理できます。
そこで利用するのが multiprocessing モジュールで、スレッドとよく似たインターフェースでプロセスを扱えます。

ちなみに Python にはプロセスベースの並列処理における子プロセスの生成方法がいくつか用意されていますが、当記事では Linux 環境のデフォルトである fork を前提としています。

先ほどのスレッド(threading)を利用したプログラムを multiprocessing モジュールで書き換えてみます。
とは言っても、両者はほぼ共通のインターフェースを持つため、この場合は import するクラスを置き換えるだけの修正で動きます。

from multiprocessing import Process, Queue

def count_primes(num: int, queue: Queue) -> None:
    primes = 0
    for i in range(2, num + 1):
        for j in range(2, i):
            if i % j == 0:
                break
        else:
            primes += 1
    queue.put(primes)

queue1 = Queue()
process1 = Process(target=count_primes, args=(100000, queue1))
process1.start()

queue2 = Queue()
process2 = Process(target=count_primes, args=(100000, queue2))
process2.start()

process1.join()
print(queue1.get())
process2.join()
print(queue2.get())

実行すると先ほど1分間以上かかった処理が34秒で終了し、約2倍に高速化しました。

$ time python count_primes_process.py 
9592
9592

real    0m34.479s
user    1m8.574s
sys     0m0.010s

プロセスベースの並列処理なら Python コードも並列実行できることが分かります。

スレッドと似たインターフェースで扱いやすいのは良いのですが、実際にはプロセスとスレッドは本質的に異なるものであり不適切な利用をすると望まぬ結果になります。そこで次はプロセスベースの並列処理でハマりやすいポイントを解説します。

スレッドはメモリ空間を共有、プロセスは独立

スレッドは親プロセスとメモリ空間を共有するので、例えばグローバル変数の値を変更すればスレッドの呼び出し元もその変更結果を取得できます。
一方で子プロセスは fork のタイミングで親プロセスからコピーされたメモリ空間を持つため、親プロセスの持っていた変数を参照することができますが、子プロセス側で変数を変更してもその子プロセスが終了すれば破棄されて親プロセスには影響しません。

実際に試してみます。

from multiprocessing import Process
from threading import Thread

global_value = 0

def worker() -> None:
    global global_value
    global_value += 1
    print(f"in worker               : {global_value=}")

# スレッドを生成・実行開始して終了まで待つ
print(f"before thread execution : {global_value=}")
thread = Thread(target=worker)
thread.start()
thread.join()
print(f"after thread execution  : {global_value=}")

# プロセスを生成・実行開始して終了まで待つ
print(f"before process execution: {global_value=}")
process = Process(target=worker)
process.start()
process.join()
print(f"after process execution : {global_value=}")

実行結果は以下のようになりました。

$ time python thread_process.py 
before thread execution : global_value=0
in worker               : global_value=1
after thread execution  : global_value=1
before process execution: global_value=1
in worker               : global_value=2
after process execution : global_value=1

スレッドベースの並列処理の方は、スレッドが起動して global_value1 加算した結果を呼び出し側でも受け取れました。
次にプロセスベースの並列処理の方は、worker 関数内の出力は global_value=2 となっているのでグローバル変数の値を参照することはできています。しかし worker 関数が終了(子プロセスが終了)して親プロセスに制御が戻ると global_value=1 に戻ってしまいました。なぜならば global_value=2 に更新されたのは子プロセスの中だけで、親プロセス側の global_value は影響を受けないからです。

ちなみにこのサンプルプログラムは1つのプログラム内でスレッドとプロセスを両方とも生成していますが、プロセス生成のタイミングで複数スレッドが動作している状況は本質的に安全ではない点は注意してください。プロセスベースの並列処理を起動する前にスレッドを終了しておくと安全です。※公式ドキュメントのforkに関する説明を参照

データの受け渡しはプロセス間通信

次はデータの受け渡しに利用した Queue クラスについて考えてみます。

スレッドの場合はメモリ空間を共有しているので、データのやりとりは同じメモリ空間内で直接行うことができます。
一方でプロセスベースの並列処理で生成される子プロセスは親子関係こそあるものの相互に独立したプロセスなので、データのやりとりにはプロセス間通信を利用してバイト列を送受信します。

スレッドとプロセスは両方とも似たインターフェースを持つ Queue クラスでやりとりしましたが、実はその内部実装は大きく異なるのです。

動作原理の違いを理解するために、まずはスレッドから defaultdict(lambda: 1) を受け取ってみます。
defaultdict は存在しないキーへアクセスされた場合のデフォルト値を callable の実行結果とすることができるので、lambda と連携してよく使われます。

以下のコードは指定した数値までに含まれる素数をキーに設定した defaultdict を返却します。

from collections import defaultdict
from threading import Thread
from queue import Queue

def get_primes(num: int, queue: Queue) -> None:
    primes = defaultdict(lambda: 1)
    for i in range(2, num + 1):
        for j in range(2, i):
            if i % j == 0:
                break
        else:
            primes[i]
    queue.put(primes)

queue = Queue()
thread = Thread(target=get_primes, args=(100000, queue))
thread.start()
thread.join()
primes = queue.get()
print(sum(primes.values()))

スレッドベースの並列処理では問題なく defaultdict を受け取ることができました。

では次にプロセスベースの並列処理で同じことをやってみます。

from collections import defaultdict
from multiprocessing import Process
from multiprocessing import Queue

def get_primes(num: int, queue: Queue) -> None:
    primes = defaultdict(lambda: 1)
    for i in range(2, num + 1):
        for j in range(2, i):
            if i % j == 0:
                break
        else:
            primes[i] = 1
    queue.put(primes)

queue = Queue()
process = Process(target=get_primes, args=(100000, queue))
process.start()
process.join()
primes = queue.get()
print(sum(primes.values()))

しかしこれは残念ながら実行すると永遠に終了しません。

multiprocessing.Queue はプロセス間通信を行うクラスなので受け渡すデータをシリアライズしてバイト列に変換し、受け取り側はデシリアライズします。しかし defaultdict に渡した lambda: 1 はシリアライズできません。
子プロセス側は素数の計算処理を実行して Queue へデータを受け渡しますが、そのタイミングで失敗してしまいます。そして親プロセス側は primes = queue.get()Queue からのデータを待つのでストールします。

シリアライズには pickle を利用するので pickle が対応しているオブジェクトはシリアライズが可能です。例えば以下のように lambda の代わりに def を利用して、子プロセスを生成する前に定義しておけば問題なく処理できるようになります。

def return_one():
    return 1

def get_primes(num: int, queue: Queue) -> None:
    primes = defaultdict(return_one)
    ...

子プロセスにデータを渡す方法

プロセスベースの並列処理で子プロセスを作成する負荷(つまり fork の負荷)はそれほど高くありません。そこで親から子に情報を受け渡す場合は親プロセス側で渡したいデータを生成した後に子プロセスを生成する方法がパフォーマンス的に優れています。
ちなみに子プロセスはリソースをコピーするものの、CoW(コピーオンライト)という仕組みのおかげで読み取りだけなら実際のメモリ領域を消費しないため、その面でもデメリットはありません。

Processargs で2GBの巨大な文字列を渡してみます。

from multiprocessing import Process, Queue

def worker(target: str, queue: Queue):
    queue.put(len(target))

# 2GB相当の文字列を生成
huge_str = "A" * 1 * 1024 * 1024 * 1024 * 2

queue = Queue()

# 2GB相当の文字列を渡して子プロセスを開始
process = Process(target=worker, args=(huge_str, queue))
process.start()
process.join()

# 実行結果の受け取り
print(queue.get())

巨大なデータですが、親プロセスが保持するメモリ空間をそのまま持つ子プロセスは実質コピーすることなくデータの受け渡しが可能なので1秒もかからず処理が終了しました。

2147483648

real    0m0.797s
user    0m0.359s
sys     0m0.439s

次はプロセスを生成した後に Queue で子プロセスに2GBのデータを送信したところ 30秒以上かかってしまいました。

from multiprocessing import Process, Queue

def worker(queue: Queue):
    queue.put(len(queue.get()))

queue = Queue()

# 子プロセスを開始
process = Process(target=worker, args=(queue,))
process.start()

# 2GB相当の文字列を生成して子プロセスに送信
huge_str = "A" * 1 * 1024 * 1024 * 1024 * 2
queue.put(huge_str)
process.join()

# 実行結果の受け取り
print(queue.get())
2147483648

real    0m30.700s
user    0m5.584s
sys     0m26.285s

親プロセス側はシリアライズして子プロセスにデータを送信し、子プロセスはそれをデシリアライズして受け取るという流れで、同じ変数に対するメモリ領域を何度も確保してコピーした結果パフォーマンスが劣化しました。

親プロセスから子プロセスへデータを渡す場合は親プロセス側でデータを作成した後に子プロセスを生成する方法が CoW を効率的に利用できて良さそうです。
ただし子プロセスから親プロセスへデータを渡す場合は同じ方法を使うことができないので、multiprocessing.Queue などを利用したプロセス間通信に頼らざるをえません。

まとめ

Python コードを並列処理したい場合はプロセスベースでやりましょう、でも注意点もありますよ。という内容でした。
スレッドを使った並列処理を実装して、「なぜか速くならないな・・・」と思っている方に届けば幸いです。

ちなみにスレッドで問題になった GIL はシステムコール実行中は解放されます。そのためストレージIOやDB操作がボトルネックならばスレッドによる並列処理でもある程度のパフォーマンス改善が可能です。
スレッドならば扱いやすいので両者の特徴をしっかり考えた上で適切な並列処理の方法を選びましょう。