keras_to_tensorflow


概要

jsdoに、keras.jsで学習したグラフを使おうとしたが、だめだった。
だめなやつ
deeplearn.jsに、鞍替えするため、kerasのモデルをtensorflowに、変換した。
手順を記載する。

環境

windows 7 sp1 64bit
anaconda3
tensorflow 1.2

kerasでモデルを学習して、セーブする。

import numpy as np
from tensorflow.contrib.keras.python.keras.models import Model
from tensorflow.contrib.keras.python.keras.layers import Input, Dense, Activation
from tensorflow.contrib.keras.python.keras.optimizers import SGD
import matplotlib.pyplot as plt

x = np.arange(200).reshape(-1, 1) / 30
y = np.sin(x)

inputs = Input(shape = (1, ))
m = Dense(30)(inputs)
m = Activation('sigmoid')(m)
m = Dense(10)(m)
m = Activation('sigmoid')(m)
m = Dense(1)(m)
model = Model(inputs, m)
sgd = SGD(lr = 0.1)
model.compile(loss = 'mean_squared_error', optimizer = sgd)
model.summary();
model.fit(x, y, epochs = 400, batch_size = 10, verbose = 0)
preds = model.predict(x)
plt.plot(x, y, 'b', x, preds, 'r--')
plt.savefig("keras13.png")
plt.show()
with open('keras13_arch.json', 'w') as f:
    f.write(model.to_json())
model.save('keras13.h5')
print ("save model")

keras_to_tensorflowで、pbファイルに変換する。

pbファイルを、読んで、inputとoutputを確認。

import os
import os.path
import tensorflow as tf
from tensorflow.python.platform import gfile

with tf.Session() as sess:
    with gfile.FastGFile("keras13.pb", 'rb') as f:
        graph_def = tf.GraphDef()
        graph_def.ParseFromString(f.read())
        _ = tf.import_graph_def(graph_def, name = '')
    for op in tf.get_default_graph().get_operations():
        print (op.name)
        for output in op.outputs:
            print ('  ', output.name)


pbファイルを、読み込んで動作確認。

import os
import os.path
import tensorflow as tf
from tensorflow.python.platform import gfile
import numpy as np
import matplotlib.pyplot as plt

x = np.arange(200).reshape(-1, 1) / 30
y = np.sin(x)
g = []
with tf.Session() as sess:
    with gfile.FastGFile("keras13.pb", 'rb') as f:
        graph_def = tf.GraphDef()
        graph_def.ParseFromString(f.read())
        _ = tf.import_graph_def(graph_def, name = '')
    for i in range(200):
        result = sess.run('output_node0:0', feed_dict = {
            'input_1:0': [x[i]]
        })
        print (result)
        g.append(result)
    plt.plot(g)
    plt.savefig("predic13.png")
    plt.show()

以上。