機械学習で友人をアニメキャラに例えてみた


この記事について

機械学習を使って、僕の友人の僧侶インフルエンサー稲田ズイキ君をアニメキャラクターに例えてみました。
背景についてはこちらの記事をご覧ください。

ユーザーの入力画像に対してデータセット内の画像で似ている画像を返すアプリを作ることを目標とします。
キーワードはTuricreate, Flask, Docker。こんな感じです。

データ収集

なんと言ってもまずはデータ。
顔デカキャラの名前と画像の枚数を指定するだけで画像を集めてくれるスクリプトを書きます。

PythonのBeatiful Soupというライブラリを使ってスクレイパーを作成しました。

# Import libralies
import bs4
from sys import argv
import urllib.request, urllib.error
import os
import argparse
import sys
import json


def get_soup(url, header):
    # Use BeautifulSoup for scraper
    return bs4.BeautifulSoup(urllib.request.urlopen(urllib.request.Request(url, headers=header)), 'html.parser')


def main(args):
    # Arg parser
    # -k: Keyword for image search. Several keywords can work like; banana+monkey
    # -n: Number of image files to download
    # -o: Output directory
    parser = argparse.ArgumentParser(description='Options for scraping Google images')
    parser.add_argument('-k', '--search', default='banana', type=str, help='Search keywords')
    parser.add_argument('-n', '--num_images', default=200, type=int, help='Number of images to scrape')
    parser.add_argument('-o', '--directory', default='./', type=str, help='Output directory')
    args = parser.parse_args()

    # Put several keywords together
    query = args.search.split()
    query = '+'.join(query)

    # Make output folder if it does not exist
    save_directory = args.directory + '/' + query
    if not os.path.exists(save_directory):
        os.makedirs(save_directory)

    # Define scraper
    url = "https://www.google.com/search?q=" + urllib.parse.quote_plus(query, encoding='utf-8') + "&source=lnms&tbm=isch"
    header = {'User-Agent': "Mozilla/5.0 (Macintosh; Intel Mac OS X 10_12_3)"
                            "AppleWebKit/537.36 (KHTML, like Gecko)"
                            "Chrome/43.0.2357.134 Safari/537.36"}
    soup = get_soup(url, header)

    image_data_list = []
    # Get search result, (image url and extension)
    for a in soup.find_all("div", {"class": "rg_meta"}):
        link, img_type = json.loads(a.text)["ou"], json.loads(a.text)["ity"]
        image_data_list.append((link, img_type))

    # Save each image file
    for index, (img, img_type) in enumerate(image_data_list):
        try:
            img_type = img_type if len(img_type) > 0 else 'jpg'
            print("Downloading image {} ({}), type is {}".format(index, img, img_type))
            raw_img = urllib.request.urlopen(img).read()
            f = open(os.path.join(save_directory, "img_"+str(index)+"."+img_type), 'wb')
            f.write(raw_img)
            f.close()
        except Exception as e:
            print("could not load : "+img)
            print(e)


if __name__ == "__main__":
    try:
        main(argv)
    except KeyboardInterrupt:
        pass
    sys.exit()

上記のコードをimage_collector.pyとして保存、実行します。
以下の例ではニコちゃん大王の画像を100枚集めています。

python image_collector -k ニコちゃん大王 -o ./data -n 100

無心で顔デカキャラの画像を集めていきます。(あくまで個人的な研究目的にのみ利用)

類似度計算

TuricreateというAppleが公開しているOSSで簡単に類似度計算ができると聞いたのでそちらを使っていきます。

ResNet-50をベースとした転移学習を行なっているようで、画像のリサイズ等の前処理含めフォルダを指定するだけでよしなにしてくれます。流れとしては以下の感じです。

  1. trainディレクトリにある画像をもとにモデルを作成、保存
  2. モデルの読み込み
  3. trainディレクトリ内の画像に対して入力画像の類似度を計算、近い順にソート

また、kを指定することでユークリッド距離を基準にソートされた類似画像を何枚まで返すかを指定できます。

以下のコードの一部を載せておきます。
本記事での開発に使ったコードはGitHubに置いておきます。

  • 著作権の都合で度々コードに登場するdatasetに関わるディレクトリがないことに注意してください。
  • image_file_pathはアプリ等でユーザーからの入力画像を想定しています。
import turicreate as tc

# Save training data
train_data = tc.image_analysis.load_images('train')
train_data = train_data.add_row_number()
train_data.save(os.path.join('train_data', 'train.sframe'))

# Train model with Resnet pre-trained model
trained_model = tc.image_similarity.create(train_data)

# Save temporary image (tmp_data内に入力画像が保存されている前提)
tmp_data = tc.image_analysis.load_images('tmp_data')
tmp_data = tmp_data.add_row_number()
tmp_data.save(os.path.join('tmp_data', 'tmp.sframe'))

# Get image file ID from image file path
image_id = tmp_data[tmp_data['path'] == os.path.join('tmp_data', image_file_path)]['id'][0]

# Query with test data (k is number of the nearest neighbors to return)
query_results = model.query(tmp_data[image_id]['image'], k=10)

# Visualize result
tmp_data[image_id]['image'].show()  # Leave it for debug

# Show result
similar_image_ids = query_results[query_results['query_label'] == 0]['reference_label']
similar_image_top_ten = train_data.filter_by(similar_image_ids, 'id')['path']
train_data.filter_by(similar_image_ids, 'id').explore()

テスト

Turicreateには.explore()という関数が実装されており、類似度計算の実行結果をこのようにGUIで表示させてくれます。

上から順に類似度が高い画像になります。
もちろん実際のユークリッド距離の値もデバッガー等で確認することもできます。

Flaskアプリ化

ここからはFlaskを使ったアプリ化についてです。

  1. ユーザーが入力画像を選択、送信
  2. バックエンドでTuricreateを使って推論、類似度計算
  3. もっとも類似度の高い画像を返し、入力画像とともに表示

こんな感じのアプリを作っていきます。

# Import libraries
import os
import base64
import shutil
from PIL import Image
from io import BytesIO
from werkzeug import secure_filename
from flask import Flask, render_template, request, redirect, url_for
from src.image_similarity_evaluation import ImageSimilarityEvaluation

# Init Flask app
app = Flask(__name__)
app.config.update(
    STATIC_FOLDER='src/static',
    APPLICATION_ROOT='../../',
    UPLOAD_FOLDER='dataset/tmp',
    REFERENCE_FOLDER='dataset/train',
    TARGET_FILE='',
    ALLOWED_EXTENSIONS=set(['png', 'jpg', 'jpeg', 'PNG', 'JPG']),
    SECRET_KEY=os.urandom(24),
    BOOTSTRAP_SERVE_LOCAL=True,
    # Flask-Dropzone config:
    DROPZONE_ALLOWED_FILE_TYPE='image',
    DROPZONE_MAX_FILE_SIZE=3,
    DROPZONE_MAX_FILES=20,
    DROPZONE_UPLOAD_ON_CLICK=True
)

# Init for Image Similarity Evaluation
app.ISE = ImageSimilarityEvaluation(setting_file='config/master_config.ini')
model_file_path = 'model/inada.model'


def routine():
    # Remove folder and recreate it
    shutil.rmtree(app.config['UPLOAD_FOLDER'])
    os.mkdir(app.config['UPLOAD_FOLDER'])
    print("Image files removed")


def allowed_file(filename):
    return '.' in filename and \
           filename.rsplit('.', 1)[1] in app.config['ALLOWED_EXTENSIONS']


@app.route('/')
def index():
    # routine()
    return render_template('index.html')


@app.route('/send', methods=['GET', 'POST'])
def send():
    # Get image data if it uses POST method
    if request.method == 'GET' or 'POST':
        img_file = request.files['img_file']

        # Get image file path if file extension is good
        if img_file and allowed_file(img_file.filename):
            filename = secure_filename(img_file.filename)
            app.config['TARGET_FILE'] = secure_filename(filename)
        else:
            return ''' <p>This file format is not currently supported</p> '''

        # Save image file into tmp directory
        if not os.path.isdir(app.config['UPLOAD_FOLDER']):
            os.mkdir(app.config['UPLOAD_FOLDER'])

        # Save target image file into tmp folder
        original_file_path = os.path.join(app.config['UPLOAD_FOLDER'], app.config['TARGET_FILE'])
        img_file.save(original_file_path)

        # Convert image data to binary
        image = Image.open(img_file)
        buffer = BytesIO()
        image.save(buffer, format="PNG")
        input_image_data = base64.b64encode(buffer.getvalue()).decode().replace("'", "")

        # Run evaluation script
        similar_image_paths = app.ISE.evaluation(image_file_path=app.config['TARGET_FILE'], model_file_path=model_file_path)
        similar_image_path = similar_image_paths[0]
        similar_image_path = similar_image_path.replace(app.config['STATIC_FOLDER']+'/', '')

        # Remove temp image data
        os.remove(os.path.join(app.config['UPLOAD_FOLDER'], app.config['TARGET_FILE']))
        return render_template('result.html', input_image_data=input_image_data, similar_image_path=similar_image_path)

    else:
        return redirect(url_for('index'))


@app.route('/retry', methods=['GET', 'POST'])
def retry():
    return redirect(url_for('index'))


if __name__ == '__main__':
    # Run app
    app.jinja_env.cache = {}
    app.run(debug=True, host='0.0.0.0', port=os.environ.get('PORT'))

入力された画像をバイナリ形式で保存、その画像データをhtmlに渡します。
また、入力画像を対象として先ほどの類似度計算を関数化したものを実行し、もっとも類似度の高い画像のファイルパスをhtmlに渡します。

html, css等のコードもGitHubに載せておきました。

完成したアプリはこのように見えます。

Docker化

最後にDockerを使ってFlaskアプリを簡単に起動できるようにします。
Flaskで利用するポート番号を5000にすることと、staticファイル達へのパスを指定するのに注意します。

FROM ubuntu:latest

RUN apt-get update
RUN apt-get install python3 python3-pip -y

# COPY stuff, directory path setting
ADD . /app
WORKDIR /app
ENV PYTHONPATH /app
ENV STATIC_URL src/static
ENV STATIC_PATH src/static

# Download libraries
RUN pip3 install -r src/requirements.txt

# Port forward
EXPOSE 5000

# Run web app
CMD python3 src/app.py

あとはDockerコマンドを使ってアプリを作成、起動します。

docker build -t charagao .
docker run -p 5000:5000 charagao

さて、気になる稲田ズイキ君に似ているキャラは。。。

おわりに

Turicreate, Flask, Dockerを使って画像類似度を計算するアプリを作ってみました。
今回は遊びで顔デカキャラを選びましたが、類似画像検索アプリの開発等にも応用できそうです。

他にも面白いお題があったら開発してみようと思うので、ぜひコメントやメッセージください。

今回協力してくれた稲田ズイキ君には全ての許可を取ってあります。今度焼肉でもおごります。

ここまで読んでくださってありがとうございました。

こっそりTwitterやってます。

Aki

※この記事は稲田くんの許可をとって制作しました。
見せてみたら、稲田くんは「美味しいな〜〜」と言って、とっても喜んでいました。
数少ない自慢の親友です。