【Flask+Keras】サーバーで複数モデルを高速で推論させる方法


結論

keras==2.2.4
tensorflow=1.14.0
numpy==1.16.4

テストコード

from flask import Flask
import time

import numpy as np
import tensorflow as tf
from keras.models import load_model
from keras.preprocessing.image import img_to_array, load_img

app = Flask(__name__)

model_path1 = "mnist.h5"
model1 = load_model(model_path1)
label1 = ["l0", "l1", "l2", "l3", "l4", "l5", "l6", "l7", "l8", "l9"]
model1._make_predict_function()#<めっちゃ重要>predictの高速化
graph1 = tf.get_default_graph()


model_path2 = "mnist.h5"
model2 = load_model(model_path2)
label2 = ["l0", "l1", "l2", "l3", "l4", "l5", "l6", "l7", "l8", "l9"]
model2._make_predict_function()
graph2 = tf.get_default_graph()

def model1_predict(img_path):
    img = img_to_array(load_img(img_path, target_size=(28, 28), grayscale=True))
    img_nad = img_to_array(img) / 255
    img_nad = img_nad[None, ...]
    global graph1
    with graph1.as_default():
        pred = model1.predict(img_nad, batch_size=1, verbose=0)
    score = np.max(pred)
    pred_label = label1[np.argmax(pred[0])]
    print("スコア:", score, "ラベル:", pred_label)

def model2_predict(img_path):
    img = img_to_array(load_img(img_path, target_size=(28, 28), grayscale=True))
    img_nad = img_to_array(img) / 255
    img_nad = img_nad[None, ...]
    global graph2
    with graph2.as_default():
        pred = model2.predict(img_nad, batch_size=1, verbose=0)
    score = np.max(pred)
    pred_label = label2[np.argmax(pred[0])]
    print("スコア:", score, "ラベル:", pred_label)

@app.route("/", methods=['GET', 'POST'])
def webapp():
    start1 = time.time()
    model1_predict("mnist_test.jpg")
    end1 = time.time()-start1
    print("処理時間<model1>: ", end1, "秒")

    start2 = time.time()
    model2_predict("mnist_test.jpg")
    end2 = time.time() - start2
    print("処理時間<model2>: ", end2, "秒")

    output = "<p>model1:"+str(round(end1, 3))+"秒</p><br><p>model2:"+str(round(end2, 3))+"秒</p>"
    return output

if __name__ == "__main__":
    app.run(port=5000, debug=False)

重要な部分

model1 = load_model(model_path1)
model1._make_predict_function()#<めっちゃ重要>predictの高速化
graph1 = tf.get_default_graph()

def model1_predict():
    global graph1
    with graph1.as_default():
        pred = model1.predict(***, batch_size=1, verbose=0)