XceptionHourgrassに猫の写真とマスク画像を学習させてみた


XceptionHourgrassに猫の写真とマスク画像を学習させてみた

にゃーん(笑)

XceptionHourgrassとは

こちらで公開しているニューラルネットワークのモデルです。
マスク画像やヒートマップ画像を作るためのニューラルネットワークです。
詳しいアーキテクチャとかは、前回の投稿を参照してください。

何をやったか

MS-COCOデータセットから、猫の写真とマスク画像をダウンロードしてきて、XceptionHourgrassに学習させてみました。
GitHubのreadmeに書いてある手順そのままです。

XceptionHourgrassのダウンロード

まずはGitHubからCloneします。

$ git clone https://github.com/tanreinama/XceptionHourgrass---PyTorch.git
$ cd XceptionHourgrass---PyTorch

MS-COCOデータセットのダウンロード

COCO2014のtrainデータセットを使用します。
データサイズが大きいので注意。以下のコマンドで、画像データと、アノテーションデータをダウンロードし、解凍します。

$ wget http://images.cocodataset.org/annotations/annotations_trainval2014.zip
$ wget http://images.cocodataset.org/zips/train2014.zip
$ unzip annotations_trainval2014.zip
$ unzip train2014.zip

猫の写真とマスク画像の用意

サンプルのトレーニング用PGは、COCOデータセットを直接読み込むのでは無く、フォルダから画像を読み込むように作られているので、COCOデータセットのアノテーションデータからPNG画像を作成します。

$ python3 make_testdata.py

そうすると、「train_files/imgs」以下に256x256pxの写真データが、「train_files/mask」以下に256x256pxのマスク画像が生成されます。
独自のデータを学習させたい場合は、同じディレクトリ内に画像を配置すれば良いわけですね。

因みにこれで生成されるデータ数は、2818枚。
う〜ん、本格的な学習をするには画像の枚数が少ないなぁ・・・。
これは、COCOデータセットの中から猫の画像だけを取りだしているため。python3 make_testdata.py --category 'cat,dog'とかすれば、猫と犬の画像を抽出するので、データ数的には少しましになる。ただ、あまりクラス数が多くなると、認識が難しくなるので、悩みどころ。
本格的な用途で使うには、もっと大規模なデータセットを用意した方が良いでしょうね。

トレーニング

後はGitHubに書いてある手順そのまま、プログラムを叩くだけ。
バリデーション用のデータは、自動的にtrain_test_splitで0.01分が取り分けられる。

$ python3 train_testdata.py

学習には、HourgrassNetの論文に従って、OptimizerにRMSPropを使用している。Adamとかでも良かったかもしれないけど、Momentum SGDでは上手く動かなかった。
とりあえず20エポックほど学習させますが、実際はこのデータ数だと12エポック程度で収束する模様。手元で行った実験では160000枚のデータに対して22エポックでValidation lossが最小になりました。

結果

冒頭の画像の通り。にゃーん(笑)

いいわけ

ハッキリ言ってデータ数が足りてません。もっと、同じドメインに沢山のマスク画像があるデータセットで試してみたい所ですが、ぱっとは思いつかなかったので、とりあえず定番のCOCOデータセットを使ってみただけです。

まぁ、3000枚いかない枚数のデータで、しかも転移学習とか無しに直接学習させた結果なので、こんなもんでしょう。
より多数のデータを用意出来るなら、きっともっと良くなるはずです。

あと、この学習用PGはテスト用なので、Data Augmentationの手法は何も入っていません。単純に画像のRandom CropやFlip等入れるとか、Affine変換入れるとかでだいぶ変わるはずです。
Torchvisionのtransforms使っても良かったんだけど入力画像とマスク画像の両方に同じ変換入れなきゃなので面倒だったから省いたの・・・。

Reference

  1. XceptionHourgrass - PyTorch https://github.com/tanreinama/XceptionHourgrass---PyTorch
  2. Alejandro Newell, Kaiyu Yang, and Jia Deng "Stacked Hourglass Networks forHuman Pose Estimation" https://arxiv.org/pdf/1603.06937.pdf
  3. COCO Dataset http://cocodataset.org/#home