tensorflowで訓練したモデルをandroidに移植する

5779 ワード

tensorflowで訓練したモデルをandroidに移植する
説明
本稿では、訓練されたモデルをandroidデバイスに移植し、androidデバイスに処理対象データを入力し、モデルを介して出力データを取得する方法について説明する.1つの例を通して,移植全体の過程を述べる.(demoのソースコードはgithubにアクセスしましたhttps://github.com/CrystalChen1017/TSFOnAndroid)全体の考え方は以下の通りである.
  • pythonを使用してPC上でモデルを訓練し、pbファイル
  • に保存します.
  • androidプロジェクトを新規作成し、pbファイルをassetsフォルダの下に
  • tensorflowのsoファイルおよびjarパッケージをlibsの下に
  • 配置する
  • ライブラリファイルをロードし、tensorflowをappで実行する
  • の準備を
  • tensorflowの環境、参照http://blog.csdn.net/cxq234843654/article/details/70857562
  • libtensorflow_inference.so
  • libandroid_tensorflow_inference_java.jar
  • 上記の2つのファイルを自分でコンパイルするにはbazelをインストールする必要があります.参照http://blog.csdn.net/cxq234843654/article/details/70861155の2ステップ
  • 以上の2つのファイルは、次の2つのWebサイトからダウンロードされます.https://github.com/CrystalChen1017/TSFOnAndroid/tree/master/app/libsまたはhttp://download.csdn.net/detail/cxq234843654/9833372
    PC側モデルの準備
    これは簡単なモデルで,入力は配列matrix 1であり,操作後,この配列に2*matrix 1を乗じたものである.
  • は入力データにinputと命名する、android側ではこのinputを用いて入力データに
  • を付与する必要がある.
  • は、送信データにoutputと命名する、android側では、このoutputを用いて出力値
  • を取得する必要がある.
  • はtfを使用することができない.train.write_graph()は、モデルの構造を保存するだけであり、訓練済みのパラメータ値
  • を保存しないため、モデルを保存する.
  • はtfを使用することができない.train.saver()は、ネットワーク内のパラメータ値を保存するだけで、モデルの構造を保存しないため、モデルを保存します.
  • graph_util.convert_variables_to_constantsは、sesion全体を定数として保存することができ、output_node_namesパラメータにより出力
  • を指定する.
  • tf.gfile.FastGFile('model/cxq.pb', mode='wb')保存ファイルの経路及び読み書き方式
  • を指定する.
  • f.write(output_graph_def.SerializeToString())固化モデルをファイル
  • に書き込む.
    # -*- coding:utf-8 -*-
    import tensorflow as tf
    from tensorflow.python.client import graph_util
    
    session = tf.Session()
    
    matrix1 = tf.constant([[3., 3.]], name='input')
    add2Mat = tf.add(matrix1, matrix1, name='output')
    
    session.run(add2Mat)
    
    output_graph_def = graph_util.convert_variables_to_constants(session, session.graph_def,output_node_names=['output'])
    
    with tf.gfile.FastGFile('model/cxq.pb', mode='wb') as f:
        f.write(output_graph_def.SerializeToString())
    
    session.close()
    
    

    実行するとmodelフォルダの下にcxqが生成されます.pbファイル、今このファイルはさっきの一連の操作を固化したので、次回変数を計算して2を乗じる必要がある場合、私たちは直接pbファイルを手に入れて、入力を指定して、出力を取得することができます.
    (オプション)bazelはsoとjarファイルをコンパイルします
    自分でtensorflowのソースコードでsoとjarファイルをコンパイルしたい場合は、端末を介してtensorflowのディレクトリの下に入り、以下の操作を行う必要があります.
  • コンパイルsoライブラリ
  • bazel build -c opt //tensorflow/contrib/android:libtensorflow_inference.so \
        -- crosstool_top=//external:android/crosstool \
        -- host_crosstool_top=@bazel_tools//tools/cpp:toolchain \
        -- cpu=armeabi-v7a
    

    コンパイル完了後、libtensorflow_inference.soの経路は:/tensorflow/bazel-bin/tensorflow/contrib/android
  • jarパッケージのコンパイル
  • bazel build //tensorflow/contrib/android:android_tensorflow_inference_java
    
    

    コンパイル完了後、android_tensorflow_inference_java.JArのパスは:/tensorflow/bazel-bin/tensorflow/contrib/android
    Android側の準備
  • Androidプロジェクト
  • を新規作成
  • さっきのpbファイルをassetsフォルダの下に
  • に保存します.
  • libandroid_tensorflow_inference_java.JArは/app/libsディレクトリに格納され、右クリック「add as Libary」
  • /app/libsの下にarmeabiフォルダを新規作成し、libtensorflow_inference.so
  • app:gradleおよびgradleを構成する.properties
  • androidノードの下にsoureSetsを追加し、jniLibsのパス
  • を作成する.
    sourceSets {
            main {
                jniLibs.srcDirs = ['libs']
            }
        }
    
  • defaultConfigノードの下に
  • を追加
    defaultConfig {
    
            ndk {
                abiFilters "armeabi"
            }
        }
    
  • はgradle.propertiesに次の行
  • を追加
    android.useDeprecatedNdk=true
    

    以上の3ステップの操作によりtensorflowの環境が導入されました.
    モデルの呼び出し
    まずMyTSFクラスを新規作成し、このクラスでモデルの呼び出しを行い、出力を取得します.
    package com.learn.tsfonandroid;
    
    import android.content.res.AssetManager;
    import android.os.Trace;
    
    import org.tensorflow.contrib.android.TensorFlowInferenceInterface;
    
    
    public class MyTSF {
        private static final String MODEL_FILE = "file:///android_asset/cxq.pb"; //      
    
        //     
        private static final int HEIGHT = 1;
        private static final int WIDTH = 2;
    
        //          
        private static final String inputName = "input";
        //           
        private float[] inputs = new float[HEIGHT * WIDTH];
    
        //          
        private static final String outputName = "output";
        //           
        private float[] outputs = new float[HEIGHT * WIDTH];
    
    
    
        TensorFlowInferenceInterface inferenceInterface;
    
    
        static {
            //     
            System.loadLibrary("tensorflow_inference");
        }
    
        MyTSF(AssetManager assetManager) {
            //    
            inferenceInterface = new TensorFlowInferenceInterface(assetManager,MODEL_FILE);
        }
    
        public float[] getAddResult() {
            //       
            inputs[0]=1;
            inputs[1]=3;
    
            //   feed tensorflow
            Trace.beginSection("feed");
            inferenceInterface.feed(inputName, inputs, WIDTH, HEIGHT);
            Trace.endSection();
    
            //   2   
            Trace.beginSection("run");
            String[] outputNames = new String[] {outputName};
            inferenceInterface.run(outputNames);
            Trace.endSection();
    
            //      outputs 
            Trace.beginSection("fetch");
            inferenceInterface.fetch(outputName, outputs);
            Trace.endSection();
    
            return outputs;
        }
    
    
    }
    
    
    

    ActivityでのMyTSFクラスの使用
     public void click01(View v){
            Log.i(TAG, "click01: ");
            MyTSF mytsf=new MyTSF(getAssets());
            float[] result=mytsf.getAddResult();
            for (int i=0;i