[Pytorch]行列(テンソル)の積を楽にするeinsumの活用
記事を読んで何か指摘があれば、遠慮なくどうぞ。
いいね
もしてもらえると励みになります。
einsumとは
-
einsum
というのは、numpy
やpytorch
で実装されているアインシュタインの縮約記法
というものです。(日本語でこのように呼ぶことは知らなかった。)
- 複雑なテンソル積の演算に対して、とても意図的に操作を可能にします。
背景
-
einsum
を使うまでは、pytorch
でのテンソルの積の演算方法として、torch.bmm
やtorch.matmul
(今思いついたものだけ)などを使用していました。
- しかし、上記であげた積の演算方法には、あらかじめテンソルの次元が決められています。(例えば、3次元テンソルx2次元テンソル、2次元テンソルx1次元テンソルなど)
- これでは、毎回ドキュメントを逐一確認して、次元に合わせた関数(メソッド)を使用しなければいけないのが大変でした。
解説
パターン1
einsum
というのは、numpy
やpytorch
で実装されているアインシュタインの縮約記法
というものです。(日本語でこのように呼ぶことは知らなかった。)einsum
を使うまでは、pytorch
でのテンソルの積の演算方法として、torch.bmm
やtorch.matmul
(今思いついたものだけ)などを使用していました。ディープラーニングでありがちな演算
import torch as t
X = t.rand(3,10,5)
Y = t.rand(3,20,5)
- ミニバッチが$3$
- 行列Xのサイズが、$10\times5$
- 行列Yのサイズが、$20\times5$
こんな時にミニバッチ内の各行列同士の積を計算して、$10\times20$の大きさの行列を計算したい。そして、バッチのまま返して欲しい場合があります。
- つまり、XY の演算結果として、$3\times10\times20$というテンソルを返して欲しい
- 結果として、次のように計算すれば良い
t.einsum('bnm,bkm->bnk',X,Y).size()
>> torch.Size([3, 10, 20])
- ちなみに、$3\times10\times20$ではなくて、$3\times20\times10$を返して欲しい時
- 'bnm,bkm->bkn'ここの部分の違いに注目
t.einsum('bnm,bkm->bkn',X,Y).size()
>> torch.Size([3, 20, 10])
パターン2
行列とベクトルの積を計算することもできます。簡単に。
X = t.rand(3,10,5)
Y = t.rand(3,5)
t.einsum('bnm,bm->bn',X,Y).size()
>> torch.Size([3, 10])
まとめ
-
einsum
を使うことで複雑なテンソルの積がとても簡単に表現できます。torch.transpose
やtorch.view
、torch.squeeze
などで次元を無理やり合致させていた人はこれでおさらばです。
Author And Source
この問題について([Pytorch]行列(テンソル)の積を楽にするeinsumの活用), 我々は、より多くの情報をここで見つけました https://qiita.com/marusta/items/e0dbef7b6c900cb3b785著者帰属:元の著者の情報は、元のURLに含まれています。著作権は原作者に属する。
Content is automatically searched and collected through network algorithms . If there is a violation . Please contact us . We will adjust (correct author information ,or delete content ) as soon as possible .