Show, Attend and Tell を動かす,キャプション生成


読み込んだ画像からキャプションを生成するモデルの一つ Show, Attend and Tell の動かし方を書いておきます.
https://github.com/kelvinxu/arctic-captions
(Issueみると結構困ってる人が多いのと自分も最初ちょっと躓いたので)

おそらくほとんどの人はデータセットの作り方がわからないのだと思うのでそのあたりに触れます.

キャプション生成例,馬が2匹見えてるっぽい

まず実行してみる

https://github.com/AAmmy/show-attend-and-tell-8x512
を使用します.Readmeにあるライブラリが必要です.Caffeは必要ありません.

画像の特徴とキャプションを
http://cs.stanford.edu/people/karpathy/deepimagesent/
(COCO 750MB) からダウンロードし,coco.py の下のほうに記述されている場所に配置し,
scripts.py のメモにあるように
THEANO_FLAGS=mode=FAST_RUN,device=gpu,floatX=float32 python evaluate_coco.py
を実行すれば学習が開始されます.

250 update ごとにキャプション例と生成されたキャプションが表示されます.最初のうちは何も出力されませんが徐々に文が出来上がっていきます.
TITAN X で3時間程度で BLEU 値が論文と同じくらいになります.15時間程度で early stop になり学習が終了します.

何か画像を入れてみる

あとで(generate_capsとneraltalkのextract feature py つかえばできる)

画像の特徴

ダウンロードしたファイルが含む画像の特徴は4096次元で,これを coco.py で 8x512 にして使用しています.
論文では 196x512 に取り出したものを使用していますが,4096次元でもうまくいっています.attention の可視化はそのままではできません.

align.pkl

本来のコードでは事前にデータセットと辞書を学習,開発,テストにまとめたものを作っておくのですが,私のコードでは coco.py で学習開始時に作っています.
データの形式は,

train = [
            [["a dog is running.", 0], ...], # キャプション
            [[000000000000], ...] # 特徴
        ]

のようになっています.キャプションは文と数字がリストになっており,数字はその文を持つ画像の特徴ベクトルが特徴リストの何番目に入っているかを意味しています.

辞書

全キャプションの単語数を数えて頻度順に数字を出力するようにしているだけです.collectionsうまく使えばもっと楽に作れるみたいですが.

coco.py
from collections import OrderedDict
dict_count = {}
for cs in cap_:
    for c in cs:
        for w in c.split():
            if w in dict_count:
                dict_count[w] += 1
            else:
                dict_count[w] = 1
dict_ordered_by_count = OrderedDict(sorted(dict_count.items(), key = lambda x:x[0]))
dict_ordered_by_count = OrderedDict(sorted(dict_ordered_by_count.items(), key = lambda x:-x[1]))
worddict = {k:i+2 for i,k in enumerate(dict_ordered_by_count.keys())}

例えば "a" が最も多く出現した単語であった場合は,
worddict['a'] = 2
次に多く出現した単語が "on" であった場合は,
worddict['on'] = 3
と出力されます.0 と 1 は <eos> と UNK に開けておきます.

経緯

最初に Show, Attend and Tell を触ったときにうまく動かなかった (入力データの形式がわからなかった) ので newraltalk を使ってみた.https://github.com/karpathy/neuraltalk
なんとなく理解したので Show, Attend and Tell を使う.特徴抽出は時間がかかるので newraltalk の人たちが使用していたものを流用.Flickr8k では学習は失敗.30k は試してない.
その後 attention を表示してみたくなったので画像から CNN で特徴抽出を行うが,BLEU値が低い.
Issue を見ると BLEU 値が低いと言っている人が何人かいる.

どうも特徴抽出の方法が良くないらしい.(Issue にいる人たち含め)
http://www.statmt.org/wmt16/multimodal-task.html
試していないのがここに論文と同じ形式の Flickr30k のデータがおいてある.

まとめ

入力データができればすぐに動く.
特徴ベクトルうまく抽出してないといい BLEU でない.
学習パラメータは capgen.py の train() に書いてあるのが flickr8k 用,evaluate_coco.py に書いてあるのが coco用.flickr30k はcocoのパラメータの隠れ層を 1300 とかにするとよいらしい.