CloudTPUでtransformer英日翻訳モデルを学習&推論する手順


公式のチュートリアル通りに進めていったら上手くいかなかったのでメモ。
本記事では、CloudTPU と GCE VMインスタンスを立ち上げ、NMTモデルの一つであるtransformerで英日翻訳モデルを構築する手順を説明します。

前提

  • Google Cloud Platformでプロジェクトを作成済みであること
  • 作成済みプロジェクトに対して課金が有効になっていること

Cloud Console

Cloud Consoleで以下を入力し、CloudTPUとGCE VM インスタンスを新規に立ち上げます。

cloud_console
# プロジェクトIDのセット
gcloud config set project <project_id>
# ctpuを起動(ctpu名はtransformer)
# GCEインスタンスも起動する
ctpu up --name=transformer --tf-version=1.14

公式チュートリアルではctpu upで起動することになっていますが、GCE VMインスタンスのデフォルトのtensorflowとバージョンが合わないため、チュートリアル通りに進めるとエラーが発生します。
チュートリアル通りに進めるにはCloudTPUのtensorflowのバージョンをGCE VMインスタンスのものと揃える必要があります。

GCE(Google Computing Engine)

GCS(Google Cloud Strage)に格納した自前のデータセット(英語-日本語の対訳)をもとにtransformerモデルで学習、推論する手順を説明します。
以下、ctpu upで作成したGCEのVMインスタンスにSSHで接続した状態で進めていきます。

VMインスタンス内のディレクトリ構成

.
├── src
│   ├── __init__.py
│   └── myproblem.py
└── tmp
    └── t2t_tmp
        └── sample.picke

学習データセットをGCSからGCEにダウンロード

gsutil cp gs://<budge_name>/sample.pickle ./tmp/t2t_tmp/sample.pickle

ここで、sample.pickleはenglish(英文)とjapanese(日本語文)の2カラム構成のデータフレームとします。

PROBLEの定義

独自のデータセットを使う場合は、PROBMEを実装して登録してあげる必要があります。
参考:https://tensorflow.github.io/tensor2tensor/new_problem.html
ここでは、以下の2つのPythonスクリプトを作成しておきます。

./src/__init__.py
from . import myproblem
./src/myproblem.py
import pickle

import numpy as np

from tensor2tensor.data_generators import problem
from tensor2tensor.data_generators import text_problems
from tensor2tensor.utils import registry


@registry.register_problem
class Translate_JPEN(text_problems.Text2TextProblem):
    @property
    def approx_vocab_size(self):
        return 2**13

    @property
    def is_generate_per_split(self):
        return False

    @property
    def dataset_splits(self):
        return [{
            "split": problem.DatasetSplit.TRAIN,
            "shards": 9,
        }, {
            "split": problem.DatasetSplit.EVAL,
            "shards": 1,
        }]

    def generate_samples(self, data_dir, tmp_dir, dataset_split):
        with open('./tmp/t2t_tmp/sample.pickle', 'rb') as fin:
            sentences = pickle.load(fin)
        for row in np.array(sentences):
            yield {'inputs': row[0], 'targets': row[1]}

VMインスタンス内の環境変数を設定

# 環境変数のセット
export STORAGE_BUCKET=gs://<project_name>
export DATA_DIR=$STORAGE_BUCKET/transformer
export TMP_DIR=/tmp/t2t_tmp
export PATH=.local/bin:$PATH
export PROBLEM=translate_jpen
export TRAIN_DIR=$STORAGE_BUCKET/training/transformer_ende
export MODEL=transformer
export HPARAMS=transformer_tpu
# 自作スクリプト
export USR_DIR=./src

前処理と学習

自作した./src/myproblem.pyをもとに前処理後、学習します。
ここで、cloud_tpu_namectpu upで指定したnameを直に指定します。($TPU_NAMEで指定するとエラーになります)
参考:https://stackoverflow.com/questions/59089613/tpu-core-error-on-google-cloud-platform-cannot-find-any-tpu-cores-in-the-system

データ量にもよりますが、約6万対訳のデータセットで3時間程度かかりました。

# 前処理
t2t-datagen \
  --problem=$PROBLEM \
  --data_dir=$DATA_DIR \
  --tmp_dir=$TMP_DIR \
  --t2t_usr_dir=$USR_DIR

# 学習
t2t-trainer \
  --data_dir=$DATA_DIR \
  --problem=$PROBLEM \
  --train_steps=40000 \
  --eval_steps=3 \
  --model=$MODEL \
  --hparams_set=$HPARAMS \
  --output_dir=$TRAIN_DIR \
  --t2t_usr_dir=$USR_DIR \
  --use_tpu=True \
  --cloud_tpu_name=transformer

推論

学習後、推論を実行します。
decode_interactiveパラメータをTrueにすることで、インタラクティブシェルで翻訳を実行できます。
CloudTPUでの学習結果をもとにローカルで推論したい場合は以下をご参照ください。
https://qiita.com/yolo_kiyoshi/items/209750f27f582ed48257

# 推論
t2t-decoder \
   --data_dir=$DATA_DIR \
   --problem=$PROBLEM \
   --model=$MODEL \
   --hparams_set=$HPARAMS \
   --output_dir=$TRAIN_DIR \
   --t2t_usr_dir=$USR_DIR \
   --decode_hparams="beam_size=4,alpha=0.6 \
   --decode_interactive=true

参考