TensorFlowはPBファイルのモデルを保存します.

8231 ワード

通常、TensorFlowを使用する時、モデルを保存する時に、ckpt形式のモデルファイルを使って、似たような語句を使ってモデルを保存します.
tf.train.Saver().save(sess,ckpt_file_path,max_to_keep=4,keep_checkpoint_every_n_hours=2) 
以下のステートメントを使用して、すべての変数情報を復元します.
saver.restore(sess,tf.train.latest_checkpoint('./ckpt'))  
しかし、この方法にはいくつかの欠点があります.まず、このモデルファイルはTensorFlowに依存しています.その枠組みの中でしか使えません.第二に、モデルを回復する前に、ネットワーク構造をもう一度定義してから、変数の値をネットワークに回復する必要があります.
Googleが推奨する保存モデルは、PBファイルのモデルを保存し、独立して動作し、クローズドされたプログレッシブ形式で、他の言語と深さ学習フレームの読み取り、訓練を継続し、TensorFlowのモデルを移転することができます.
その主な使用シーンは,前方導出inferenceのコードが統一されるように,モデルの作成と使用のデカップリングを実現することである.
また、PBファイルに保存するとモデルの変数が固定になり、モデルのサイズが大幅に小さくなり、携帯端末での動作に適しているという利点があります.
具体的な詳細は、このようなPBファイルはMetaGraphhのprotocol bufferフォーマットを示すファイルであり、MetaGraphhは計算図、データストリーム、および関連する変数と入出力signature、およびastertsは計算図を作成する際に追加のファイルを指す.
これは私が最初に見つけたMetaGraphの説明です.分かりやすいです.
When you are saving your graph,a MetaGraph is created.This the graph itself,and all the other metadata necessary for computation,as well as some user info can saved and sivericon.
主にtf.SavedModelBuiderクラスを使用してこの作業を完了し、複数の計算図を一つのPBファイルに保存することができます.複数のMetaGraphがあれば、最初のMetaGraphのバージョン番号だけを保持します.また、各MetaGraphに特殊な名前を指定して区別しなければなりません.例えば、serving or trining、CPU or GPU.
典型的なPBファイル保存のコードを見てみましょう.
import tensorflow as tf
import os
from tensorflow.python.framework import graph_util

pb_file_path = os.getcwd()

with tf.Session(graph=tf.Graph()) as sess:
    x = tf.placeholder(tf.int32, name='x')
    y = tf.placeholder(tf.int32, name='y')
    b = tf.Variable(1, name='b')
    xy = tf.multiply(x, y)
    #          name  
    op = tf.add(xy, b, name='op_to_store')

    sess.run(tf.global_variables_initializer())

    # convert_variables_to_constants     output_node_names,list(),    
    constant_graph = graph_util.convert_variables_to_constants(sess, sess.graph_def, ['op_to_store'])

    #    OP
    feed_dict = {x: 10, y: 3}
    print(sess.run(op, feed_dict))

    #        PB   
    with tf.gfile.FastGFile(pb_file_path+'model.pb', mode='wb') as f:
        f.write(constant_graph.SerializeToString())

    #   
    # INFO:tensorflow:Froze 1 variables.
    # Converted 1 variables to const ops.
    # 31
PBモデルファイルの一般コードを読み込みます.
from tensorflow.python.platform import gfile

sess = tf.Session()
with gfile.FastGFile(pb_file_path+'model.pb', 'rb') as f:
    graph_def = tf.GraphDef()
    graph_def.ParseFromString(f.read())
    sess.graph.as_default()
    tf.import_graph_def(graph_def, name='') #      

#                
sess.run(tf.global_variables_initializer())

#        
print(sess.run('b:0'))
# 1

#   
input_x = sess.graph.get_tensor_by_name('x:0')
input_y = sess.graph.get_tensor_by_name('y:0')

op = sess.graph.get_tensor_by_name('op_to_store:0')

ret = sess.run(op,  feed_dict={input_x: 5, input_y: 5})
print(ret)
#    26
また、save model形式として保存しても、モデルのPBファイルを生成することができ、より簡単です.
import tensorflow as tf
import os
from tensorflow.python.framework import graph_util

pb_file_path = os.getcwd()

with tf.Session(graph=tf.Graph()) as sess:
    x = tf.placeholder(tf.int32, name='x')
    y = tf.placeholder(tf.int32, name='y')
    b = tf.Variable(1, name='b')
    xy = tf.multiply(x, y)
    #          name  
    op = tf.add(xy, b, name='op_to_store')

    sess.run(tf.global_variables_initializer())

    #    OP
    feed_dict = {x: 10, y: 3}
    print(sess.run(op, feed_dict))




    #     ,    saved_model_builder  
    builder = tf.saved_model.builder.SavedModelBuilder(pb_file_path+'savemodel')
    #          ,       session,    tag, 
    #         ,     
    builder.add_meta_graph_and_variables(sess,
                                       ['cpu_server_1'])


#       MetaGraphDef 
#with tf.Session(graph=tf.Graph()) as sess:
#  ...
#  builder.add_meta_graph([tag_constants.SERVING])
#...

builder.save()  #    PB   
保存してからsaved_に行きます.モデルディレクトリーの下にsaved_があります.model.pbファイル及びvariablesフォルダ.名前の通り、variabelesはすべての変数を保存します.モデル構造などの情報を保存するためにmodel.pbを使用します.
この方法はモデルを導入する方法に対応する.
import tensorflow as tf
import os
from tensorflow.python.framework import graph_util

pb_file_path = './'

with tf.Session(graph=tf.Graph()) as sess:
    tf.saved_model.loader.load(sess, ['cpu_server_1'], './way-4savemodel') 
    #            tag
    #         model      ,         variables
    sess.run(tf.global_variables_initializer())

    input_x = sess.graph.get_tensor_by_name('x:0')
    input_y = sess.graph.get_tensor_by_name('y:0')

    op = sess.graph.get_tensor_by_name('op_to_store:0')

    ret = sess.run(op,  feed_dict={input_x: 5, input_y: 5})
    print(ret)
#             session,    tag,         ,        
これは前の導入PBモデルと同じで、tenssorのnameを知ることです.では、どうやってテナントを知らずに使うことができますか?add_を与えるmeta_graphand_variables方法は第三のパラメータに入ります.def_mapでいいです転載:https://zhuanlan.zhihu.com/p/32887066