ゼロから作るDeep Learning② word2vecコードで謎だった場所


謎だった場所

なぜself.paramsにWを代入する際にW(numpy配列)をリストにしてから、わざわざ「W,=」とまたリストの中のself.paramの要素(numpy配列)を取り出すという一見無駄なことをしているのか?

class Embedding:
    def __init__(self, W):
        self.params = [W]
        self.grads = [np.zeros_like(W)]
        self.idx = None

    def forward(self, idx):
        W, = self.params
        self.idx = idx
        out = W[idx]
        return out

【参考】
関数名にカンマを使用する意味についての内容が「W,=」の箇所の理解にとても役立ちました(ありがとうございます!)

その理由は、実は以下のように最初から最後までWをnumpy配列で処理をしても「Embeddingの処理だけなら」動作するが、その後の処理でリストの足し算などが行われるところでエラーとなり全体的なプログラムとしては動作しなくなるためだと理解しました。

# 以下でも「Embeddingの処理だけなら」動作する
class Embedding:
    def __init__(self, W):
        self.params = W
        self.grads = [np.zeros_like(W)]
        self.idx = None

    def forward(self, idx):
        W = self.params
        self.idx = idx
        out = W[idx]
        return out

つまり、以下のコードでも

import numpy as np

class Embedding:
    def __init__(self, W):
        self.params = [W]   #オリジナル
        self.grads = [np.zeros_like(W)]
        self.idx = None

    def forward(self, idx):
        W, = self.params   #オリジナル
        self.idx = idx
        out = W[idx]
        return out


W_in = 0.01 * np.random.randn(7, 3).astype('f')
print(f"W_in =\n{W_in}\n")

layer = Embedding(W_in)
print(f"3番目をEmbeddingしたもの:\n{layer.forward(3)}\n")

params0 = W_in
print(f"W_inをリストにしない場合:\n{params0}\n")
print(f"{type(params0)}\n")

params = [W_in]
print(f"W_inをリストにした場合:\n{params}\n")
print(f"{type(params)}\n")

W = params
print(f"カンマつけた渡し方をしない場合(W = params)のW:\n{W}")
print(type(W))
print("\n")

W, = params
print(f"カンマつけた渡し方をした場合(W, = params)のW:\n{W}")
print(type(W))

以下のコードでも

import numpy as np

class Embedding:
    def __init__(self, W):
        self.params = W   #リスト化しない
        self.grads = [np.zeros_like(W)]
        self.idx = None

    def forward(self, idx):
        W = self.params   #numpy配列のまま処理
        self.idx = idx
        out = W[idx]
        return out


W_in = 0.01 * np.random.randn(7, 3).astype('f')
print(f"W_in =\n{W_in}\n")

layer = Embedding(W_in)
print(f"3番目をEmbeddingしたもの:\n{layer.forward(3)}\n")

params0 = W_in
print(f"W_inをリストにしない場合:\n{params0}\n")
print(f"{type(params0)}\n")

params = [W_in]
print(f"W_inをリストにした場合:\n{params}\n")
print(f"{type(params)}\n")

W = params
print(f"カンマつけた渡し方をしない場合(W = params)のW:\n{W}")
print(type(W))
print("\n")

W, = params
print(f"カンマつけた渡し方をした場合(W, = params)のW:\n{W}")
print(type(W))

結果は同じく以下のようになる。

W_in =
[[-0.00414995 -0.00505246 -0.02271379]
 [-0.00385737 -0.01022162 -0.00621947]
 [-0.01317972  0.00763595  0.00437246]
 [ 0.01119065 -0.01144209  0.02539131]
 [-0.00316145 -0.01609291 -0.00868459]
 [ 0.00361989  0.01507116 -0.00318975]
 [ 0.00530743  0.00881439 -0.01096747]]

3番目をEmbeddingしたもの
[ 0.01119065 -0.01144209  0.02539131]

W_inをリストにしない場合:
[[-0.00414995 -0.00505246 -0.02271379]
 [-0.00385737 -0.01022162 -0.00621947]
 [-0.01317972  0.00763595  0.00437246]
 [ 0.01119065 -0.01144209  0.02539131]
 [-0.00316145 -0.01609291 -0.00868459]
 [ 0.00361989  0.01507116 -0.00318975]
 [ 0.00530743  0.00881439 -0.01096747]]
<class 'numpy.ndarray'>


W_inをリストにした場合:
[array([[-0.00414995, -0.00505246, -0.02271379],
       [-0.00385737, -0.01022162, -0.00621947],
       [-0.01317972,  0.00763595,  0.00437246],
       [ 0.01119065, -0.01144209,  0.02539131],
       [-0.00316145, -0.01609291, -0.00868459],
       [ 0.00361989,  0.01507116, -0.00318975],
       [ 0.00530743,  0.00881439, -0.01096747]], dtype=float32)]
<class 'list'>


カンマつけた渡し方をしない場合(W = params)のW:
[array([[-0.00414995, -0.00505246, -0.02271379],
       [-0.00385737, -0.01022162, -0.00621947],
       [-0.01317972,  0.00763595,  0.00437246],
       [ 0.01119065, -0.01144209,  0.02539131],
       [-0.00316145, -0.01609291, -0.00868459],
       [ 0.00361989,  0.01507116, -0.00318975],
       [ 0.00530743,  0.00881439, -0.01096747]], dtype=float32)]
<class 'list'>


カンマつけた渡し方をした場合(W, = params)のW:
[[-0.00414995 -0.00505246 -0.02271379]
 [-0.00385737 -0.01022162 -0.00621947]
 [-0.01317972  0.00763595  0.00437246]
 [ 0.01119065 -0.01144209  0.02539131]
 [-0.00316145 -0.01609291 -0.00868459]
 [ 0.00361989  0.01507116 -0.00318975]
 [ 0.00530743  0.00881439 -0.01096747]]
<class 'numpy.ndarray'>

しかし後者だとtrain.pyなどを動かしているとエラーが出て止まってしまう。それはおそらく処理の中で複数要素を一つのリストに放り込んでいく処理のところで実施されるリストの足し算が出来ない(numpy配列とリストの足し算が出来ない)ためである。
イメージ的には、以下でprint(a+c)が実行できないのと同じ。

a = np.array([[1,2,3,4,5],[6,7,8,9,0]])
b = np.arange(10).reshape(2,5)

print(a)
print(b)

print(a+b) #numpy配列の要素の数が合っているので計算できる

lista = [a]
listb = [b]

print(lista+listb)

c = np.arange(30).reshape(6,5)

print(a)
print(c)

# print(a+c) #numpy配列の要素の数が合っていないので計算できない

lista = [a]
listc = [c]

print(lista+listc) #リスト同士にするとリストの0番目にa、1番目にbが入るので、足し算が出来る

一応出力を書いておくとこんな感じ

# print(a)
[[1 2 3 4 5]   
 [6 7 8 9 0]]

# print(b)
[[0 1 2 3 4]   #print(b)
 [5 6 7 8 9]]

# print(a+b) 
[[ 1  3  5  7  9]   
 [11 13 15 17  9]]

# print(lista+listb)
 [array([[1, 2, 3, 4, 5],
       [6, 7, 8, 9, 0]]), array([[0, 1, 2, 3, 4],
       [5, 6, 7, 8, 9]])]

# print(a)
[[1 2 3 4 5]
 [6 7 8 9 0]]

# print(c)
 [[ 0  1  2  3  4]
 [ 5  6  7  8  9]
 [10 11 12 13 14]
 [15 16 17 18 19]
 [20 21 22 23 24]
 [25 26 27 28 29]]

# print(lista+listc)
 [array([[1, 2, 3, 4, 5],
       [6, 7, 8, 9, 0]]), array([[ 0,  1,  2,  3,  4],
       [ 5,  6,  7,  8,  9],
       [10, 11, 12, 13, 14],
       [15, 16, 17, 18, 19],
       [20, 21, 22, 23, 24],
       [25, 26, 27, 28, 29]])]