PC側で変換後のTFLiteモデルをテストする方法
2238 ワード
前編ではtensorflowモデルをモバイル側に配置するためにtfliteモデルに変換する方法について説明します.
この分かち合いはどのようにPC端でtflite模型に対して予測を行って、模型が利用できるかどうかをテストします
まず、tfliteモデルをロードし、モデルの入出力を表示します.
印刷結果:
そこで私たちはここでinputのデータを処理する必要があります.便利のために、私はデータをcsvファイルに保存して、取り出して直接使用することができます.
これでinputデータを設定し、ネットワークモデルにデータを転送します.
予測の開始:
予測結果の読み出し
output_dataは予測結果のソースデータ、printが出てくるshapeは上output_detailのshape
この分かち合いはどのようにPC端でtflite模型に対して予測を行って、模型が利用できるかどうかをテストします
まず、tfliteモデルをロードし、モデルの入出力を表示します.
import numpy as np
import tensorflow as tf
import cv2
# Load TFLite model and allocate tensors.
interpreter = tf.lite.Interpreter(model_path="newModel.tflite")
interpreter.allocate_tensors()
input_details = interpreter.get_input_details()
output_details = interpreter.get_output_details()
print(input_details)
print(output_details)
印刷結果:
[{'name': 'input_1', 'index': 115, 'shape': array([ 1, 224, 224, 3]), 'dtype': , 'quantization': (0.0, 0)}]
[{'name': 'activation_1/truediv', 'index': 6, 'shape': array([ 1, 12544, 2]), 'dtype': , 'quantization': (0.0, 0)}]
input details
から分かるように、入力するnumpy配列[1,224,224,3]、データ型はfloat 32である.index:115
の意味は、データが格納されている場所だと理解しています.そこで私たちはここでinputのデータを処理する必要があります.便利のために、私はデータをcsvファイルに保存して、取り出して直接使用することができます.
input_data = np.loadtxt('C:/Users\WIN10/input.csv',delimiter=',')
input_data = input_data.reshape(1,224,224,3)
input_data = input_data.astype(np.float32)
index = input_details[0]['index']
interpreter.set_tensor(index, input_data)
これでinputデータを設定し、ネットワークモデルにデータを転送します.
予測の開始:
interpreter.invoke()
予測結果の読み出し
output_data = interpreter.get_tensor(output_details[0]['index'])
print('output_data shape:',output_data.shape)
output_dataは予測結果のソースデータ、printが出てくるshapeは上output_detailのshape
[ 1, 12544, 2]
最後に、予測されたソースデータの後処理を行い、私たちが望む結果を解析する必要があります.ここでは、工事ごとに異なり、参考に供します.output_data = output_data.reshape(224,112)
pr = output_data.reshape(112,112,2).argmax( axis=2 )
seg_img = np.zeros( ( 112 , 112 , 3 ) )
seg_img[:,:,0] += ((pr[:,: ] == 1 )*200).astype('uint8')
seg_img[:,:,1] += ((pr[:,: ] == 1 )*200).astype('uint8')
seg_img[:,:,2] += ((pr[:,: ] == 1 )*200).astype('uint8')
cv2.imshow('img',seg_img)
cv2.waitKey(0)