tensorflowで訓練したモデルをandroidに移植する
5779 ワード
tensorflowで訓練したモデルをandroidに移植する
説明
本稿では、訓練されたモデルをandroidデバイスに移植し、androidデバイスに処理対象データを入力し、モデルを介して出力データを取得する方法について説明する.1つの例を通して,移植全体の過程を述べる.(demoのソースコードはgithubにアクセスしましたhttps://github.com/CrystalChen1017/TSFOnAndroid)全体の考え方は以下の通りである. pythonを使用してPC上でモデルを訓練し、pbファイル に保存します. androidプロジェクトを新規作成し、pbファイルをassetsフォルダの下に tensorflowのsoファイルおよびjarパッケージをlibsの下に 配置するライブラリファイルをロードし、tensorflowをappで実行する の準備を tensorflowの環境、参照http://blog.csdn.net/cxq234843654/article/details/70857562 libtensorflow_inference.so libandroid_tensorflow_inference_java.jar 上記の2つのファイルを自分でコンパイルするにはbazelをインストールする必要があります.参照http://blog.csdn.net/cxq234843654/article/details/70861155の2ステップ 以上の2つのファイルは、次の2つのWebサイトからダウンロードされます.https://github.com/CrystalChen1017/TSFOnAndroid/tree/master/app/libsまたはhttp://download.csdn.net/detail/cxq234843654/9833372
PC側モデルの準備
これは簡単なモデルで,入力は配列matrix 1であり,操作後,この配列に2*matrix 1を乗じたものである.は入力データに を付与する必要がある.は、送信データに を取得する必要がある.はtfを使用することができない.train.write_graph()は、モデルの構造を保存するだけであり、訓練済みのパラメータ値 を保存しないため、モデルを保存する.はtfを使用することができない.train.saver()は、ネットワーク内のパラメータ値を保存するだけで、モデルの構造を保存しないため、モデルを保存します. を指定する. を指定する. に書き込む.
実行するとmodelフォルダの下にcxqが生成されます.pbファイル、今このファイルはさっきの一連の操作を固化したので、次回変数を計算して2を乗じる必要がある場合、私たちは直接pbファイルを手に入れて、入力を指定して、出力を取得することができます.
(オプション)bazelはsoとjarファイルをコンパイルします
自分でtensorflowのソースコードでsoとjarファイルをコンパイルしたい場合は、端末を介してtensorflowのディレクトリの下に入り、以下の操作を行う必要があります.コンパイルsoライブラリ
コンパイル完了後、libtensorflow_inference.soの経路は:/tensorflow/bazel-bin/tensorflow/contrib/android jarパッケージのコンパイル
コンパイル完了後、android_tensorflow_inference_java.JArのパスは:/tensorflow/bazel-bin/tensorflow/contrib/android
Android側の準備 Androidプロジェクト を新規作成さっきのpbファイルをassetsフォルダの下に に保存します. libandroid_tensorflow_inference_java.JArは/app/libsディレクトリに格納され、右クリック「add as Libary」 /app/libsの下にarmeabiフォルダを新規作成し、libtensorflow_inference.so app:gradleおよびgradleを構成する.properties androidノードの下にsoureSetsを追加し、jniLibsのパス を作成する. defaultConfigノードの下に を追加はgradle.propertiesに次の行 を追加
以上の3ステップの操作によりtensorflowの環境が導入されました.
モデルの呼び出し
まずMyTSFクラスを新規作成し、このクラスでモデルの呼び出しを行い、出力を取得します.
ActivityでのMyTSFクラスの使用
説明
本稿では、訓練されたモデルをandroidデバイスに移植し、androidデバイスに処理対象データを入力し、モデルを介して出力データを取得する方法について説明する.1つの例を通して,移植全体の過程を述べる.(demoのソースコードはgithubにアクセスしましたhttps://github.com/CrystalChen1017/TSFOnAndroid)全体の考え方は以下の通りである.
PC側モデルの準備
これは簡単なモデルで,入力は配列matrix 1であり,操作後,この配列に2*matrix 1を乗じたものである.
input
と命名する、android側ではこのinput
を用いて入力データにoutput
と命名する、android側では、このoutput
を用いて出力値graph_util.convert_variables_to_constants
は、sesion全体を定数として保存することができ、output_node_names
パラメータにより出力tf.gfile.FastGFile('model/cxq.pb', mode='wb')
保存ファイルの経路及び読み書き方式f.write(output_graph_def.SerializeToString())
固化モデルをファイル# -*- coding:utf-8 -*-
import tensorflow as tf
from tensorflow.python.client import graph_util
session = tf.Session()
matrix1 = tf.constant([[3., 3.]], name='input')
add2Mat = tf.add(matrix1, matrix1, name='output')
session.run(add2Mat)
output_graph_def = graph_util.convert_variables_to_constants(session, session.graph_def,output_node_names=['output'])
with tf.gfile.FastGFile('model/cxq.pb', mode='wb') as f:
f.write(output_graph_def.SerializeToString())
session.close()
実行するとmodelフォルダの下にcxqが生成されます.pbファイル、今このファイルはさっきの一連の操作を固化したので、次回変数を計算して2を乗じる必要がある場合、私たちは直接pbファイルを手に入れて、入力を指定して、出力を取得することができます.
(オプション)bazelはsoとjarファイルをコンパイルします
自分でtensorflowのソースコードでsoとjarファイルをコンパイルしたい場合は、端末を介してtensorflowのディレクトリの下に入り、以下の操作を行う必要があります.
bazel build -c opt //tensorflow/contrib/android:libtensorflow_inference.so \
-- crosstool_top=//external:android/crosstool \
-- host_crosstool_top=@bazel_tools//tools/cpp:toolchain \
-- cpu=armeabi-v7a
コンパイル完了後、libtensorflow_inference.soの経路は:/tensorflow/bazel-bin/tensorflow/contrib/android
bazel build //tensorflow/contrib/android:android_tensorflow_inference_java
コンパイル完了後、android_tensorflow_inference_java.JArのパスは:/tensorflow/bazel-bin/tensorflow/contrib/android
Android側の準備
sourceSets {
main {
jniLibs.srcDirs = ['libs']
}
}
defaultConfig {
ndk {
abiFilters "armeabi"
}
}
android.useDeprecatedNdk=true
以上の3ステップの操作によりtensorflowの環境が導入されました.
モデルの呼び出し
まずMyTSFクラスを新規作成し、このクラスでモデルの呼び出しを行い、出力を取得します.
package com.learn.tsfonandroid;
import android.content.res.AssetManager;
import android.os.Trace;
import org.tensorflow.contrib.android.TensorFlowInferenceInterface;
public class MyTSF {
private static final String MODEL_FILE = "file:///android_asset/cxq.pb"; //
//
private static final int HEIGHT = 1;
private static final int WIDTH = 2;
//
private static final String inputName = "input";
//
private float[] inputs = new float[HEIGHT * WIDTH];
//
private static final String outputName = "output";
//
private float[] outputs = new float[HEIGHT * WIDTH];
TensorFlowInferenceInterface inferenceInterface;
static {
//
System.loadLibrary("tensorflow_inference");
}
MyTSF(AssetManager assetManager) {
//
inferenceInterface = new TensorFlowInferenceInterface(assetManager,MODEL_FILE);
}
public float[] getAddResult() {
//
inputs[0]=1;
inputs[1]=3;
// feed tensorflow
Trace.beginSection("feed");
inferenceInterface.feed(inputName, inputs, WIDTH, HEIGHT);
Trace.endSection();
// 2
Trace.beginSection("run");
String[] outputNames = new String[] {outputName};
inferenceInterface.run(outputNames);
Trace.endSection();
// outputs
Trace.beginSection("fetch");
inferenceInterface.fetch(outputName, outputs);
Trace.endSection();
return outputs;
}
}
ActivityでのMyTSFクラスの使用
public void click01(View v){
Log.i(TAG, "click01: ");
MyTSF mytsf=new MyTSF(getAssets());
float[] result=mytsf.getAddResult();
for (int i=0;i