pythonピット:ループ内のlambda変数ドメインの問題

12954 ワード

最近廖雪峰先生のpythonチュートリアルを振り返ると、「高次関数」というセクションにfilter関数とジェネレータを利用して素数を求めるコードがあり、ここでは理解とデバッグを容易にするために簡略化されています.
def func(n):
    return lambda x: x % n > 0

def primes():
    it = (i for i in np.arange(3, 20, 2))
    while True:
        n = next(it)
        yield n
        it = filter(func(n), it)
        
print(list(primes()))
:[3, 5, 7, 11, 13, 17, 19]
コードは簡単に見えます.func()の機能は匿名関数を返し、入力値がnより大きいかどうかを判断することです.prime()は素数を返すジェネレータであり、実装プロセスは偶数を含まないジェネレータを定義し、次に素数を返すたびに、この素数によってジェネレータから除去できる要素をfilterで排除する.ここを見たとき、なぜ単独でfunc()を定義するのか、どうせ戻ってくるのもlambdaなので、いっそfunc()の内容をfilterに入れてpythonを書くのはもちろんコードを簡潔にすればするほどいいのではないかと思っていました.しかし、悪夢が始まった.
def primes():
    it = (i for i in np.arange(3, 20, 2))
    while True:
        n = next(it)
        yield n
        it = filter(lambda x: x % n > 0, it)
        
print(list(primes()))
:[3, 5, 7, 9, 11, 13, 15, 17, 19]
しばらく心が冷めて、穴を踏んだ以上、穴をゆっくり埋めましょう.まずfilter()の使い方を思い出します.関数とシーケンスを受信し、各要素に順次作用し、返されたブール値に基づいて要素を保持するかどうかを決定します.フィルタを使用して偶数のフィルタを実装してみます.
L = list(filter(lambda s: s % 2 == 0, (i for i in range(5))))
print(L)
:[0, 2, 4]
その結果,filter()の論理は簡潔であり,匿名関数とジェネレータはfilter()関数パラメータとしても問題ない.しかし、これまでの結果を振り返ると、yield 3以降filterは3で除去できる元素を排除する役割を果たしておらず、9と15は依然として最終結果に現れているため、lambdaでは問題があったと推測される.簡単なlambdaに変えてもう一度試してみましょう.
def primes():
    it = (i for i in np.arange(3, 20, 2))
    while True:
        n = next(it)
        yield n
        it = filter(lambda x: x < 15, it)
        
print(list(primes()))
:[3, 5, 7, 9, 11, 13]
結果は正しい.なぜこのような問題が発生したのか、私は多くの資料を探して発見しましたが、実は公式のFAQでこのようなpythonの穴を紹介しています.公式の例を見てみましょう.
squares = []
for x in range(5):
    squares.append(lambda: x**2)
    
print([i() for i in squares])
:[16, 16, 16, 16, 16]
前の例と似ているような気がしますが、結果は同じ不思議ですか?0から4の平方を伝えているのに、なぜ出てきたのは4の平方ばかり?公式の解釈は以下の通りです.
This happens because x is not local to the lambdas, but is defined in the outer scope, and it is accessed when the lambda is called — not when it is defined. At the end of the loop, the value of x is 4, so all the functions now 16.
公式には、xの値を変更し続け、lambdaの結果がどのように変化しているかを確認することで、これを検証することもできます.
x = 8
squares[2]()
:64
このような状況を回避するために、グローバル変数(元のコードのfunc(n)関数がこの役割を果たす)の代わりに、外部数値をパラメータとしてlambdaに渡す必要があります.コードは以下の通りです.
squares = []
for x in range(5):
    squares.append(lambda n=x: n**2)

print([i() for i in squares])
:[0, 1, 4, 9, 16]
同様に、以前のlambdaも次のように変更できます.
def primes():
    it = (i for i in np.arange(3, 20, 2))
    while True:
        n = next(it)
        yield n
        it = filter(lambda x, n=n: x % n > 0, it)
        
print(list(primes()))
:[3, 5, 7, 11, 13, 17, 19]