raspberry pi 1でtensorflow lite その7


概要

raspberry pi 1でtensorflow liteやってみた。
tfliteファイルを作ってみた。
SavedModelから作ってみた。
データセットは、xor.

環境

tensorflow 1.12

モデルを学習してSavedModelを作る。

import tensorflow as tf
import tensorflow.contrib.lite as lite
import numpy as np
from tensorflow.python.framework import graph_util

X = [[0, 0], [0, 1], [1, 0], [1, 1]]
Y = [[1, 0], [0, 1], [0, 1], [1, 0]]
x = tf.placeholder(tf.float32, shape = [None, 2], name = "input")
y = tf.placeholder(tf.float32, shape = [None, 2])
w1 = tf.Variable(tf.random_uniform([2, 2], -1, 1, seed = 0))
w2 = tf.Variable(tf.random_uniform([2, 2], -1, 1, seed = 0))
b1 = tf.Variable(tf.zeros([2]))
b2 = tf.Variable(tf.zeros([2]))
h1 = tf.sigmoid(tf.matmul(x, w1) + b1)
h2 = tf.nn.softmax(tf.matmul(h1, w2) + b2, name = "output")
cost = -tf.reduce_sum(y * tf.log(h2))
opti = tf.train.GradientDescentOptimizer(0.1).minimize(cost)
graph = tf.get_default_graph()
graph_def = graph.as_graph_def()
with tf.Session() as sess:
  sess.run(tf.initialize_all_variables())
  for i in range(10000):
    sess.run(opti, feed_dict = {
      x: X,
      y: Y
    })
  for i in [[1, 1], [1, 0], [0, 1], [0, 0]]:
    print (i, sess.run(h2, feed_dict = {
      x: [i],
    }))
  signature_def = tf.saved_model.signature_def_utils.build_signature_def({
                "input": tf.saved_model.utils.build_tensor_info(x),
            }, {
                "output": tf.saved_model.utils.build_tensor_info(h2)
            }, tf.saved_model.signature_constants.REGRESS_METHOD_NAME)
  builder = tf.saved_model.builder.SavedModelBuilder("./model")
  builder.add_meta_graph_and_variables(sess, [tf.saved_model.tag_constants.SERVING], signature_def_map = {
    tf.saved_model.signature_constants.DEFAULT_SERVING_SIGNATURE_DEF_KEY: signature_def
  }, assets_collection = tf.get_collection(tf.GraphKeys.ASSET_FILEPATHS))
  builder.save()

SavedModelからtfliteを作る。

import tensorflow as tf
import tensorflow.contrib.lite as lite

converter = lite.TFLiteConverter.from_saved_model("./model")
tflite_model = converter.convert()
open("xor4_model.tflite", "wb").write(tflite_model)

tfliteファイルを検証する。

import numpy as np
import tensorflow as tf
import tensorflow.contrib.lite as lite

interpreter = lite.Interpreter(model_path = "xor4_model.tflite")
interpreter.allocate_tensors()
input_details = interpreter.get_input_details()
output_details = interpreter.get_output_details()
print (input_details)
print (output_details)
input_shape = input_details[0]['shape']
input_data = np.array([[0.0, 0.0]], dtype = np.float32)
print(input_data)
interpreter.set_tensor(input_details[0]['index'], input_data)
interpreter.invoke()
output_data = interpreter.get_tensor(output_details[0]['index'])
print (output_data)
input_data = np.array([[1.0, 0.0]], dtype = np.float32)
print(input_data)
interpreter.set_tensor(input_details[0]['index'], input_data)
interpreter.invoke()
output_data = interpreter.get_tensor(output_details[0]['index'])
print (output_data)
input_data = np.array([[0.0, 1.0]], dtype = np.float32)
print(input_data)
interpreter.set_tensor(input_details[0]['index'], input_data)
interpreter.invoke()
output_data = interpreter.get_tensor(output_details[0]['index'])
print (output_data)
input_data = np.array([[1.0, 1.0]], dtype = np.float32)
print(input_data)
interpreter.set_tensor(input_details[0]['index'], input_data)
interpreter.invoke()
output_data = interpreter.get_tensor(output_details[0]['index'])
print (output_data)

以上。