Pythonのジェネレータってyieldするだけじゃなかったんだね

28084 ワード

こんにちわ alivelimb です。
Pythonista のみなさん、TypeHint(型ヒント)書いてますか?私は基本的に書くようにしています。今回は TypeHint で ジェネレータ(Generator) を書く時に出てきた疑問とそれついて調べたことを紹介します

もし使い捨てコードでなく、Python3.5 以上を利用されているにも関わらず TypeHint を書いていない方がいれば、是非書くことを検討してください。TypeHint をつけることで、コード理解がしやすくなり、IDE の予測変換も活用できるので開発効率向上が期待できます。TypeHint の恩恵を最大限受ける VSCode の環境構築については記事にしているので、参照してみてください。

VSCode と Poetry で作る Python 開発環境

そもそも ジェネレータ とは

公式ドキュメントには以下の記述があります。

generator iterator を返す関数です。 通常の関数に似ていますが、 yield 式を持つ点で異なります。 yield 式は、 for ループで使用できたり、next() 関数で値を 1 つずつ取り出したりできる、値の並びを生成するのに使用されます。

なるほど。generator iteratorについては

yield のたびに局所実行状態 (局所変数や未処理の try 文などを含む) を記憶して、処理は一時的に中断されます。 ジェネレータイテレータ が再開されると、中断した位置を取得します (通常の関数が実行のたびに新しい状態から開始するのと対照的です)。

という記載があります。これは何が嬉しいのでしょうか。StackOverflowでは「遅延評価ができるようになる」との回答ありました。同回答で紹介されている検索システムの具体例を交えて紹介してみます。

ユーザテーブルを SQLite で作成し、指定した年齢以下のユーザの ID 一覧を検索する例で考えてみましょう。まずはユーザテーブルを作成します。

# ユーザテーブル作成
def init_users_table(n: int) -> None:
    # n: ダミーユーザ数
    conn = sqlite3.connect("sample_users.db")
    cursor = conn.cursor()

    # テーブル作成
    cursor.execute("CREATE TABLE users(user_id, age)")

    # ダミーユーザ追加
    for _ in range(n):
        user_id = str(uuid4())
        age = randint(0, 100)
        cursor.execute("INSERT INTO users VALUES (?, ?)", (user_id, age))

    conn.commit()
    conn.close()

このテーブルで検索を行いますが、検索結果をリストで受け取るようにすると以下のようになるでしょうか。

def get_user_ids(age: int) -> List[str]:
    conn = sqlite3.connect("sample_users.db")
    cursor = conn.cursor()

    cursor.execute("SELECT user_id FROM users WHERE age <= ?", (age,))

    return cursor.fetchall()

リストの場合、条件に合うレコードを全てメモリに載せることになります。今回の例ではダミーユーザ数が相当大きくならないとメモリエラーにはなりません。しかし、取得するカラム数が増えていけばメモリは圧迫されてしまいます。

こんな時にジェネレータが役に立ちます。

def get_user_ids_generator(age: int) -> Generator[str, None, None]:
    conn = sqlite3.connect("sample_users.db")
    cursor = conn.cursor()

    cursor.execute("SELECT user_id FROM users WHERE age <= ?", (age,))

    for user_id in cursor.fetchall():
        yield user_id

TypeHint の通り、リストではなくジェネレータイテレータを返します。ジェネレータイテレータは__next__メソッドが呼ばれるとジェネレータのyieldを 1 回実行した後、次のyieldで処理を中断するので、条件に一致する全てのユーザ ID ではなく、1 つずつ取り出すことが可能です。(for 文は内部的に__next__を呼んでいるようです)。

ジェネレータやイテレータの詳細については以下の記事が参考になったので、適宜参照してください。

typing.Generator

さて、本記事の本題です。公式ドキュメントには以下のような記述があります。

Generator[YieldType, SendType, ReturnType]

ここで私が思ったのは「YieldTypeyieldする値の型ってわかるけど、SendTypeって何?あとReturnTypeってあるけど、ジェネレータってreturn返せるんだっけ?」ということです。具体例として以下の記述があります。

def echo_round() -> Generator[int, float, str]:
    sent = yield 0
    while sent >= 0:
        sent = yield round(sent)
    return 'Done'

このコードを見た時の疑問をまとめてみると

  • yieldが変数に代入されてるのは何故?
  • sentが 0 以上になる時ってどういう時?
  • returnの時ってどうなるの?

ということでした。この疑問を解消するのが本記事の目的です。1 つずつ紐解いていきましょう。

SendType

まずはSendType です。今まで知らなかったですが、以下のようにジェネレータには値を渡すことができるようです。

# ジェネレータを作成
gen = make_generator()

# ジェネレータに値を渡す
gen.send(value)

SendType.sendで渡す、値の型である」

ということになりますね。公式ドキュメントで紹介されているカウンターの例で見てみます。

def counter(maximum: int) -> Generator[int, int, None]:
    i = 0
    while i < maximum:
        # sendされた値をvalとして受け取る
        val = yield i

        # sendされていれば、カウント(i)をvalにする
        if val is not None:
            i = val

        # sendされていなければ、カウント(i)を1進める
        else:
            i += 1


it = counter(10)
print(next(it))    # 0
print(next(it))    # 1
print(it.send(8))  # 8
print(next(it))    # 9

なるほど、.sendを使うことでジェネレータに値を入力し、カウントを進められていますね。「.sendの次のnextは 9 じゃなくて 8 にじゃない?」と思ったのですが、違いました。同じように思った方は以下の検証を参照してください。


def counter_next(it: Generator[int, int, None]) -> None:
    print("next")
    print("-" * 10)
    value = next(it)
    print(f"value: {value}")
    print("-" * 10)
    print()


def counter_send(it: Generator[int, int, None], v: int) -> None:
    print("send")
    print("-" * 10)
    value = it.send(v)
    print(f"value: {value}")
    print("-" * 10)
    print()


def counter(maximum: int) -> Generator[int, int, None]:
    i = 0
    while i < maximum:
        print(f"before count: {i}")
        val = yield i
        if val is not None:
            i = val
        else:
            i += 1

        print(f"after count: {i}")


it = counter(10)
counter_next(it)
counter_next(it)
counter_send(it, 8)
counter_next(it)

実行結果

next
----------
before count: 0
value: 0
----------

next
----------
after count: 1
before count: 1
value: 1
----------

send
----------
after count: 8
before count: 8
value: 8
----------

next
----------
after count: 9
before count: 9
value: 9
----------

ジェネレータはyieldで一時的に関数を抜ける形になるので、after と before が逆になってます。2 回目以降のnextは前に実行したnextの後処理から始まって最後にyieldになります。そのためnextを実行すると、先に i がインクリメントされてからyieldされるイメージです。

ReturnType

次にReturnTypeです。公式ドキュメントに以下の記載がありました。

ジェネレータ関数の中では、return value は __next__() メソッドから送出された StopIteration(value) を引き起こします。これが発生した場合や、関数の終わりに到達した場合は、値の生成が終了してジェネレーターがそれ以上の値を返さない。

これを整理すると

  • ジェネレータ内でreturnするとStopIterationという例外を吐く
  • ジェネレータ内でreturnすると関数の終わりに到達した場合と同様にジェネレータがそれ以上値を返さない
  • StopIterationには値(return value)を持たせることが出来て、この値の型がReturnType

ということになるでしょうか。実際に試してみましょう。

def counter(maximum: int) -> Generator[int, int, int]:
    i = 0
    while i < maximum:
        val = yield i
        if val is not None:
            if val >= 0:
                i = val
            else:
                return -1
        else:
            i += 1

    return 0

先ほどのカウンターの例にreturnを追加して

  • カウントが最後まで行ったら 0 を返す
  • カウントが 0 未満で.sendで書き換えられたら-1を返す

ように修正してみました。まず0を返すケースから検証します

it = counter(3)
print(next(it))  # 0
print(next(it))  # 1
print(next(it))  # 2
print(next(it))  # StopIteration: 0 (例外発生)

想定通り、カウントが最後までいくとStopIterationが発生し、値に0が入っていますね。次に-1を返すケースを検証してみます。

print(next(it))  # 0
try:
    it.send(-10)  # 例外発生
except StopIteration:
    print(traceback.format_exc()) # StopIteration: -1
print(next(it)) # StopIteration: 例外発生

こちらも想定通りStopIterationが発生し、値に0が入っています。また、1 回StopIterationが発生した後にもう一度next()で取り出そうとすると、同じくStopIterationが発生しますが、値は何も入らないようです。

まとめ

TypeHint 発端でジェネレータについて調べてみました。今まで知らなかったジェネレータの機能を知れて面白かったです。正直「ジェネレータでSendTypeReturnTypeなんて使う時あるのかな」と思っていたのですが、それは私が Python で非同期処理をほとんど書いたことがないからでしょうか。本記事では触れていないyield fromなども含めて非同期処理についても今後勉強してみたいと思います。