[深層学習]4000倍早いTransformer, Self-Attentionの計算量がO(n^2)からO(n)になった[論文解説]


Attentionを爆速にした論文Transformers are RNNsを解説

こんにちはYosematです!
今回は長いこと計算時間が問題になっていたAttentionが爆速になってしまったという論文Transformers are RNNsを解説していきます。
今後も論文解説を続けていきますのでぜひTwitterとQiitaをフォローしてください!モチベ上がります!

忙しい人向け

  • Attentionの計算に内積を使うのをやめてカーネル関数を使う
  • Self-Attentionの計算オーダーが$O(n^2)>>O(n)$になった
  • 計算は爆速になったけどパフォーマンスはcompetetive!

Attention

Transformerでお馴染みのAttention。最初は自然言語の王様でしたが、最近は画像の認識や生成タスクでも猛威を奮っている様子で、次世代のDNNの基本的な構成要素となっていきそうですよね。

今更聴けないAttentionの仕組みについておさらいしましょう。知ってるぞー人はスキップ!

Attention機構はQuery, Kεy, Valueという3つの情報を使ってPythonでいうDictのような情報の取り出しを行うものです。

Pythonではよく

dic = {key_1: value_1, key_2: value_2, ..., key_n: value_n}

のようなdictに対してkey = queryであるような情報をとってこようとすれば

dic[query]

などと書くことになります。

Attention機構はまさにこのKey, Query、 Valueを使ったDictionaryです。

ただしQueryが完全にKeyと一致するような情報を抜き取ってくるのではなく、QueryとKeyが似ている情報を類似度に合わせて抜き取ってきます。

数式を見てみましょう。いろいろあるんですがだいたいAttentionは

Attention(Q, K, V) = sortmax(\frac{QK^T}{\sqrt{d_{key}}})V\\

Q \in R^{n_{query} \times d_{key}}, K \in R^{n_{seq}\times d_{key}}, V \in R^{n_{seq}\times d_{model}}

と書かれます。$V$は単なるデータでしかありません。Attentionで大事なのは主に$QK^T$の部分です。

ここでは$QK^T$の1行の計算$q_iK^T$に注目しましょう。

$Q$のとある行ベクトル$q$と$K$との内積にsoftmaxをとる操作は正規化された重みベクトルを作ります。その大きさをヒストグラムにしたら図の右下のような感じになります。

このヒストグラムの高さはqueryとkeyの各行とがどのくらい似ているのかを示します。(ベクトルの標準内積は同じ方向を向いたベクトルほど大きくなりますよね!)

この重みを使って$V$の各行に対して重み付和を取るのがAttention機構なのです。i行目に関して式を書くと

Attention(Q, K, V)_i = \frac{\sum_{j=1}^n\exp(q_i^Tk_j)\cdot v_j}{\sum_{j=1}^n\exp(q_i^Tk_j)}

となります。

Attention Weightが図のヒストグラムのような大きさになった場合は

Attention(Q, K, V)_i = 0.1\times v_1 + 0.5\times v_2 + 0.15\times v_3 + 0.25\times v_4

みたいな感じのベクトルが生成されますね!

この計算をQの全ての行に対して施すので出来上がるのは$n_{query}\times d_{model}$次元の行列です。

次元の話

登場人物である$Q \in R^{n_{query}\times d_{key}}$, $K\in R^{n_{seq}\times d_{key}}$, $V\in R^{n_{seq}\times d_{model}}$の次元について簡単に触れておきます。

Attentionは辞書オブジェクトにKeyとともに格納されたValueにQueryでアクセスする操作だと言いました。これを思い出せば各次元の意味もはっきり理解できるはずです。

dic = {key_1: value_1, key_2: value_2, ..., key_n: value_n}

まずQの行の数$n_{query}$はqueryを投げる数です。辞書dicに対してdic[query]のような取り出し操作を何回行うかを示しています。

Query, Keyの各行ベクトルの次元である$d_{key}$は辞書オブジェクトのキーの型を示しています。そう思えばQueryの各行ベクトルの次元も同じく$d_{key}$であることには納得がいくはずです。dic[query]として辞書の内容を取得するのにqueryとkeyの型が違ってはおかしいですからね。

Key, Valueの行の数を示す$n_{seq}$は辞書オブジェクトの要素数です。格納してあるデータの数だけkey, valueがあるはずなので$K$, $V$は行の数を共有しています。

Vの行ベクトルの次元である$d_{model}$は格納しておくべきデータの型です。辞書にどんなデータを格納するかはあなたの自由ですからd_modelはFully-Connected層のニューロンの数のように自由に決定することができるハイパーパラメータです。

だいたいどこの次元がどのような意味を持っているかを理解できたでしょうか。

Attentionは遅い

Attention機構の説明はこれくらいにして本題に入ります。
Attention機構を使ったTransformerが自然言語タスクの王者に就いて以来ずっと猛威を振るっていますがTransformerの良いところは実は計算が速いことでした。

上の図は自然言語におけるTransformerの計算コストで、Self-Attentionの計算コストは$O(n^2\cdot d)$, RNNは$O(n\cdot d^2)$。ただしnはシーケンス長、dは単語の分散表現の次元数です。

自然言語の分散表現の次元は数百、数千にもおよぶのに対して、シーケンス長はせいぜいが数十といったところでしょう。つまり通常$n << d$なのでSelf-AttentionはRNNなどよりもはるかに高速なのです。

しかし、シーケンス長$n$に対して$O(n^2)$という計算コストが残っています。自然言語ならまだしもシーエンス長が1000, 10000と大きくなってしまうようなタスクにおいてはこの計算コストは大きすぎます

Linear Attention

今回紹介する論文はAttentionの計算コストが$O(n^2)$から$O(n)$になったLinear Attentionという新しいAttention機構を提案しています。

これまでのAttention機構の数式で表すと

Attention(Q, K, V) = sortmax(\frac{QK^T}{\sqrt{d_{key}}})V

でした。

第i行に注目するなら

Attention(Q, K, V)_i = \frac{\sum_{j=1}^n\exp(q_i^Tk_j)\cdot v_j}{\sum_{j=1}^n\exp(q_i^Tk_j)}

だったはず。

ここで$QK^T$の各要素は$Q$のとある行$q_i$と$K$のとある行$k_j$との類似度です。

つまりこの計算は

Attention(Q, K, V)_i = \frac{\sum_{j=1}^nsim(q_i, k_j)\cdot v_j}{\sum_{j=1}^nsim(q_i, k_j)}

と一般化できます。$sim(q, k)=exp(\frac{q^Tk}{\sqrt{d_{key}}})$とかくと同じになりますので確認してみてください。

Attention機構の計算コストが$O(n^2)$になってしまう原因はずばりここにあります。

1つの行のAttentionを計算するのに$q_i$と$k_j$との類似度を$n_{seq}$回計算しなければいけない。
もしqueryが$n_{query}$行あるなら$sim(q_i, k_j)$を$n_{query}\cdot n_{seq}$回計算しなければいけない。

したがってQueryの数とKeyの数が$n_{query}=n_{seq}=n$個であるSelfAttention、Transformerでは計算が$n^2$回必要になってしまうわけです。

逆カーネルトリック

$O(n^2)$になってしまう根本的な問題は、それぞれn通りあるi, jについて類似度$sim(q_i, k_j)$を計算しなくてはならないことです。これを計算していては$n^2$通りの組み合わせが発生してしまうことは避けられません。

なんとか類似度を直接計算することなく

Attention(Q, K, V)_i = \frac{\sum_{j=1}^nsim(q_i, k_j)\cdot v_j}{\sum_{j=1}^nsim(q_i, k_j)}

と等価な計算をすることはできないでしょうか。

もし類似度関数が

sim(q, k) = \phi(q)^T \phi(k)

のように$q$と$k$とが分解された形で表現できるならそれが可能です。

ここで$\phi$は何らかの非線形関数です。

これはSVM(サポートベクトルマシン)でお馴染みのカーネルトリックです。

SVMでは$\phi(q)^T\phi(k)$の計算を$K(q, k)$によって代替していました。

今回はその逆で$sim(q, k)$の計算を$\phi(q)^T\phi(k)$で代替しようというアイデアです。

どんな類似度関数simについてもこうした分解を可能にする関数$\phi$があります。(類似度関数は対称で半正定値でなくてはならないですが大抵の類似度関数はこれを満たすはずなので気にしないことにします)

こうしてみるとAttentionの第$i$行の計算はこんなふうに書き換えることができます。

Attention(Q, K, V)_i = \frac{\sum_{j=1}^nsim(q_i, k_j)\cdot v_j}{\sum_{j=1}^nsim(q_i, k_j)} = \frac{\sum_{j=1}^n\phi(q_i)^T\phi(k_j)\cdot v_j}{\sum_{j=1}^n\phi(q_i)^T\phi(k_j)}

ですがよく見ると右辺の$\phi(q_i)$は$j$に関する和の計算には関係ないので

Attention(Q, K, V)_i = \frac{\phi(q_i)^T\color{red}{\sum_{j=1}^n\phi(k_j)\cdot v_j}}{\phi(q_i)^T\color{red}{\sum_{j=1}^n\phi(k_j)}}

と書いてもいいはずです。

さて、よく見てください。赤字部分の計算には$i$が含まれていないじゃないですか!

つまり何行目のAttentionを計算するのにも赤字の部分の値は変わらないわけです。つまり事前に赤字部分を計算しておくことができます。(事前計算のコストは$O(n)$)
事前に$\sum_{j=1}^n\phi(k_j)\cdot v_j = A$,$\sum_{j=1}^n\phi(k_j) = B$と求まっていれば

Attention(Q, K, V)_i = \frac{\phi(q_i)^TA}{\phi(q_i)^TB}

の計算はなんと$O(1)$です。

Queryが$n$個、つまりQが$n$行からなる行列だったとしても$Attention(Q, K, V)$の計算はO(n)なのです!!

これがLinear Attentionの根本的なアイデアです。

本家TransformerのAttention計算は高速化できない

残念なことに本家Attentionで用いられている類似度関数について

sim(q, k)=sim(q, k)=exp(\frac{q^Tk}{\sqrt{d_{key}}})=\phi(q)^T\phi(k)

を実現する$\phi$は存在するにはするんですが無限大の次元を持ちます

そんな計算はできないのでsoftmaxを用いたAttention機構を$O(n)$で計算することはできません。

でも別にsoftmaxにこだわることなんてないじゃないですか!似ていることの定義なんてcos類似度を使ったりなんてら距離関数を使ったりとなんでも使えば良いのです!

論文では

\phi(x) = elu(x) + 1

を提案しています。

まさかの類似度関数$sim$ではなく分解された表現$\phi$を先に決めてしまうという発想。

これで$O(n)$のAttention機構が出来上がりました。

パフォーマンス

めちゃめちゃ速いです。

ってのを示すのが両対数グラフってのも微妙だと思うんですけど...。
両対数グラフだと$O(n^2)$でも$O(n)$でも線形なグラフになってしまいますが、その傾きからLinearAttentionはSoftmaxを使ったAttentionよりも計算量が緩やかにスケールしていくことがわかります。

計算コストが速いのはわかった。気になる学習の収束スピードはどうなの?ということでシーケンスを複製するタスクにおける収束のパフォーマンスがこちら。

本家SoftmaxAttentionには及ばないものの匹敵する性能を持っていることがわかります。横軸はGradientStepです。つまりLinearAttentionに使った計算時間は本家Attentionより遥かに短いはずです。

画像生成タスク。Bits/dim(小さい方が良い評価指標)で同等以上の成績を残しながらも4000倍早い生成速度を誇るLinearAttention。

画像ともなると系列の長さは半端ないことになるのでO(n^2)とO(n)で比べるとこれだけの差が開くのも納得です。

まとめ

いかがだったでしょうか。

まとめると
1. 逆カーネルトリックによって類似度計算は2つの値の積に分解できる
2. 分解してしまえば1queryあたりのAttentionの計算を$O(1)$に抑えることができる

という具合でしょうか。

他にも今回紹介した論文Transformers are RNNsではTransformerでRNNを表現できるという興味深い考察も展開されています。記事が長くなるのでここでは触れませんが興味があれば読んでみるととても面白いと思います。

最近Attentionを線形時間で終わらせようという論文はいくつか登場してきています。GoogleResearchが出したBigBirdなんかも注目です。

需要があればこの辺りも解説記事を出していきたいと思っているのでぜひLGTM・フォローしてコメントもしてYosematと仲良くしてください。

それでは今日はこの辺で!お疲れ様でした!