[Pytorch]行列(テンソル)の積を楽にするeinsumの活用


記事を読んで何か指摘があれば、遠慮なくどうぞ。
いいねもしてもらえると励みになります。

einsumとは

  • einsumというのは、numpypytorchで実装されているアインシュタインの縮約記法というものです。(日本語でこのように呼ぶことは知らなかった。)
  • 複雑なテンソル積の演算に対して、とても意図的に操作を可能にします。

背景

  • einsumを使うまでは、pytorchでのテンソルの積の演算方法として、torch.bmmtorch.matmul(今思いついたものだけ)などを使用していました。
  • しかし、上記であげた積の演算方法には、あらかじめテンソルの次元が決められています。(例えば、3次元テンソルx2次元テンソル、2次元テンソルx1次元テンソルなど)
  • これでは、毎回ドキュメントを逐一確認して、次元に合わせた関数(メソッド)を使用しなければいけないのが大変でした。

解説

パターン1

ディープラーニングでありがちな演算

import torch as t
X = t.rand(3,10,5)
Y = t.rand(3,20,5)
  • ミニバッチが$3$
  • 行列Xのサイズが、$10\times5$
  • 行列Yのサイズが、$20\times5$

こんな時にミニバッチ内の各行列同士の積を計算して、$10\times20$の大きさの行列を計算したい。そして、バッチのまま返して欲しい場合があります。

  • つまり、XY の演算結果として、$3\times10\times20$というテンソルを返して欲しい
  • 結果として、次のように計算すれば良い
t.einsum('bnm,bkm->bnk',X,Y).size()                      
>> torch.Size([3, 10, 20])
  • ちなみに、$3\times10\times20$ではなくて、$3\times20\times10$を返して欲しい時
  • 'bnm,bkm->bkn'ここの部分の違いに注目
t.einsum('bnm,bkm->bkn',X,Y).size()                      
>> torch.Size([3, 20, 10])

パターン2

行列とベクトルの積を計算することもできます。簡単に。

X = t.rand(3,10,5)                                       
Y = t.rand(3,5)                                          
t.einsum('bnm,bm->bn',X,Y).size()                        
>> torch.Size([3, 10])

まとめ

  • einsumを使うことで複雑なテンソルの積がとても簡単に表現できます。torch.transposetorch.viewtorch.squeezeなどで次元を無理やり合致させていた人はこれでおさらばです。