tensorflow|tensorflow実装ckptをpbファイルに転送

15007 ワード

このブログでは、自分で訓練して保存したckptモデルをpbファイルに変換することを実現します.この方法は、任意のckptモデルに適用されます.もちろん、ckptモデルの入出力ノード名を決定する必要があります.
目次
tensorflow実装ckptをpbファイルに転送
一、CKPTをPB形式に変換する
二、pbモデル予測
三、ソースコードのダウンロードと資料の推薦
1、トレーニング方法
2、当ブログGithubアドレス
3、モデルをAndroidに移植する方法
tfを使用する.train.saver()モデルを保存すると複数のファイルが生成され、計算図の構造と図上のパラメータの値を異なるファイルに分けて格納されます.この方法はTensorFlowで最も一般的な保存方法である.
たとえば、次のコードが実行されると、saveディレクトリに4つのファイルが保存されます.
import tensorflow as tf
#       
v1 = tf.Variable(tf.random_normal([1, 2]), name="v1")
v2 = tf.Variable(tf.random_normal([2, 3]), name="v2")
init_op = tf.global_variables_initializer() #        
saver = tf.train.Saver() #   tf.train.Saver       
with tf.Session() as sess:
    sess.run(init_op)
    print("v1:", sess.run(v1)) #   v1、v2          
    print("v2:", sess.run(v2))
    saver_path = saver.save(sess, "save/model.ckpt")  #       save/model.ckpt  
    print("Model saved in file:", saver_path)

ここで、checkpointはチェックポイントファイルであり、ファイルは1つのディレクトリの下のすべてのモデルファイルのリストを保存している.model.ckpt.MetaファイルはTensorFlow計算図の構造を保存する、ニューラルネットワークのネットワーク構造と理解でき、このファイルはtf.train.import_meta_graphは現在のデフォルトの図にロードされて使用されます.ckpt.Data:モデル内の各変数の値を保存しますが、多くの場合、TensorFlowのモデルを単一のファイル(モデル構造の定義と重みを含む)にエクスポートし、Androidにネットワークを導入するなど、他の場所での使用を容易にする必要があります.tfを利用する.train.write_graph()のデフォルトではネットワークの定義(重みなし)のみが導出、tfが利用される.train.Saver().save()エクスポートファイルgraph_defは重みから分離されるため,別の方法が必要である.知ってるよgraphdefファイルにはネットワーク内のVariable値は含まれていません(通常は重みが格納されています)が、constant値が含まれているため、Variableをconstantに変換すれば、ネットワークアーキテクチャと重みを1つのファイルで同時に格納する目標を達成できます.
TensoFlowはconvert_を提供してくれましたvariables_to_constants()メソッドでは、モデル構造を硬化させ、計算図の変数値を定数で保存し、保存したモデルをAndroidプラットフォームに移植することができます.
一、CKPTをPB形式に変換し、CKPTをPB形式のファイルに変換する過程は以下のように簡単に述べることができる.
CKPTモデルへのパスを入力してモデルの図と変数データをimport_meta_graphインポートモデルの図はsaver.restoreモデルから図中の各変数のデータをgraph_で復元するutil.convert_variables_to_constantsがモデルを永続化した次のCKPTをPB形式に変換した例は、私がGoogleNet InceptionV 3モデルに保存したckpt回転pbファイルを訓練した例で、訓練過程はブログを参考にすることができる:『自分のデータセットを使ってGoogLenet InceptionNet V 1 V 2 V 3モデル(TensorFlow)』:
 
def freeze_graph(input_checkpoint,output_graph):
    '''
    :param input_checkpoint:
    :param output_graph: PB      
    :return:
    '''
    # checkpoint = tf.train.get_checkpoint_state(model_folder) #     ckpt        
    # input_checkpoint = checkpoint.model_checkpoint_path # ckpt    
 
    #          ,                 
    output_node_names = "InceptionV3/Logits/SpatialSqueeze"
    saver = tf.train.import_meta_graph(input_checkpoint + '.meta', clear_devices=True)
    graph = tf.get_default_graph() #       
    input_graph_def = graph.as_graph_def()  #                
 
    with tf.Session() as sess:
        saver.restore(sess, input_checkpoint) #        
        output_graph_def = graph_util.convert_variables_to_constants(  #      ,      
            sess=sess,
            input_graph_def=input_graph_def,#   :sess.graph_def
            output_node_names=output_node_names.split(","))#          ,     
 
        with tf.gfile.GFile(output_graph, "wb") as f: #    
            f.write(output_graph_def.SerializeToString()) #     
        print("%d ops in the final graph." % len(output_graph_def.node)) #            
 
        # for op in graph.get_operations():
        #     print(op.name, op.values())

説明:
1、関数freeze_graphでは、「出力のノード名を指定する」ことが最も重要です.このノード名は元のモデルに存在するノードでなければなりません.freeze操作では、出力ノードの名前を定義する必要があります.ネットワークは実は複雑なので、出力ノードの名前を定義すると、freezeのときにそのノードを出力するために必要なサブマップを固化するだけで、他の関係のないものは捨てられます.私たちのfreezeモデルの目的は次に予測することです.だから、output_node_namesは一般的にネットワークモデルの最後のレイヤが出力するノード名,あるいは我々が予測する目標である.
2、保存するときはconvert_variables_to_constants関数は、硬化するノード名を指定します.私のコードでは、硬化するノードは1つだけです.output_node_names.ノード名とテンソルの名前の違いに注意してください.たとえば、「input:0」はテンソルの名前であり、「input」はノードの名前を表します.
3、ソースコードにgraph=tfを通す.get_default_graph()は、saver=tfであるデフォルトの図を得る.train.import_meta_graph(input_checkpoint+'.meta',clear_devices=True)が復元する図なので、tfを先に実行する必要がある.train.import_meta_graph、tfを実行する.get_default_graph() .
4、実質的には、リカバリされたセッションsessで、デフォルトのネットワーク図を直接取得することができます.より簡単な方法は、以下の通りです.
def freeze_graph(input_checkpoint,output_graph):
    '''
    :param input_checkpoint:
    :param output_graph: PB      
    :return:
    '''
    # checkpoint = tf.train.get_checkpoint_state(model_folder) #     ckpt        
    # input_checkpoint = checkpoint.model_checkpoint_path # ckpt    
 
    #          ,                 
    output_node_names = "InceptionV3/Logits/SpatialSqueeze"
    saver = tf.train.import_meta_graph(input_checkpoint + '.meta', clear_devices=True)
 
    with tf.Session() as sess:
        saver.restore(sess, input_checkpoint) #        
        output_graph_def = graph_util.convert_variables_to_constants(  #      ,      
            sess=sess,
            input_graph_def=sess.graph_def,#   :sess.graph_def
            output_node_names=output_node_names.split(","))#          ,     
 
        with tf.gfile.GFile(output_graph, "wb") as f: #    
            f.write(output_graph_def.SerializeToString()) #     
        print("%d ops in the final graph." % len(output_graph_def.node)) #            

呼び出し方法は簡単です.ckptモデルのパスを入力し、pbモデルのパスを出力すればいいです.
#ckptモデルパスinput_を入力checkpoint='models/model.ckpt-1000'#出力pbモデルのパスout_pb_path="models/pb/frozen_model.pb"#freeze_を呼び出すgraphはckptをpb freeze_に変換graph(input_checkpoint,out_pb_path)5、上記および説明:保存時にconvert_variables_to_constants関数は、硬化するノード名を指定します.私のコードでは、硬化するノードは1つだけです.output_node_names.そのため、他のネットワークモデルでも、出力されたノード名output_を簡単に変更することができます.node_names、ckptをpbファイルに変換します.
PS:ノード名に注意し、name_を含めるscopeとvariable_scopeネーミングスペースは、「InceptionV 3/Logits/SpatialSqueeze」などの「/」で区切られています.
二、pbモデル予測以下はpbモデルを予測するコードである
 
def freeze_graph_test(pb_path, image_path):
    '''
    :param pb_path:pb     
    :param image_path:       
    :return:
    '''
    with tf.Graph().as_default():
        output_graph_def = tf.GraphDef()
        with open(pb_path, "rb") as f:
            output_graph_def.ParseFromString(f.read())
            tf.import_graph_def(output_graph_def, name="")
        with tf.Session() as sess:
            sess.run(tf.global_variables_initializer())
 
            #          ,           
            # input:0      ,keep_prob:0  dropout   ,     1,is_training:0    
            input_image_tensor = sess.graph.get_tensor_by_name("input:0")
            input_keep_prob_tensor = sess.graph.get_tensor_by_name("keep_prob:0")
            input_is_training_tensor = sess.graph.get_tensor_by_name("is_training:0")
 
            #          
            output_tensor_name = sess.graph.get_tensor_by_name("InceptionV3/Logits/SpatialSqueeze:0")
 
            #       
            im=read_image(image_path,resize_height,resize_width,normalization=True)
            im=im[np.newaxis,:]
            #             ,                tensor   ,         
            # out=sess.run("InceptionV3/Logits/SpatialSqueeze:0", feed_dict={'input:0': im,'keep_prob:0':1.0,'is_training:0':False})
            out=sess.run(output_tensor_name, feed_dict={input_image_tensor: im,
                                                        input_keep_prob_tensor:1.0,
                                                        input_is_training_tensor:False})
            print("out:{}".format(out))
            score = tf.nn.softmax(out, name='pre')
            class_id = tf.argmax(score, 1)
            print "pre class_id:{}".format(sess.run(class_id))

説明:
1、ckpt予測とは異なり、pbファイルはネットワークモデル構造を固化しているため、元のトレーニングモデル(train)のソースコードが分からなくても、ネットワーク図を復元し、予測することができる.復元モデルは、読み込まれたシーケンス化データからネットワーク構造をインポートするだけで簡単です.
tf.import_graph_def(output_graph_def,name=")2ですが、元のネットワークモデルの入出力のノード名を知る必要があります(もちろん、データを渡すときは、入出力のテンソルを入力することで完了します).InceptionV 3モデルの入力には3つのノードがあるため、ここでは、ネットワーク構造の入力テンソルに対応する入力テンソル名を定義する必要があります.
input_image_tensor = sess.graph.get_tensor_by_name("input:0") input_keep_prob_tensor = sess.graph.get_tensor_by_name("keep_prob:0") input_is_training_tensor = sess.graph.get_tensor_by_name(「is_training:0」)および出力のテンソル名:
output_tensor_name = sess.graph.get_tensor_by_name(「InceptionV 3/Logits/SpatialSqueeze:0」)3、予測の場合、feed入力データが必要:
#読み込まれたモデルが正しいかどうかをテストします.ここで入力ノードのtensorの名前が出力ノードと入力ノードの名前であり、操作ノードの名前ではありません.#out=sess.run("InceptionV3/Logits/SpatialSqueeze:0", feed_dict={'input:0': im,'keep_prob:0':1.0,'is_training:0':False}) out=sess.run(output_tensor_name,feed_dict={input_image_tensor:im,input_keep_prob_tensor:1.0,input_is_training_tensor:False})4、その他のネットワークモデル予測の場合、入出力するテンソルの名前を変更することもできます.
PS:注意テンソルの名称、即ち:ノード名+“:”+“id号”、例えば“InceptionV 3/Logits/SpatialSqueeze:0”
完全なCKPTをPB形式と予測に変換するコードは以下の通りである.
# -*-coding: utf-8 -*-
"""
    @Project: tensorflow_models_nets
    @File   : convert_pb.py
    @Author : panjq
    @E-mail : [email protected]
    @Date   : 2018-08-29 17:46:50
    @info   :
    -     CKPT                 
    -   import_meta_graph        
    -   saver.restore                
    -   graph_util.convert_variables_to_constants       
"""
 
import tensorflow as tf
from create_tf_record import *
from tensorflow.python.framework import graph_util
 
resize_height = 299  #       
resize_width = 299  #       
depths = 3
 
def freeze_graph_test(pb_path, image_path):
    '''
    :param pb_path:pb     
    :param image_path:       
    :return:
    '''
    with tf.Graph().as_default():
        output_graph_def = tf.GraphDef()
        with open(pb_path, "rb") as f:
            output_graph_def.ParseFromString(f.read())
            tf.import_graph_def(output_graph_def, name="")
        with tf.Session() as sess:
            sess.run(tf.global_variables_initializer())
 
            #          ,           
            # input:0      ,keep_prob:0  dropout   ,     1,is_training:0    
            input_image_tensor = sess.graph.get_tensor_by_name("input:0")
            input_keep_prob_tensor = sess.graph.get_tensor_by_name("keep_prob:0")
            input_is_training_tensor = sess.graph.get_tensor_by_name("is_training:0")
 
            #          
            output_tensor_name = sess.graph.get_tensor_by_name("InceptionV3/Logits/SpatialSqueeze:0")
 
            #       
            im=read_image(image_path,resize_height,resize_width,normalization=True)
            im=im[np.newaxis,:]
            #             ,                tensor   ,         
            # out=sess.run("InceptionV3/Logits/SpatialSqueeze:0", feed_dict={'input:0': im,'keep_prob:0':1.0,'is_training:0':False})
            out=sess.run(output_tensor_name, feed_dict={input_image_tensor: im,
                                                        input_keep_prob_tensor:1.0,
                                                        input_is_training_tensor:False})
            print("out:{}".format(out))
            score = tf.nn.softmax(out, name='pre')
            class_id = tf.argmax(score, 1)
            print "pre class_id:{}".format(sess.run(class_id))
 
 
def freeze_graph(input_checkpoint,output_graph):
    '''
    :param input_checkpoint:
    :param output_graph: PB      
    :return:
    '''
    # checkpoint = tf.train.get_checkpoint_state(model_folder) #     ckpt        
    # input_checkpoint = checkpoint.model_checkpoint_path # ckpt    
 
    #          ,                 
    output_node_names = "InceptionV3/Logits/SpatialSqueeze"
    saver = tf.train.import_meta_graph(input_checkpoint + '.meta', clear_devices=True)
 
    with tf.Session() as sess:
        saver.restore(sess, input_checkpoint) #        
        output_graph_def = graph_util.convert_variables_to_constants(  #      ,      
            sess=sess,
            input_graph_def=sess.graph_def,#   :sess.graph_def
            output_node_names=output_node_names.split(","))#          ,     
 
        with tf.gfile.GFile(output_graph, "wb") as f: #    
            f.write(output_graph_def.SerializeToString()) #     
        print("%d ops in the final graph." % len(output_graph_def.node)) #            
 
        # for op in sess.graph.get_operations():
        #     print(op.name, op.values())
 
def freeze_graph2(input_checkpoint,output_graph):
    '''
    :param input_checkpoint:
    :param output_graph: PB      
    :return:
    '''
    # checkpoint = tf.train.get_checkpoint_state(model_folder) #     ckpt        
    # input_checkpoint = checkpoint.model_checkpoint_path # ckpt    
 
    #          ,                 
    output_node_names = "InceptionV3/Logits/SpatialSqueeze"
    saver = tf.train.import_meta_graph(input_checkpoint + '.meta', clear_devices=True)
    graph = tf.get_default_graph() #       
    input_graph_def = graph.as_graph_def()  #                
 
    with tf.Session() as sess:
        saver.restore(sess, input_checkpoint) #        
        output_graph_def = graph_util.convert_variables_to_constants(  #      ,      
            sess=sess,
            input_graph_def=input_graph_def,#   :sess.graph_def
            output_node_names=output_node_names.split(","))#          ,     
 
        with tf.gfile.GFile(output_graph, "wb") as f: #    
            f.write(output_graph_def.SerializeToString()) #     
        print("%d ops in the final graph." % len(output_graph_def.node)) #            
 
        # for op in graph.get_operations():
        #     print(op.name, op.values())
 
 
if __name__ == '__main__':
    #   ckpt    
    input_checkpoint='models/model.ckpt-10000'
    #   pb     
    out_pb_path="models/pb/frozen_model.pb"
    #   freeze_graph ckpt  pb
    freeze_graph(input_checkpoint,out_pb_path)
 
    #   pb  
    image_path = 'test_image/animal.jpg'
    freeze_graph_test(pb_path=out_pb_path, image_path=image_path)

三、ソースコードのダウンロードと資料の推薦1、訓練方法の上のCKPTをPB形式に変換した例は、私がGoogleNet InceptionV 3モデルに保存したckptを訓練してpbファイルを転送した例で、訓練過程はブログを参考にすることができる.
『自分のデータセットを使ってGoogLenet InceptionNet V 1 V 2 V 3モデル(TensorFlow)』:https://blog.csdn.net/guyuealian/article/details/81560537
2、GithubアドレスGithubソース:https://github.com/PanJinquan/tensorflow_models_netsのconvert_pb.pyファイル
事前トレーニングモデルのダウンロードアドレス:https://download.csdn.net/download/guyuealian/10610847
3、モデルをAndroidに移植する方法pbファイルはAndroidプラットフォームに移植して実行することができ、その方法は参考にすることができる.
「tensorflowで訓練したモデルをAndroid(MNIST手書きデジタル識別)に移植する」
参照先:
[1]https://blog.csdn.net/guyuealian/article/details/79672257
【2】https://blog.csdn.net/yjl9122/article/details/78341689