Generating visual explanations (GVE)を使ってみる


0. 概要

XAI人気ですよね。特にDARPAも力を入れているAIの判断基準を如何に可視化するか。これは、入力画像の何処に着目したか色を付けて画像で可視化する場合もあれば、文字でどこを見ているのかを言うなんていう表現方法等様々あります。
https://www.darpa.mil/attachments/XAIProgramUpdate.pdf
このDarpaの資料でも触れられていますが、その物体を識別した時の理由を文字で表現できるLRCNが生まれました。
それがGVEです。
https://www.mpi-inf.mpg.de/fileadmin/inf/d2/akata/generating-visual-explanations.pdf

こんな感じで、鳥を識別した理由を話してくれます。VQAと異なるのは質問しなくていいということですね。
このGVEをまずは動かしたいと思います。本家実装はCaffeらしいですが、Pytorch版も本家からReferされているのそちらを使いたいと思います。
だってCaffe2からPytorchに吸収されるんでしょう(?)、ならPytorchが良いじゃないということで。

1. Installation

Gitを見ながら進めて行く。
https://github.com/salaniz/pytorch-gve-lrcn

まずはファイルを設置

$ git clone https://github.com/salaniz/pytorch-gve-lrcn.git
$ cd pytorch-gve-lrcn

次に、環境を整える。既にPytorchが入っている人は飛ばしておk

$ conda env create -f environment.yml
$ conda activate gve-lrcn

CUBというデータセットを使うので、データセットがない人はダウンロード

$ ./cub-data-setup-linux.sh

dataに保存される。

2. Training

環境構築が終わったのでトレーニングをして行く。

まず、以下でEmbeded文書を学習

$ python main.py --model sc --dataset cub

学習が終わるとbest-ckpt.pthという最適化されたWeightが保存される。
これをdataフォルダにコピーしてからgveの学習をする。
まずはコピー。

$ cp ./checkpoints/sc-cub-D<date>-T<time>-G<GPUid>/best-ckpt.pth ./data/cub/sentence_classifier_ckpt.pth

次に学習。

$ python main.py --model gve --dataset cub --sc-ckpt ./data/cub/sentence_classifier_ckpt.pth

こんな感じで亀の歩みで学習が進んでいく。

3. Evaluation

学習が終わると以下で評価を実行できる。

$ python main.py --model gve --dataset cub --eval ./checkpoints/gve-cub-D<date>-T<time>-G<GPUid>/best-ckpt.pth

学習回数が少ないのかMETEOR低い・・・

4. Prediction

以下で学習機及び予測機を作っている。

trainer_creator = getattr(TrainerLoader, args.model)
trainer = trainer_creator(args, model, dataset, data_loader, logger, device)

学習機及び予測機の切り替えは以下で行っている。

model.train()
model.eval()

このため、以下でデータセットにロードされたデータの識別が行える。

model.eval()
trainer_creator = getattr(TrainerLoader, 'gve')
trainer = trainer_creator(args, model, dataset, data_loader, logger, device)
res = trainer.train_epoch()

識別された画像のIDと文字が返ってくるのはいいが、入力画像は特徴量しかないため画像化はできない笑
画像IDを渡せば、以下のように入力の特徴量を見ることも出来る笑

img = dataset.get_image("081.Pied_Kingfisher/Pied_Kingfisher_0132_72706.jpg")
print(img)

こんな感じ。

Label: This magnificent bird has a white belly, throat, and crown with a black superciliary, yellow torso, and dark wings with light secondaries.
Prediction: This is a bird with a white belly black wings and a long tail

まあ、正しい?