Huggingface の DistilBERT を使って、Android で NLP を!


こんにちは。@rheza_h です。
この記事はACCESS Advent Calendar 2019 22日目の記事です。

NLP は初めてですが、今回の記事は Huggingface が提供している DistilBERT を紹介し、それを使って、Android 上で動かしてみた記事です。
この記事で BERT については書かれていません。
BERT については論文 または @shotasakamoto のプレゼン資料が参考になれると思います。
社内リンク

Huggingface とは?

Natural Language Processing (NLP - 自然言語処理) を中心に研究開発をやっています。
Huggingface は 2016 に Brooklyn, New York で始まりました。
2017 にチャットボットをリリースしました。
Huggingface は自社の NLP モデルを開発して、Hierarchical Multi-Task Learning (HTML) と呼ばれています。
Chatty, Talking Dog, Talking Egg, Boloss と言う iOS アプリを開発しています。
DistilBERT と言うモデルを NeurIPS 2019 に公開されました。

DistilBERTの話

小さく、早く、安く、軽く

state-of-the-arts の NLP モデルはほとんど large-scale language model を使われています。Transformer (Vaswani et al.,) のベースで研究されていて、最近の state-of-the-art モデルのパラメーターのサイズが大きくなっています。例えば NVIDIA が作った、MegatronLM と言うモデルは 8.3億パラメーターがあります。それは約160GBのテキストデータで学習されているそうです。


図1. NLP の進捗

DistilBERT は名前の通り、"Distil"、必要・大事な部分だけを使用して、モデルが小さくして、精度は耐えられる程度で研究しています。
Knowledge Distillation [Bucila et al., 2006, Hinton et al., 2015] と言う方法をやっています。"Teacher-Student" 学習でもよく言われています。
モデルは2つあります。

  • "学生"モデルと呼ばれています。
    • 小さなモデル
    • このモデルが"先生"モデルに似たような結果を出せるよう期待しています
  • "先生"モデル
    • ベースモデルは"学生"モデルと同じ

Knowledge Distillation の流れはいろんな流れがありますが、一番大切なのは Loss Function を組み合わせところだと思います。

pre-trainedを使う場合、

図2. Knowledge Distillation

DistilBERT は名前の通り、BERT がベースになりますが、小さいバージョンが作られました。
ほとんどのアーキテクチャはそのままにさせましたが、token-type embeddings と pooler だけを削除し、レイヤーの数を2の因数に減らします。

評価 (モデルパフォーマンス)

DistillBERT は学習のため、BERT と同じコルパスを使って、英語の Wikipedia と Toronto Book Corpus [Zhu et al., 2015]です。
8つの 16GB V100 GPUs を使って 90時間ぐらいかかったそうです。

BERT と比べたら、パラメーターの数は 40% 少ないし、パフォーマンスは約2~3%しか減っていません。

Android の実装

Huggingface は以下のレポジトリでサンプルアプリを公開します。
https://github.com/huggingface/tflite-android-transformers

DistilBERTをスクラッチから学習すると時間がかかりすぎるので、huggingface は tflite版のモデルを公開されています。
(コードは Huggingface の github で公開されています)
その tflite モデルをロードして、inference に使えばうまく動いています。

モデルのロード

private static final int NUM_LITE_THREADS = 4;
public synchronized void loadModel() {
    try {
      ByteBuffer buffer = loadModelFile(this.context.getAssets());
      Interpreter.Options opt = new Interpreter.Options();
      opt.setNumThreads(NUM_LITE_THREADS);
      tflite = new Interpreter(buffer, opt);
      Log.v(TAG, "TFLite model loaded.");
    } catch (IOException ex) {
      Log.e(TAG, ex.getMessage());
    }
}

public MappedByteBuffer loadModelFile(AssetManager assetManager) throws IOException {
    // MODEL_PATH は assets にあります
    try (AssetFileDescriptor fileDescriptor = assetManager.openFd(MODEL_PATH); 
        FileInputStream inputStream = new FileInputStream(fileDescriptor.getFileDescriptor())) {
      FileChannel fileChannel = inputStream.getChannel();
      long startOffset = fileDescriptor.getStartOffset();
      long declaredLength = fileDescriptor.getDeclaredLength();
      return fileChannel.map(FileChannel.MapMode.READ_ONLY, startOffset, declaredLength);
    }
}

predict をする前に、query の feature extraction をやることが必要です。
空白とか、query に要らない部分を抜いて、必要な部分だけを predict に渡します。

Tokenization と関係するコードは以下のところで
https://github.com/huggingface/tflite-android-transformers/tree/master/bert/src/main/java/co/huggingface/android_transformers/bertqa/tokenization

例:

感想

  • DistilBERT のおかげで BERT でも Android 端末で実行することができます。
  • tflite のフォマットでpre-trained モデルが公開されているので、すぐ使えます。
  • 答えの検索処理はコンテンツ内容によります。公開されたデータセットを使ってみるところ、Nexus 6P 端末で速度は 1~ 3秒ぐらいかかります。
  • モデルのサイズは結構大きい (250MB) です。
  • 次は日本語のモデルを試してみたいです。以下には CL-Tohoku が公開された 日本語の BERT です。 https://github.com/cl-tohoku/bert-japanese

最後に

明日は @hey3 の初めての記事です。お楽しみに!

参考

https://huggingface.co/
https://golden.com/wiki/Hugging_Face
https://techcrunch.com/2019/12/17/hugging-face-raises-15-million-to-build-the-definitive-natural-language-processing-library/
DistilBERT Paper
BERT Paper
Knowledge Distillation Paper
https://medium.com/huggingface/distilbert-8cf3380435b5
https://github.com/huggingface/tflite-android-transformers
https://nervanasystems.github.io/distiller/knowledge_distillation.html
https://towardsdatascience.com/model-distillation-and-compression-for-recommender-systems-in-pytorch-5d81c0f2c0ec
https://qiita.com/nekoumei/items/7b911c61324f16c43e7e