Python(PyTorch)で自作して理解するTransformer


1. はじめに

Transformerは2017年に「Attention is all you need」という論文で発表され、自然言語処理界にブレイクスルーを巻き起こした深層学習モデルです。論文内では、英語→ドイツ語翻訳・英語→フランス語翻訳という二つの機械翻訳タスクによる性能評価が行われています。それまで最も高い精度を出すとされていたRNNベースの機械翻訳と比較して、

  • 精度(Bleuスコア)
  • 訓練にかかるコストの少なさ

という両方の面で、Transformerはそれらの性能を上回りました。以降、Transformerをベースとした様々なモデルが提案されています。その例としては、BERT,XLNet,GPT-3といった近年のSoTAとされているモデルが挙げられます。

ここで、「Attention is all you need」内に掲載されているTransformerの構造の図を見てみましょう

Transformer

ニューラル機械翻訳におけるTransformerはある時系列データを別の時系列データに変換する(Ex:日本語による文章を英語による文章に翻訳する)ようなタスクに用いられるEncoder-Decoder(seq2seq)の構造をしているという点ではRNN(LSTM,GRU)ベースのモデルと同じです。

しかし、Transformerの最大の特徴はEncoder・DecoderのいずれにもRNNのような再帰計算を必要とする層が存在せず、その代わりとしてこの後説明するAttentionが用いられている点です。

Transformerは自然言語処理のみならず、他の分野でも用いられる汎用性の高いモデルです。また2022年現在、登場から約5年という月日が経過しているため、主要なディープラーニングフレームワークであるPyTorchTensorflowのいずれにも既に公式実装が存在しており、研究等を目的に実装する際はこれらを用いるのが一般的かと思います。

しかし、作ることは理解することへの近道。ということで、今回は取り組んだのはTransformerとTransformerを構成する層のスクラッチ実装です。本記事では、Transformerモデルを構成する各レイヤの理論的背景およびPyTorchによる実装を紹介していきます。

なお、今回はPyTorchの学習も兼ねているため、PyTorchLightningIgniteのような便利ライブラリは使用せず、素のPyTorchのみで実装しています。

以下は実装したリポジトリになります。

https://github.com/YadaYuki/en_ja_translator_pytorch

2.ディレクトリ構成概観

それでは、早速実装を見ていきましょう。細部に注目する前に、プロジェクトの全体像を概観するため、ディレクトリ構成を見てみます。

。
├── const // pathなどの定数値
│   └── path.py
├── corpus // 訓練用のデータ・コーパスが入る
│   └── kftt-data-1.0
├── figure
├── layers // 深層ニューラルネットを構成するレイヤの実装
│   └── transformer
│       ├── Embedding.py
│       ├── FFN.py
│       ├── MultiHeadAttention.py
│       ├── PositionalEncoding.py
│       ├── ScaledDotProductAttention.py
│       ├── TransformerDecoder.py
│       └── TransformerEncoder.py
├── models // 深層ニューラルネットモデルの実装
│   ├── Transformer.py
│   └── __init__.py
├── mypy.ini
├── pickles // モデルやデータセットのpickleファイルを格納
│   └── nn/
├── poetry.lock
├── poetry.toml
├── pyproject.toml
├── tests // テスト(pytest)
│   ├── conftest.py
│   ├── layers/
│   ├── models/
│   └── utils/
├── train.py // 訓練用コード
└── utils // DatasetやVocabといったクラスの実装,前処理に用いる関数の実装
    ├── dataset/
    ├── download.py
    ├── evaluation/
    └── text/

poetry.*という名前のファイルが存在することからわかる通り、ライブラリのパッケージ管理にはPoetryを採用しました。Poetryには、

  • poetry installコマンド一つで必要なライブラリを全てインストールし、仮想環境を作成することができる。
  • 開発時のみ必要なライブラリと開発時と本番環境の両方で必要なライブラリを一つのファイル(pyproject.toml)で定義できる。

などのメリットがあります。

フォルダ構成としては、models以下に「Transformer本体」、layers以下に「Transformerを構成するレイヤー」という次章で説明する実装に該当するファイルが存在します。

3.Transformerの実装

それではプロジェクトの全体像が掴めたところで、Transformerを構成する各層およびTransformer本体の実装を見ていきましょう。

3.1 Attention

先ほども述べた通り、Transformerの最大の特徴はRNNのように訓練時の再帰計算を行わずに、Attention層を用いている点です。「Attention is all you need」という論文のタイトルが物語っているように、AttentionはTransformerの中で最も重要な層です。ここでは、Transformerの核となる層であるAttention層について説明します。

Attention層の目的は、入力ベクトルの各要素の重要度を算出し、それによって入力ベクトルを重み付けること。Attention層では、画像やテキストといった入力データのベクトルの中で、正しい出力を得るために重要な要素はどれであるか?という重要度の役割を果たすAttention weightという値を算出します

「入力ベクトル」と「算出したAttention weight」の積を計算することにより、入力ベクトルの中で正解ラベルを得る上で特に重要な要素を重み付けることが可能です。

Attentionを定式化してみましょう。まずは、重要度を表すAttention weightの計算です。Attention weightを求める関数\alphaを用いて入力XとQ(クエリ)から、 Xの重要度を求めます。

Attention\,weight = \alpha(Q,X)

上の式で求めたAttention weight(Xの重要度)で入力Xを重み付けたものがAttentionの出力になります。

Output = (Attention\,weight)X =\alpha(Q,X)X

ここでは、「Attention weightを算出し、それにより入力データが重み付けられる」という演算の流れがわかるように、上のような式でAttention層での演算を表しました。しかし、本家の論文内ではAttentionの演算を、Q(クエリ),K(キー),V(バリュー)という3つの記号を用いて、

Attention(Q,K,V)

のように表しています。先程の\alpha(Q,X)XAttention(Q,K,V)はいずれもAttentionの演算であるため、書き方が異なるだけで、やっていることは全く同じです。

\alpha(Q,X)XをQ,K,Vによって表すと以下の通りになります。

Attention(Q,K,V) = \alpha(Q,K)V

\alpha(Q,X)Xは、すなわち、Attention(Q,K,V)におけるK,Vの両方がXであるということです。

\alpha(Q,X)X = Attention(Q,X,X)

ここで、\alpha(Q,X)XにおけるQとXが同じであるか、そうでないかによってAttentionを区別できます。Q(クエリ)にも入力Xを用いるもの、すなわち、\alpha(X,X)Xであるものをself-attention、一方で、Q,Xに異なるデータを用いるものをsource-to-target attentionと呼びます

機械翻訳におけるTransformerでは

  • Xを翻訳元の文章データとしたself attention ( in Encoder )
  • Xを翻訳対象の文章データとしたself attention ( in Decoder )
  • Qを翻訳対象の文章データ,X翻訳元の文章データをsource-to-target attention ( in Decoder )

という3通りのAttentionの使われ方が存在します。

以上がAttention層の概要に関する説明です。

さて、本節では、Attention層が行う演算の概要とその目的について説明しました。しかし肝心なのは、入力の重要度を表すAttention weightをどのように算出すべきであるか?という点です(関数\alphaが何であるか?)

TransformerではクエリQ,入力データXの内積を計算することによって、Attention weightを算出するScaledDotProductAttentionという手法が採用されています。次の節で、その詳細を見ていきましょう。

3.2 ScaledDotProductAttention

Transformerで用いられるAttentionであるScaledDotProductAttentionのAttention weight計算は、クエリQ(N{\times}D)と入力X(N{\times}D)を用いて、以下のような式で表すことができます。

\alpha(Q,X) = softmax(\frac{QX^T}{\sqrt{D}})

質問応答や機械翻訳といったタスクにTransformerモデルで取り組んでいる場合、上の式におけるQ,Xはそれぞれ文章データを行列で表現したものです。扱うデータが「N個の単語を持つ文章を、D次元の単語分散表現で表現したデータ」であった場合、N{\times}Dというサイズを持つ行列となっています。

そのため、この演算では、クエリQ内の各単語の分散表現と入力データX内の各単語の分散表現の内積を算出している、ということです。ベクトル同士の内積が大きいということは、向いている方向が近い、すなわち、ベクトル同士の類似度が高いということです(単語同士の類似度が高い)。つまり、ScaledDotProductAttentionに文章データを入力した場合、Q,X内の単語同士の類似度を、入力に対する重要性として重み付けていると解釈することが可能です。

上で求めたAttention weightと入力Xの積を求めることで最終的な出力が得られます。従って、入力X,クエリQに対するScaledDotProductAttentionの出力は以下の式で表すことができます。

Attention(Q,K,V) = Attention(Q,X,X) = \alpha(Q,X)X = softmax(\frac{QX^T}{\sqrt{D}})X

それでは、 ScaledDotProductAttentionのPyTorchによる実装を見てみましょう。

import numpy as np
import torch
from torch import nn


class ScaledDotProductAttention(nn.Module):
    def __init__(self, d_k: int) -> None:
        super().__init__()
        self.d_k = d_k

    def forward(
        self,
        q: torch.Tensor,  # =Q
        k: torch.Tensor,  # =X
        v: torch.Tensor,  # =X
        mask: torch.Tensor = None,
    ) -> torch.Tensor:
        scalar = np.sqrt(self.d_k)
        attention_weight = torch.matmul(q, torch.transpose(k, 1, 2)) / scalar # 「Q * X^T / (D^0.5)」" を計算

        if mask is not None: # maskに対する処理
            if mask.dim() != attention_weight.dim():
                raise ValueError(
                    "mask.dim != attention_weight.dim, mask.dim={}, attention_weight.dim={}".format(
                        mask.dim(), attention_weight.dim()
                    )
                )
            attention_weight = attention_weight.data.masked_fill_(
                mask, -torch.finfo(torch.float).max
            ) 

        attention_weight = nn.functional.softmax(attention_weight, dim=2) # Attention weightを計算
        return torch.matmul(attention_weight, v) # (Attention weight) * X により重み付け.

3.3 Multihead Attention

3.2節で、TransformerモデルはAttentionの計算方法としてScaledDotProductAttentionを採用していることを説明しました。

しかし、Transformerで採用されているAttentionは単なるScaledDotProductAttentionではありません。実際のTransformerでは単一の入力に対して、複数のScaledDotProductAttentionを並列で実行するMultihead Attentionという仕組みが採用されています。

ここで、本家の論文に掲載されているMultihead Attentionの概要図を見てみましょう。

Multihead Attention

図中のh(ヘッド数)は並列実行するScaledDotProductAttentionの数を表します。

Multihead Attentionでは、以下のような処理が行われます。

  1. Attention層に対する入力Q(N_{Q}\timesd_{model}),K(N\timesd_{model}),V(N\timesd_{model})をh(ヘッド数)の数だけ複製する
  2. 複製した入力Q_i,K_i,V_i(i=1~h)を行列W_i^q(d_{model}\timesd_k),W_i^k(d_{model}\timesd_k),W_i^v(d_{model}\timesd_v)により、d_{model}d_v,d_kへと線形変換する。
  3. そうして得られたQ_iW_i^q(N_Q\timesd_k),K_iW^k_i(N\timesd_k),V_iW^v_i(N\timesd_v)をh個存在するScaledDotProductAttentionへ入力する
  4. 並列実行されたScaledDotProductAttentionから得られるh個の出力head(i=1~h,N\timesd_v)を結合(concat)し、行列O(N\timeshd_v)を得る。
  5. OW^OによりOhd_vd_{model}に線形変換し、得られた値が最終的な出力となる。

定式化すると以下の通りになります。

head_i = ScaledDotProductAttention(Q_iW^q_i,K_iW^k_i,V_iW^v_i) (i = 1 \sim h)
O = Concat(head_1, ... , head_h)
MultiHead(Q,K,V) = OW^O

それでは、以上を踏まえて、実装を見ていきましょう。

import torch
from layers.transformer.ScaledDotProductAttention import ScaledDotProductAttention
from torch import nn


class MultiHeadAttention(nn.Module):
    def __init__(self, d_model: int, h: int) -> None:
        super().__init__()
        self.d_model = d_model
        self.h = h
        self.d_k = d_model // h
        self.d_v = d_model // h

        #
        self.W_k = nn.Parameter(
            torch.Tensor(h, d_model, self.d_k)  # ヘッド数, 入力次元, 出力次元(=入力次元/ヘッド数)
        )

        self.W_q = nn.Parameter(
            torch.Tensor(h, d_model, self.d_k)  # ヘッド数, 入力次元, 出力次元(=入力次元/ヘッド数)
        )

        self.W_v = nn.Parameter(
            torch.Tensor(h, d_model, self.d_v)  # ヘッド数, 入力次元, 出力次元(=入力次元/ヘッド数)
        )

        self.scaled_dot_product_attention = ScaledDotProductAttention(self.d_k)

        self.linear = nn.Linear(h * self.d_v, d_model)

    def forward(
        self,
        q: torch.Tensor,
        k: torch.Tensor,
        v: torch.Tensor,
        mask_3d: torch.Tensor = None,
    ) -> torch.Tensor:

        batch_size, seq_len = q.size(0), q.size(1)

        """repeat Query,Key,Value by num of heads"""
        q = q.repeat(self.h, 1, 1, 1)  # head, batch_size, seq_len, d_model
        k = k.repeat(self.h, 1, 1, 1)  # head, batch_size, seq_len, d_model
        v = v.repeat(self.h, 1, 1, 1)  # head, batch_size, seq_len, d_model

        """Linear before scaled dot product attention"""
        q = torch.einsum(
            "hijk,hkl->hijl", (q, self.W_q)
        )  # head, batch_size, d_k, seq_len
        k = torch.einsum(
            "hijk,hkl->hijl", (k, self.W_k)
        )  # head, batch_size, d_k, seq_len
        v = torch.einsum(
            "hijk,hkl->hijl", (v, self.W_v)
        )  # head, batch_size, d_k, seq_len

        """Split heads"""
        q = q.view(self.h * batch_size, seq_len, self.d_k)
        k = k.view(self.h * batch_size, seq_len, self.d_k)
        v = v.view(self.h * batch_size, seq_len, self.d_v)

        if mask_3d is not None:
            mask_3d = mask_3d.repeat(self.h, 1, 1)

        """Scaled dot product attention"""
        attention_output = self.scaled_dot_product_attention(
            q, k, v, mask_3d
        )  # (head*batch_size, seq_len, d_model)

        attention_output = torch.chunk(attention_output, self.h, dim=0)
        attention_output = torch.cat(attention_output, dim=2)

        """Linear after scaled dot product attention"""
        output = self.linear(attention_output)
        return output

3.4 PositionalEncoding

Postional Encoding層では、系列データ内の各要素に、要素のデータ内における位置情報を付与する役割を担います。例えば、文章の場合、Positiona Encodingによって、各単語ベクトルに、その単語が文章内で何番目に登場するか?という情報を付与することが可能です。

Transformerでは、以下の式で算出した値を固定値として、入力に加算します。

PE_{(pos,2i)} = sin(pos/10000^{2i/d_{model}})
PE_{(pos,2i+1)} = cos(pos/10000^{2i/d_{model}})

上の式で「なぜ位置情報を正しく付与することができるのか?」という点はここでは詳しく説明しませんが、それについては、こちらが参考になると思います。それでは、実装を見てみましょう。

import numpy as np
import torch
from torch import nn


class AddPositionalEncoding(nn.Module):
    def __init__(
        self, d_model: int, max_len: int, device: torch.device = torch.device("cpu")
    ) -> None:
        super().__init__()
        self.d_model = d_model
        self.max_len = max_len
        positional_encoding_weight: torch.Tensor = self._initialize_weight().to(device)
        self.register_buffer("positional_encoding_weight", positional_encoding_weight)

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        seq_len = x.size(1)
        return x + self.positional_encoding_weight[:seq_len, :].unsqueeze(0)

    def _get_positional_encoding(self, pos: int, i: int) -> float:
        w = pos / (10000 ** (((2 * i) // 2) / self.d_model))
        if i % 2 == 0:
            return np.sin(w)
        else:
            return np.cos(w)

    def _initialize_weight(self) -> torch.Tensor:
        positional_encoding_weight = [
            [self._get_positional_encoding(pos, i) for i in range(1, self.d_model + 1)]
            for pos in range(1, self.max_len + 1)
        ]
        return torch.tensor(positional_encoding_weight).float()

3.5 Position-wise Feed-Forward Networks

Position-wise Feed-Forward Networks(FFN)は2つの全結合層(Linear)を重ねただけの非常に単純な層です。1つ目の層からの出力に対する活性化関数にはReluを用いています。Position-wise Feed-Forward Networks(FFN)を定式化すると以下の通りです。

FFN(x) = max(0, xW_1 + b_1)W_2 + b_2

それでは、実装を見てみましょう。

import torch
from torch import nn
from torch.nn.functional import relu


class FFN(nn.Module):
    def __init__(self, d_model: int, d_ff: int) -> None:
        super().__init__()
        self.linear1 = nn.Linear(d_model, d_ff)
        self.linear2 = nn.Linear(d_ff, d_model)

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        return self.linear2(relu(self.linear1(x)))

3.6 Encoder

Transformerによる機械翻訳は、seq2seqの一種であるため、RNNベースのモデルと同様、EncoderDecoderの構造をしている、というのは既に説明しました。ここで紹介するのは、TransformerにおけるEncoder、すなわち、Transformerの概要図における以下の部分です。

TransformerのEncoder

ということで、ここでは、今までの紹介した層を組み合わせて、TransformerにおけるEncoderを実装します。

TransformerEncoderは

  • Embedding層(単語id列による文章データを単語分散表現による文章データに変換する)
  • Positiona Encoding層
  • Multihead Attention, FeedForward Networkの2つの層(それぞれに対して、Add & Normを施す )から構成されるTransformerEncoderBlock層を任意のN層重ね合わせた層

という大きく分けて3つのパーツで構成されます。それでは実装を見てみましょう。

Encoder
import torch
from torch import nn
from torch.nn import LayerNorm

from .Embedding import Embedding
from .FFN import FFN
from .MultiHeadAttention import MultiHeadAttention
from .PositionalEncoding import AddPositionalEncoding


class TransformerEncoderLayer(nn.Module):
    def __init__(
        self,
        d_model: int,
        d_ff: int,
        heads_num: int,
        dropout_rate: float,
        layer_norm_eps: float,
    ) -> None:
        super().__init__()

        self.multi_head_attention = MultiHeadAttention(d_model, heads_num)
        self.dropout_self_attention = nn.Dropout(dropout_rate)
        self.layer_norm_self_attention = LayerNorm(d_model, eps=layer_norm_eps)

        self.ffn = FFN(d_model, d_ff)
        self.dropout_ffn = nn.Dropout(dropout_rate)
        self.layer_norm_ffn = LayerNorm(d_model, eps=layer_norm_eps)

    def forward(self, x: torch.Tensor, mask: torch.Tensor = None) -> torch.Tensor:
        x = self.layer_norm_self_attention(self.__self_attention_block(x, mask) + x)
        x = self.layer_norm_ffn(self.__feed_forward_block(x) + x)
        return x

    def __self_attention_block(
        self, x: torch.Tensor, mask: torch.Tensor
    ) -> torch.Tensor:
        """
        self attention block
        """
        x = self.multi_head_attention(x, x, x, mask)
        return self.dropout_self_attention(x)

    def __feed_forward_block(self, x: torch.Tensor) -> torch.Tensor:
        """
        feed forward block
        """
        return self.dropout_ffn(self.ffn(x))


class TransformerEncoder(nn.Module):
    def __init__(
        self,
        vocab_size: int,
        max_len: int,
        pad_idx: int,
        d_model: int,
        N: int,
        d_ff: int,
        heads_num: int,
        dropout_rate: float,
        layer_norm_eps: float,
        device: torch.device = torch.device("cpu"),
    ) -> None:
        super().__init__()
        self.embedding = Embedding(vocab_size, d_model, pad_idx)

        self.positional_encoding = AddPositionalEncoding(d_model, max_len, device)

        self.encoder_layers = nn.ModuleList(
            [
                TransformerEncoderLayer(
                    d_model, d_ff, heads_num, dropout_rate, layer_norm_eps
                )
                for _ in range(N)
            ]
        )

    def forward(self, x: torch.Tensor, mask: torch.Tensor = None) -> torch.Tensor:
        x = self.embedding(x)
        x = self.positional_encoding(x)
        for encoder_layer in self.encoder_layers:
            x = encoder_layer(x, mask)
        return x

3.7 Decoder

Encoderに続き、Decoderも実装していきましょう。Decoderは、Transformerの概要図における以下の部分です。DecoderもEncoder同様、Embedding,Positiona Encoding,Multihead Attention,FeedForward Networkから構成されます。

TransformerのDecoder

実装は以下の通りです。

Decoder
import torch
from torch import nn
from torch.nn import LayerNorm

from .Embedding import Embedding
from .FFN import FFN
from .MultiHeadAttention import MultiHeadAttention
from .PositionalEncoding import AddPositionalEncoding


class TransformerDecoderLayer(nn.Module):
    def __init__(
        self,
        d_model: int,
        d_ff: int,
        heads_num: int,
        dropout_rate: float,
        layer_norm_eps: float,
    ):
        super().__init__()
        self.self_attention = MultiHeadAttention(d_model, heads_num)
        self.dropout_self_attention = nn.Dropout(dropout_rate)
        self.layer_norm_self_attention = LayerNorm(d_model, eps=layer_norm_eps)

        self.src_tgt_attention = MultiHeadAttention(d_model, heads_num)
        self.dropout_src_tgt_attention = nn.Dropout(dropout_rate)
        self.layer_norm_src_tgt_attention = LayerNorm(d_model, eps=layer_norm_eps)

        self.ffn = FFN(d_model, d_ff)
        self.dropout_ffn = nn.Dropout(dropout_rate)
        self.layer_norm_ffn = LayerNorm(d_model, eps=layer_norm_eps)

    def forward(
        self,
        tgt: torch.Tensor,  # Decoder input
        src: torch.Tensor,  # Encoder output
        mask_src_tgt: torch.Tensor,
        mask_self: torch.Tensor,
    ) -> torch.Tensor:
        tgt = self.layer_norm_self_attention(
            tgt + self.__self_attention_block(tgt, mask_self)
        )

        x = self.layer_norm_src_tgt_attention(
            tgt + self.__src_tgt_attention_block(src, tgt, mask_src_tgt)
        )

        x = self.layer_norm_ffn(x + self.__feed_forward_block(x))

        return x

    def __src_tgt_attention_block(
        self, src: torch.Tensor, tgt: torch.Tensor, mask: torch.Tensor
    ) -> torch.Tensor:
        return self.dropout_src_tgt_attention(
            self.src_tgt_attention(tgt, src, src, mask)
        )

    def __self_attention_block(
        self, x: torch.Tensor, mask: torch.Tensor
    ) -> torch.Tensor:
        return self.dropout_self_attention(self.self_attention(x, x, x, mask))

    def __feed_forward_block(self, x: torch.Tensor) -> torch.Tensor:
        return self.dropout_ffn(self.ffn(x))


class TransformerDecoder(nn.Module):
    def __init__(
        self,
        tgt_vocab_size: int,
        max_len: int,
        pad_idx: int,
        d_model: int,
        N: int,
        d_ff: int,
        heads_num: int,
        dropout_rate: float,
        layer_norm_eps: float,
        device: torch.device = torch.device("cpu"),
    ) -> None:
        super().__init__()
        self.embedding = Embedding(tgt_vocab_size, d_model, pad_idx)
        self.positional_encoding = AddPositionalEncoding(d_model, max_len, device)
        self.decoder_layers = nn.ModuleList(
            [
                TransformerDecoderLayer(
                    d_model, d_ff, heads_num, dropout_rate, layer_norm_eps
                )
                for _ in range(N)
            ]
        )

    def forward(
        self,
        tgt: torch.Tensor,  # Decoder input
        src: torch.Tensor,  # Encoder output
        mask_src_tgt: torch.Tensor,
        mask_self: torch.Tensor,
    ) -> torch.Tensor:
        tgt = self.embedding(tgt)
        tgt = self.positional_encoding(tgt)
        for decoder_layer in self.decoder_layers:
            tgt = decoder_layer(
                tgt,
                src,
                mask_src_tgt,
                mask_self,
            )
        return tgt

3.8 Transformer(完成版)

ここまでで、AttentionやPositionalEncoding,Encoder,DecoderといったTransformerを構成する各パーツの実装が完了しました。本章ではいよいよ、Transformer本体の実装に移ります。

といっても、EncoderとDecoderの実装が完了しているので、モデルの実装自体はその二つを組み合わせるだけで、とてもシンプルです。それでは実装を見てみましょう。

Transformer
import torch
from layers.transformer.TransformerDecoder import TransformerDecoder
from layers.transformer.TransformerEncoder import TransformerEncoder
from torch import nn


class Transformer(nn.Module):
    def __init__(
        self,
        src_vocab_size: int,
        tgt_vocab_size: int,
        max_len: int,
        d_model: int = 512,
        heads_num: int = 8,
        d_ff: int = 2048,
        N: int = 6,
        dropout_rate: float = 0.1,
        layer_norm_eps: float = 1e-5,
        pad_idx: int = 0,
        device: torch.device = torch.device("cpu"),
    ):

        super().__init__()

        self.src_vocab_size = src_vocab_size
        self.tgt_vocab_size = tgt_vocab_size
        self.d_model = d_model
        self.max_len = max_len
        self.heads_num = heads_num
        self.d_ff = d_ff
        self.N = N
        self.dropout_rate = dropout_rate
        self.layer_norm_eps = layer_norm_eps
        self.pad_idx = pad_idx
        self.device = device

        self.encoder = TransformerEncoder(
            src_vocab_size,
            max_len,
            pad_idx,
            d_model,
            N,
            d_ff,
            heads_num,
            dropout_rate,
            layer_norm_eps,
            device,
        )

        self.decoder = TransformerDecoder(
            tgt_vocab_size,
            max_len,
            pad_idx,
            d_model,
            N,
            d_ff,
            heads_num,
            dropout_rate,
            layer_norm_eps,
            device,
        )

        self.linear = nn.Linear(d_model, tgt_vocab_size)

    def forward(self, src: torch.Tensor, tgt: torch.Tensor) -> torch.Tensor:
        """
        Parameters:
        ----------
        src : torch.Tensor
            単語のid列. [batch_size, max_len]
        tgt : torch.Tensor
            単語のid列. [batch_size, max_len]
        """

        # mask
        pad_mask_src = self._pad_mask(src)

        src = self.encoder(src, pad_mask_src)

        # if tgt is not None:

        # target系列の"0(BOS)~max_len-1"(max_len-1系列)までを入力し、"1~max_len"(max_len-1系列)を予測する
        mask_self_attn = torch.logical_or(
            self._subsequent_mask(tgt), self._pad_mask(tgt)
        )
        dec_output = self.decoder(tgt, src, pad_mask_src, mask_self_attn)

        return self.linear(dec_output)

    def _pad_mask(self, x: torch.Tensor) -> torch.Tensor:
        """単語のid列(ex:[[4,1,9,11,0,0,0...],[4,1,9,11,0,0,0...],[4,1,9,11,0,0,0...]...])からmaskを作成する.
        Parameters:
        ----------
        x : torch.Tensor
            単語のid列. [batch_size, max_len]
        """
        seq_len = x.size(1)
        mask = x.eq(self.pad_idx)  # 0 is <pad> in vocab
        mask = mask.unsqueeze(1)
        mask = mask.repeat(1, seq_len, 1)  # (batch_size, max_len, max_len)
        return mask.to(self.device)

    def _subsequent_mask(self, x: torch.Tensor) -> torch.Tensor:
        """DecoderのMasked-Attentionに使用するmaskを作成する.
        Parameters:
        ----------
        x : torch.Tensor
            単語のトークン列. [batch_size, max_len, d_model]
        """
        batch_size = x.size(0)
        max_len = x.size(1)
        return (
            torch.tril(torch.ones(batch_size, max_len, max_len)).eq(0).to(self.device)
        )

4.Transformerの学習

さて、前章まででTransformerの実装したところで、 ここでは実装したTransformerを実際の機械翻訳データセットを使用して、学習していきます。

モデルを学習するにあたって、モデルの学習を目的としたクラス・Trainerをtrain.pyに定義しました。なお、TrainerクラスはPyTorch LightningのAPIを参考にしており、以下の5つのメソッドを持ちます。

  • loss_fn: 誤差関数の計算
  • train_step: バッチ学習における1ステップ(訓練)
  • val_step: バッチ学習における1ステップ(検証)
  • fit: バッチ学習によるモデルの訓練・検証を行う
  • test: テストデータによるモデルの検証を行う

というわけで、実装を見てみましょう。Trainerクラスの実装は以下の通りです。

Trainerクラス
from os.path import join
from typing import List, Tuple

import torch
from matplotlib import pyplot as plt
from torch import nn, optim
from torch.utils.data import DataLoader

from const.path import (
    FIGURE_PATH,
    KFTT_TOK_CORPUS_PATH,
    NN_MODEL_PICKLES_PATH,
    TANAKA_CORPUS_PATH,
)
from models import Transformer
from utils.dataset.Dataset import KfttDataset
from utils.evaluation.bleu import BleuScore
from utils.text.text import tensor_to_text, text_to_tensor
from utils.text.vocab import get_vocab


class Trainer:
    def __init__(
        self,
        net: nn.Module,
        optimizer: optim.Optimizer,
        critetion: nn.Module,
        bleu_score: BleuScore,
        device: torch.device,
    ) -> None:
        self.net = net
        self.optimizer = optimizer
        self.critetion = critetion
        self.device = device
        self.bleu_score = bleu_score
        self.net = self.net.to(self.device)

    def loss_fn(self, preds: torch.Tensor, labels: torch.Tensor) -> torch.Tensor:
        return self.critetion(preds, labels)

    def train_step(
        self, src: torch.Tensor, tgt: torch.Tensor
    ) -> Tuple[torch.Tensor, torch.Tensor, float]:
        self.net.train()
        output = self.net(src, tgt)

        tgt = tgt[:, 1:]  # decoderからの出力は1 ~ max_lenまでなので、0以降のデータで誤差関数を計算する
        output = output[:, :-1, :]  #

        # calculate loss
        loss = self.loss_fn(
            output.contiguous().view(
                -1,
                output.size(-1),
            ),
            tgt.contiguous().view(-1),
        )

        # calculate bleu score
        _, output_ids = torch.max(output, dim=-1)
        bleu_score = self.bleu_score(tgt, output_ids)

        self.optimizer.zero_grad()
        loss.backward()
        self.optimizer.step()

        return loss, output, bleu_score

    def val_step(
        self, src: torch.Tensor, tgt: torch.Tensor
    ) -> Tuple[torch.Tensor, torch.Tensor, float]:
        self.net.eval()
        output = self.net(src, tgt)

        tgt = tgt[:, 1:]
        output = output[:, :-1, :]  #

        loss = self.loss_fn(
            output.contiguous().view(
                -1,
                output.size(-1),
            ),
            tgt.contiguous().view(-1),
        )
        _, output_ids = torch.max(output, dim=-1)
        bleu_score = self.bleu_score(tgt, output_ids)

        return loss, output, bleu_score

    def fit(
        self, train_loader: DataLoader, val_loader: DataLoader, print_log: bool = True
    ) -> Tuple[List[float], List[float], List[float], List[float]]:
        # train
        train_losses: List[float] = []
        train_bleu_scores: List[float] = []
        if print_log:
            print(f"{'-'*20 + 'Train' + '-'*20} \n")
        for i, (src, tgt) in enumerate(train_loader):
            src = src.to(self.device)
            tgt = tgt.to(self.device)
            loss, _, bleu_score = self.train_step(src, tgt)
            src = src.to("cpu")
            tgt = tgt.to("cpu")

            if print_log:
                print(
                    f"train loss: {loss.item()}, bleu score: {bleu_score},"
                    + f"iter: {i+1}/{len(train_loader)} \n"
                )

            train_losses.append(loss.item())
            train_bleu_scores.append(bleu_score)

        # validation
        val_losses: List[float] = []
        val_bleu_scores: List[float] = []
        if print_log:
            print(f"{'-'*20 + 'Validation' + '-'*20} \n")
        for i, (src, tgt) in enumerate(val_loader):
            src = src.to(self.device)
            tgt = tgt.to(self.device)
            loss, _, bleu_score = self.val_step(src, tgt)
            src = src.to("cpu")
            tgt = tgt.to("cpu")

            if print_log:
                print(f"train loss: {loss.item()}, iter: {i+1}/{len(val_loader)} \n")

            val_losses.append(loss.item())
            val_bleu_scores.append(bleu_score)

        return train_losses, train_bleu_scores, val_losses, val_bleu_scores

    def test(self, test_data_loader: DataLoader) -> Tuple[List[float], List[float]]:
        test_losses: List[float] = []
        test_bleu_scores: List[float] = []
        for i, (src, tgt) in enumerate(test_data_loader):
            src = src.to(self.device)
            tgt = tgt.to(self.device)
            loss, _, bleu_score = trainer.val_step(src, tgt)
            src = src.to("cpu")
            tgt = tgt.to("cpu")

            test_losses.append(loss.item())
            test_bleu_scores.append(bleu_score)

        return test_losses, test_bleu_scores

それでは、モデルを学習していきます。データセットには日本語と英語の翻訳コーパスである京都フリー翻訳タスク(KFTT)を用いました。モデルの学習には以下のコマンドを実行します。

$ poetry run python train.py

上のコマンドを実行すると、以下のような出力が得られます。

$ poetry run python trainpy

epoch: 1
--------------------Train--------------------

train loss: 10.104473114013672, bleu score: 0.0,iter: 1/4403

train loss: 9.551202774047852, bleu score: 0.0,iter: 2/4403

train loss: 8.950608253479004, bleu score: 0.0,iter: 3/4403

train loss: 8.688143730163574, bleu score: 0.0,iter: 4/4403

train loss: 8.4220552444458, bleu score: 0.0,iter: 5/4403

train loss: 8.243291854858398, bleu score: 0.0,iter: 6/4403

train loss: 8.187620162963867, bleu score: 0.0,iter: 7/4403

train loss: 7.6360859870910645, bleu score: 0.0,iter: 8/4403

....

以下は1ステップ当たりの訓練ロスの推移をプロットしたグラフです。

image

実行環境の都合で、「モデルを十分に訓練し、テストデータに対するBleuスコア・Lossで検証を行う」という検証はできませんでしたが(すみません🙇‍♂️)、1ステップあたりの訓練ロスが順調に減っていることから、正しく学習できていると推察されます。

5.まとめ

一つ一つ書いていたらとても長くなってしまいましたが、以上が実装になります!

Transformerは登場して既に5年が経過しているので、モデル自体は新しいとは言い難いです。しかし、近年自然言語処理の分野でSoTAとされるモデルのほとんどは Transformer・Attention機構をベースとしています。そのため、Transformerは深層学習のトレンドを掴む上で、非常に重要なモデルといって良いでしょう。

最近はライブラリが充実しており、Transformerのような複雑なモデルの実装に対するハードルもかなり下がっています。ですが、内部構造を理解しておくことは応用力を高める上で非常に重要だと思うので、興味のある方は今回のような車輪の再発明も是非試してみてください。

6.参考文献

https://arxiv.org/abs/1706.03762

https://www.amazon.co.jp/詳解ディープラーニング-TensorFlow-Keras・PyTorchによる時系列データ処理-Compass-Booksシリーズ/dp/4839969515

https://www.amazon.co.jp/深層学習-改訂第2版-機械学習プロフェッショナルシリーズ-岡谷-貴之/dp/4065133327

https://www.amazon.co.jp/PyTorch実践入門-Eli-Stevens/dp/4839974691

https://www.amazon.co.jp/ゼロから作るDeep-Learning-―自然言語処理編-斎藤-康毅/dp/4873118360/ref=pd_bxgy_img_1/358-0651022-5160614?pd_rd_w=JO4Kw&pf_rd_p=020fee25-8ced-4191-bce3-27e7ce0c0e3b&pf_rd_r=WM1R7J4578P1B0MCNSJE&pd_rd_r=2b8a68ac-2514-4675-9132-acafd1cf2853&pd_rd_wg=kDhDj&pd_rd_i=4873118360&psc=1

https://github.com/pytorch/pytorch/blob/master/torch/nn/modules/transformer.py

https://qiita.com/halhorn/items/c91497522be27bde17ce

https://deeplearning.hatenablog.com/entry/transformer