BERT+Grad-CAM (PyTorch)で根拠を可視化する


概要

 Grad-CAMはCNNベースのモデルに対しての視覚的な説明を作成する手法です.本記事ではGrad-CAMを言語処理モデルBERTに実装することでどのような結果になるかを確認していきます.

畳み込みベース以外のモデルに適応できるのか?

 畳み込みが使用されるモデル(特にFully Convolutional Networks)では,ピクセルとその周辺ピクセルの情報を抽出するため位置情報とモデルの出力がダイレクトに対応するように思えますが,BERTなどのSelf-Attentionベースのモデルでは周辺だけではなく全体から重みを付けて情報を抽出するため位置情報とモデルの出力が対応しません.

 上記の理由からBERTにGrad-CAMを適応させるのは適切ではありません.本記事ではそれを承知しながら実装していますが,上記を覆す根拠などが無い限り研究などには応用しづらいと考えています.

準備

 BERTはHuggingFaceから公開されている文のネガポジ判定モデルdaigo/bert-base-japanese-sentimentを使用します.

 python = "3.6.8"
 pytorch = "1.6.0"
 pip install transformers
 pip install japanize_matplotlib

コード

 ・インポート

import types
import torch
from transformers import AutoTokenizer, AutoModelForSequenceClassification

import numpy as np
import matplotlib.pyplot as plt
import japanize_matplotlib

 ・トークナイザとモデルを用意します.

tokenizer = AutoTokenizer.from_pretrained("daigo/bert-base-japanese-sentiment")
model = AutoModelForSequenceClassification.from_pretrained("daigo/bert-base-japanese-sentiment")

 ・順伝播&逆伝播時の内部状態と勾配を取得するための関数を用意しておく.

def save_feature_map(self, module, input, output):
    self.feature_maps.append(output[0].detach())

def save_grad(self, module, grad_in, grad_out):
    self.grads.append(grad_out[0].detach())

def init_cam(self):
    self.feature_maps = []
    self.grads = []

model.save_feature_map = types.MethodType(save_feature_map,model)
model.save_grad =  types.MethodType(save_grad,model)
model.init_cam = types.MethodType(init_cam,model)

model.feature_maps = []
model.grads = []

 ・CNNベースの画像認識モデルでは抽出された特徴的を最下層から取得するが,Self-Attentionベースの言語モデルでは一概に下層ほど特徴がより正確にエンコードされているとは限らないため,全ての層のアウトプット時の状態と勾配を取得する.

for i in range(12):
    model.bert.encoder.layer[i].register_forward_hook(model.save_feature_map)
    model.bert.encoder.layer[i].register_backward_hook(model.save_grad)

 ・下記の様にサンプル文を解析すると出力ラベルは1(ネガティブ)と出力されました.

tokens = tokenizer("私は不運な人間です.",return_tensors='pt')
output = model(tokens["input_ids"],tokens["token_type_ids"],tokens["attention_mask"],labels=torch.tensor([0]))
logits = output.logits.detach()
pred = logits.argmax(-1) ##1
one_hot = torch.nn.functional.one_hot(pred) ##[0,1]

 ・下記のような逆伝播を行う.

logits = output.logits
loss = torch.sum(logits*one_hot)
loss.backward()

 ・可視化に関する数値を取得する.camsに各層の数値が追加される.

cams = []
for i in range(12):
    weights = np.mean(list(reversed(model.grads))[i].cpu().numpy(), axis=1)[0, :]
    features = model.feature_maps[i].squeeze().transpose(0,1).numpy()
    cam = np.sum(features * weights[:, None], axis=0)
    cam = np.maximum(cam, 0)
    cams.append(cam)

 ・入力文をデコードする.

decode_tokens = [tokenizer.decode(token,skip_special_tokens=True).replace(" ","") for token in tokens["input_ids"].squeeze().tolist()]

 ・可視化を行う.
(数値が0の場合は水色,大きくなるほど赤に)

cams_matrix = np.concatenate([np.expand_dims(cam,axis=0) for cam in cams],axis=0)

fig, ax = plt.subplots(figsize=(12, 12))
ax.set_xticks(np.arange(len(decode_tokens)))
ax.set_xticklabels(decode_tokens)
ax.set_title('CAM_Sample')
plt.imshow(cams_matrix, cmap='cool', interpolation='nearest')
plt.savefig("CAM_Sample.png")
plt.show()



Grad-CAM with PyTorchを参考にさせていただきました.

結果

 ・上記のコードではネガティブ判定されるサンプル文を使っていますが,ポジティブ判定される文でも実行してみました.
 ・縦軸は層を示しています.(例えば目盛りの0は1層目での数値を示しています)
 ・「不運」や「楽し」などの特徴的な部分が色付けされているのが確認できますが,画像認識モデルほどは的確にできません.

まとめ

 ・可視化できる根拠無しにBERTに対してGrad-CAMを適応し,可視化を行いました.
 ・CNNベースの画像認識モデルほど的確に可視化を行うことはできませんでしたが,それとなく可能性を感じるような結果でした.

最後に

 ・誤っている部分等ございましたら,コメント等で優しく指摘して頂けると嬉しいです.(気付かなかったら申し訳ありません)