Celeryのchainで複数の引数をタスクに渡す


Celeryのタスクchainの引数について

  • celeryのタスクをchainさせると、1つ目のタスクの戻り値が次のタスクの引数に入る仕組みになっています。
  • したがって、チェーンするタスク同士のシグネチャはうまく整合させる必要があります。
  • ここで、複数の引数を受け入れるタスクにはどのようにして値を渡すのか疑問です。
chains = chain(task_a.s(), task_b.s())


@shared_task
def task_a():
    # 処理...
    return ???


@shared_task
def task_b(?? a,b):
    # 処理...
    return None
  • シグネチャを合わせるといわれて一番最初に思いつく方法は以下のようなものでしょうか。しかし、うまくいきません。
@shared_task
def task_a():
    # 処理...
    return a, b


@shared_task
def task_b(a, b):
    # 処理...
    return None

> TypeError("task_b() missing 1 required positional argument: 'b'"):

対策1

  • pythonで複数の戻り値をカンマ区切りでreturnすると自動的にタプルとして返され、結果をタプルのアンパッキングを使って受け取れます。
def method():
	return "a", "b"
	
a, b = method()
  • したがって、複数の引数ではなくタプルを受け取るようにして、メソッド内でアンパッキングさせればよいことがわかります。
chains = chain(task_a.s(), task_b.s())


@shared_task
def task_a():
    # 処理...
    return 'a', 'b'


@shared_task
# 注)*argsではない
# Tuple([a, b])となってしまってunpackしづらい
def task_b(args): 
    # 処理...
    a, b = args
    logger.info("args={a} , {b}")
    return None
  • この方法の注意点としては可変長引数『*args』で定義していないことです。
  • 好みの問題かもしれませんがceleryのタスクchainでは可変長引数はListとして扱われるようなのでargs[0]に対してアンパッキングすることになり、何をしているのかわかりづらいです。

対策2

  • 辞書でreturnして引数で受け取る方法も動作します。
chains = chain(task_a.s(), task_b.s())


@shared_task
def task_a():
    # 処理...
    return {'a': 'a', 'b': 'b'}


@shared_task
def task_b(args_dict): # 注)**kwargsではない
    # 処理...
    a, b = args_dict.values()
    logger.info("args={a}, {b}")
    return None

  • この方法の注意点は『**kwargs』のように可変長なキーワード引数としてtask_bを定義するとうまくいきません。

  • メソッドのチェーンで『**kwargs』と書くと、task_bは位置引数を1つも期待していないのに、task_aからreturnされたdict型の位置引数を1つ渡したことになり例外を投げます。

  • もっと良い方法や、間違いがあれば修正いたしますのでコメントお願いします。