[機械学習]Transformerモデルを使ったテキスト分類(Attentionベースの識別器)


ゴール

Tensorflowのチュートリアルにある「言語理解のためのTransformerモデル」の一部を修正して、
テキスト分類のタスクができるようになることです。

Notebook

NotebookをGithubにアップしてあります。
transformer_classify

解説

チュートリアルとの主な差分を以下に記載します。

使用データはlivedoorニュースコーパス

  • 本記事で紹介した分類タスクは、業務に活用する場合には日本語の文書分類になると想定しています。
  • そのためデータは機械学習でよく利用されるlivedoorニュースコーパスを利用させていただきました。

文章の分かち書きにJumanを使用

  • 日本語の分かち書きに定評のあるJumanを使っています。
  • Jumanのダウンロード、インストールを自動化したDockerfileはこちら

Decoderの削除

  • DecoderはEncoderのアウトプットを受け取り他言語ベクトルに変換する仕組みです。
  • 今回は他言語ベクトルへの変換ではなく分類タスクであるため、Decoderは利用しません。

Transformerの修正

  • Decoderを削除する代わりに、Encoderで得られたアウトプットにDenseレイヤーを重ね、これを出力層として追加します。
  • インプットとなるテキストベクトルが、どのクラスに分類されるかを確率的に表現した値に変換するため、 活性化関数にはSoftmax関数を用いています。
transformer_classify.ipynb
NUMLABELS = 9

class Transformer(tf.keras.Model):
  def __init__(self, num_layers, d_model, num_heads, dff, input_vocab_size, 
               target_vocab_size, pe_input, pe_target, rate=0.1):
    super(Transformer, self).__init__()

    self.encoder = Encoder(num_layers, d_model, num_heads, dff, 
                           input_vocab_size, pe_input, rate)
    self.dense1 = tf.keras.layers.Dense(d_model, activation='tanh')
    self.dropout1 = tf.keras.layers.Dropout(rate)   
    self.final_layer = tf.keras.layers.Dense(NUMLABELS, activation='softmax')

  def call(self, inp, tar, training, enc_padding_mask):

    enc_output = self.encoder(inp, training, enc_padding_mask)  # (batch_size, inp_seq_len, d_model)
    enc_output = self.dense1(enc_output[:,0])
    enc_output = self.dropout1(enc_output, training=training)
    final_output = self.final_layer(enc_output )  # (batch_size, tar_seq_len, target_vocab_size)

    return final_output

損失関数

  • 出力層の活性化関数をSoftmax関数を用いているため、損失関数には多クラス交差エントロピーを使います。
  • One-hotベクトル化していないので、SparseCategoricalCrossentropy()を利用しています。
transformer_classify.ipynb
loss_object = tf.keras.losses.SparseCategoricalCrossentropy()

def loss_function(labels, pred):
  loss_ = loss_object(labels, pred)
  return loss_

val_stepの追加

  • train_stepの後にvalidデータを用いたval_stepを追加しています。
  • validationであるため、dropoutレイヤーをスキップさせるためにtraininngはfalseにセットしてます。

結果

あまりよい精度は出せませんでした。

参考URL

tf2_classify
BERT with SentencePiece for Japanese text.
作って理解する Transformer / Attention
Transformer model for language understanding