fairseqで自分でトレーニングしたTransformerモデルをロードする


はじめに

機械翻訳のTransformerモデルをトレーニングする機会があり,Pytorchベースのfairseqを使ったんですが,アプリ用のコードでモデルのロードにハマってしまいました.備忘録のために記事を書きます.

fairseqとは

Facebookの人工知能研究チームが開発している,機械翻訳用のフレームワークです.Facebookが開発元ということもあり,Pytorchがベースになっています.最近はHuggingfaceのTransformersが人気でTransformerモデルを扱うならPytorchだよね,ということもあり,こちらをフレームワークとして選びました.
その他の機械翻訳フレームワークとしては,MarianNMT,OpenNMT(こちらもPytorchベース)などがあります.基本的な機能はどのフレームワークも大差ない印象ですが,論文実装のコードはfairseqが選ばれている場合が多いです.

モデルのトレーニング方法

fairseqはドキュメントが充実しているため,チュートリアル通りにやれば,自分でも機械翻訳モデルをトレーニングすることができます.チュートリアルは英語のみですが,日本語のドキュメントが不足しているのは,他の機械翻訳フレームワークでも同様です.

トレーニングのためのコード(fairseqの公式ドキュメントから)
> mkdir -p checkpoints/fconv
> CUDA_VISIBLE_DEVICES=0 fairseq-train data-bin/iwslt14.tokenized.de-en \
    --lr 0.25 --clip-norm 0.1 --dropout 0.2 --max-tokens 4000 \
    --arch fconv_iwslt_de_en --save-dir checkpoints/fconv

トレーニングコードのsave-dirは学習済みモデルの格納フォルダですが,このフォルダには,各エポックにおける学習済みモデル(checkpoint(n).pt),精度が一番よかった学習済みモデル(checkpoint_best.pt),最終エポックにおける学習済みモデル(checkpoint_last.pt)が格納されます.

checkpoints_fconv
$ ls -alF
checkpoint1.pt
checkpoint2.pt
...
checkpoint99.pt
checkpoint_bes.pt
checkpoint_last.pt

推論の際は,checkpoint_best.ptを読み込むことになるかと思います.

インタラクティブモードでの推論

fairseqの公式ドキュメントには,シェル上のインタラクティブモードでの推論についてはチュートリアルがありますが,Pythonのコード内でモデルをロードする方法などは紹介されていません.

fairseq-interactiveによ推論(fairseqの公式ドキュメントから)
> MODEL_DIR=wmt14.en-fr.fconv-py
> fairseq-interactive \
    --path $MODEL_DIR/model.pt $MODEL_DIR \
    --beam 5 --source-lang en --target-lang fr \
    --tokenizer moses \
    --bpe subword_nmt --bpe-codes $MODEL_DIR/bpecodes
| loading model(s) from wmt14.en-fr.fconv-py/model.pt
| [en] dictionary: 44206 types
| [fr] dictionary: 44463 types
| Type the input sentence and press return:
Why is it rare to discover new marine mammal species?
S-0     Why is it rare to discover new marine mam@@ mal species ?
H-0     -0.0643349438905716     Pourquoi est-il rare de découvrir de nouvelles espèces de mammifères marins?
P-0     -0.0763 -0.1849 -0.0956 -0.0946 -0.0735 -0.1150 -0.1301 -0.0042 -0.0321 -0.0171 -0.0052 -0.0062 -0.0015

Pythonのコード内で学習済みモデルをロードする方法

modelsクラスを使うのが正解でした.インタラクティブモードの場合と異なり,いくつかの引数は自動でロードされるようです.

>>> from fairseq.models.transformer import TransformerModel
>>> model = TransformerModel.from_pretrained('wmt14.en-fr.fconv-py', 'model.pt', 'wmt14.en-fr.fconv-py')
>>> text = 'Why is it rare to discover new marine mam@@ mal species ?'
>>> model.translate(text, beam=5)
'Pourquoi est @-@ il rare de découvrir de nouvelles espèces de mammifères marins ?'

まとめ

fairseqは日本語ドキュメントが少ないですが,HuggingfaceのTransformersを使っている人であれば使いやすいですし,Pytorchベースということにも将来性を感じます.
論文実装のコードも多いですし,今後も触っていきたいフレームワークですね.