TensorflowノートGraph

24734 ワード

[TOC] https://blog.csdn.net/lovelyaiq/article/details/78646401悪くない
モデルインタフェースの設定
保存プロセス
import tensorflow as tf

w1 = tf.Variable(20.0, name="w1")
w2 = tf.Variable(30.0, name="w2")

#   b1  placeholder,     ,     
#b1= tf.Variable(2.0,name="bias")
b1= tf.placeholder(tf.float32, name='bias')

w3 = tf.add(w1,w2)

#     name,      
out = tf.multiply(w3,b1,name="out")

#   Variable constant,         
with tf.Session() as sess:
    sess.run(tf.global_variables_initializer())
    #         tensor   
    graph = tf.graph_util.convert_variables_to_constants(sess, sess.graph_def, ["out"])
    tf.train.write_graph(graph, '.', './checkpoint_dir/graph.pb', as_text=False)

プロセスの使用
import tensorflow as tf
with tf.Session() as sess:
    with open('./checkpoint_dir/graph.pb', 'rb') as f:
        graph_def = tf.GraphDef() #    ,       sess.graph_def   
        graph_def.ParseFromString(f.read())
        #  "  pb"        
        output = tf.import_graph_def(graph_def, input_map={'bias:0':4.}, return_elements=['out:0'])
        print(sess.run(output))

モデルを読み込み、修正し、fine tune
前の図の構造に新しいネットワークを構築できますか?もちろんgraphを通ってもいいです.get_tensor_by_name()メソッドは適切な操作にアクセスし,これに基づいて図を構築する.これは本当の例です.ここでは、メタグラフを使用してvggプリトレーニングネットワークをロードし、最後のレイヤで出力の数を2に変更して、新しいデータを微調整します.
......  
......  
saver = tf.train.import_meta_graph('vgg.meta')  
# Access the graph  
graph = tf.get_default_graph()  
## Prepare the feed_dict for feeding data for fine-tuning   

#Access the appropriate output for fine-tuning  
fc7= graph.get_tensor_by_name('fc7:0')  

#use this if you only want to change gradients of the last layer  
fc7 = tf.stop_gradient(fc7) # It's an identity function  
fc7_shape= fc7.get_shape().as_list()  

new_outputs=2  
weights = tf.Variable(tf.truncated_normal([fc7_shape[3], num_outputs], stddev=0.05))  
biases = tf.Variable(tf.constant(0.05, shape=[num_outputs]))  
output = tf.matmul(fc7, weights) + biases  
pred = tf.nn.softmax(output)  

# Now, you run this with fine-tuning data in sess.run()  

ckptの保存と読み込み
  def snapshot(self, sess, iter):
    net = self.net

    if not os.path.exists(self.output_dir):
      os.makedirs(self.output_dir)

    # Store the model snapshot
    filename = cfg.TRAIN.SNAPSHOT_PREFIX + '_iter_{:d}'.format(iter) + '.ckpt'
    filename = os.path.join(self.output_dir, filename)
    self.saver.save(sess, filename)
    print('Wrote snapshot to: {:s}'.format(filename))

    # Also store some meta information, random state, etc.
    nfilename = cfg.TRAIN.SNAPSHOT_PREFIX + '_iter_{:d}'.format(iter) + '.pkl'
    nfilename = os.path.join(self.output_dir, nfilename)
    # current state of numpy random
    st0 = np.random.get_state()
    # current position in the database
    cur = self.data_layer._cur
    # current shuffled indexes of the database
    perm = self.data_layer._perm
    # current position in the validation database
    cur_val = self.data_layer_val._cur
    # current shuffled indexes of the validation database
    perm_val = self.data_layer_val._perm

    # Dump the meta info
    with open(nfilename, 'wb') as fid:
      pickle.dump(st0, fid, pickle.HIGHEST_PROTOCOL)
      pickle.dump(cur, fid, pickle.HIGHEST_PROTOCOL)
      pickle.dump(perm, fid, pickle.HIGHEST_PROTOCOL)
      pickle.dump(cur_val, fid, pickle.HIGHEST_PROTOCOL)
      pickle.dump(perm_val, fid, pickle.HIGHEST_PROTOCOL)
      pickle.dump(iter, fid, pickle.HIGHEST_PROTOCOL)

    return filename, nfilename

  def from_snapshot(self, sess, sfile, nfile):
    print('Restoring model snapshots from {:s}'.format(sfile))
    self.saver.restore(sess, sfile)
    print('Restored.')
    # Needs to restore the other hyper-parameters/states for training, (TODO xinlei) I have
    # tried my best to find the random states so that it can be recovered exactly
    # However the Tensorflow state is currently not available
    with open(nfile, 'rb') as fid:
      st0 = pickle.load(fid)
      cur = pickle.load(fid)
      perm = pickle.load(fid)
      cur_val = pickle.load(fid)
      perm_val = pickle.load(fid)
      last_snapshot_iter = pickle.load(fid)

      np.random.set_state(st0)
      self.data_layer._cur = cur
      self.data_layer._perm = perm
      self.data_layer_val._cur = cur_val
      self.data_layer_val._perm = perm_val

    return last_snapshot_iter

    #  !!!ckpt    
    ckpt = tf.train.get_checkpoint_state(checkpoint_dir)
        if ckpt and ckpt.model_checkpoint_path:
            saver.restore(sess, ckpt.model_checkpoint_path)

input/output nodes
1.tensorboardで
python tensorflow/python/tools/import_pb_to_tensorboard.py \
--model_dir resnetv1_50.pb --log_dir /tmp/tensorboard

2.あるいはこう使う
import tensorflow as tf
from tensorflow.python.platform import gfile
with tf.Session() as sess:
    model_filename ='PATH_TO_PB.pb'
    with gfile.FastGFile(model_filename, 'rb') as f:
        graph_def = tf.GraphDef()
        graph_def.ParseFromString(f.read())
        g_in = tf.import_graph_def(graph_def)
LOGDIR='YOUR_LOG_LOCATION'
train_writer = tf.summary.FileWriter(LOGDIR)
train_writer.add_graph(sess.graph)

ckpt.data, ckpt.meta, pb,
.metaファイルに現在の図面構造が保存されています
.indexファイルに現在のパラメータ名が保存されています
.dataファイルは現在のパラメータ値metaファイル保存図構造を保存し、weightsなどのパラメータはdataファイルに保存します.つまり、図とパラメータデータを別々に保存します.もっと率直に言えば、metaファイルにはweightsなどのデータがありません.ただし、metaファイルは定数を保存することに注意してください.dataファイルのパラメータをmetaファイルの定数に変えるだけでいいです.1.図の構造とパラメータ、すなわちckptをロードする.metaとckpt.data
saver = tf.train.import_meta_graph(ckpt.model_checkpoint_path +'.meta') 
with tf.Session() as sess:
    saver.restore(sess,.ckpt)
    for op in tf.get_default_graph().get_operations():
        print(op.name,op.values())
    print(sess.run(tf.get_default_graph().get_tensor_by_name('fc3:0,feed_dict={'Placeholder:0': imgs}))

2.データのみロード
saver = tf.train.Saver()
with tf.Session() as sess:
    ckpt = tf.train.get_checkpoint_state('./model/')
    saver.restore(sess,ckpt.model_checkpoint_path)
  • PB(バイナリモデル)モデルロード方法
  • #      
    self.graph = tf.Graph()
    #         
    with self.graph.as_default():
        #          
        with tf.gfile.FastGFile(os.path.join(model_dir,model_name),'rb') as f:
            #   GraphDef  ,            
            graph_def = tf.GraphDef()
            # GraphDef       ,                 
            graph_def.ParseFromString(f.read())
            #        GraphDef   ,    pb            ,        tf.graphDef()
            tf.import_graph_def(graph_def,name='')
            #            graph.get_tensor_by_name    
            #            session run     
            #         ,  'conv1'     , 'conv1:0'     ,            
            self.input_tensor = self.graph.get_tensor_by_name(self.input_tensor_name)
            self.layer_tensors = [self.graph.get_tensor_by_name(name + ':0') for name   in self.layer_operation_names]
            # tensorflow object_detection api     !!!!!!
            ops=tf.get_default_graph().get_operations()
            #op.outputs            ,op.outpupts.name      
            all_tensor_names={output.name for op in ops for output in op.outputs}
            tensor_dict={}
            for key in [ 'num_detections', 'detection_boxes', 'detection_scores','detection_classes', 'detection_masks'
          ]:
                  tensor_name=key+':0'#          
                  if tensor_name in all_tensor_names:
                        tensor_dict[key]=tf.get_default_graph().get_tensor_by_name(tensor_name)
    
    

    PBファイルの実行
    import tensorflow as tf
    import  numpy as np
    import PIL.Image as Image
    from skimage import io, transform
    
    def recognize(jpg_path, pb_file_path):
        with tf.Graph().as_default():
            output_graph_def = tf.GraphDef()
    
            with open(pb_file_path, "rb") as f:
                output_graph_def.ParseFromString(f.read())
                _ = tf.import_graph_def(output_graph_def, name="")#import    
    
            with tf.Session() as sess:
                init = tf.global_variables_initializer()
                sess.run(init)
    
                input_x = sess.graph.get_tensor_by_name("input:0")#          tensor  
                print input_x
                out_softmax = sess.graph.get_tensor_by_name("softmax:0")
                print out_softmax
                out_label = sess.graph.get_tensor_by_name("output:0")
                print out_label
    
                img = io.imread(jpg_path)
                img = transform.resize(img, (224, 224, 3))
                #  
                img_out_softmax = sess.run(out_softmax, feed_dict={input_x:np.reshape(img, [-1, 224, 224, 3])})
    
                print "img_out_softmax:",img_out_softmax
                prediction_labels = np.argmax(img_out_softmax, axis=1)
                print "label:",prediction_labels
    
    recognize("vgg16/picture/dog/dog3.jpg", "vgg16/vggs.pb")

    あるいは簡単です.
    sess = tf.Session()  
       #           GraphDef  
       model_f = gfile.FastGFile("model.pb",'rb')  
       graph_def = tf.GraphDef()  
       graph_def.ParseFromString(model_f.read())  
       #                ,                     
       c = tf.import_graph_def(graph_def,return_elements=["add:0"])  
       print(sess.run(c))  
       #[array([ 11.], dtype=float32)]  

    4.PB形式モデルとして保存する演算プロセスはget_を通過するdefault_graph().as_graph_def()現在の図の計算ノード情報をgraph_util.convert_variables_to_constantsは、関連ノードのvaluesをtfに固定する.gfile.GFileによるモデル永続化
    # coding=UTF-8
    import tensorflow as tf
    import shutil
    import os.path
    from tensorflow.python.framework import graph_util
    
    output_graph = "model/pb/add_model.pb"
    
    #           CNN、RNN         ,             
    input_holder = tf.placeholder(tf.float32, shape=[1], name="input_holder")
    W1 = tf.Variable(tf.constant(5.0, shape=[1]), name="W1")
    B1 = tf.Variable(tf.constant(1.0, shape=[1]), name="B1")
    _y = (input_holder * W1) + B1
    # predictions = tf.greater(_y, 50, name="predictions") # 50   true,    false
    predictions = tf.add(_y, 10,name="predictions") #       
    
    init = tf.global_variables_initializer()
    
    with tf.Session() as sess:
        sess.run(init)
        print "predictions : ", sess.run(predictions, feed_dict={input_holder: [10.0]})
        graph_def = tf.get_default_graph().as_graph_def() #        GraphDef   ,                        ,      tf.graphDef(),       。
        #convert_variables_to_constants  ,                   。          ,       GraphDef  ,GraphDef                。      ,  convert_variables_to_constants                    ,“add:0”       "add"         。
        output_graph_def = graph_util.convert_variables_to_constants(  #      ,      
            sess,
            graph_def,
            ["predictions"] #   pb     
        )
        with tf.gfile.GFile(output_graph, "wb") as f:  #     
            f.write(output_graph_def.SerializeToString())  #      
    
        #   :
        tf.train.write_graph(graph, '.', './checkpoint_dir/graph.pb', as_text=False)
    
        print("%d ops in the final graph." % len(output_graph_def.node))
        print (predictions)
    
    # for op in tf.get_default_graph().get_operations():         
    #     print (op.name)

    4 ckptをPB形式に変換し、CKPTモデルへのパスを介してモデルの図と変数データをimport_meta_graphインポートモデルの図はsaver.restoreモデルから図中の各変数のデータをgraph_で復元するutil.convert_variables_to_constantsモデルを永続化
    # coding=UTF-8
    import tensorflow as tf
    import os.path
    import argparse
    from tensorflow.python.framework import graph_util
    
    MODEL_DIR = "model/pb"
    MODEL_NAME = "frozen_model.pb"
    
    if not tf.gfile.Exists(MODEL_DIR): #    
        tf.gfile.MakeDirs(MODEL_DIR)
    
    def freeze_graph(model_folder):
        checkpoint = tf.train.get_checkpoint_state(model_folder) #     ckpt        
        input_checkpoint = checkpoint.model_checkpoint_path # ckpt    
        output_graph = os.path.join(MODEL_DIR, MODEL_NAME) #PB      
    
        output_node_names = "predictions" #            
        saver = tf.train.import_meta_graph(input_checkpoint + '.meta', clear_devices=True) #   、clear_devices :Whether or not to clear the device field for an `Operation` or `Tensor` during import.
    
        graph = tf.get_default_graph() #      
        input_graph_def = graph.as_graph_def()  #               
    
        with tf.Session() as sess:
            saver.restore(sess, input_checkpoint) #        
    
            print "predictions : ", sess.run("predictions:0", feed_dict={"input_holder:0": [10.0]}) #             ,                   tensor   ,         
    
            output_graph_def = graph_util.convert_variables_to_constants(  #     ,      
                sess,
                input_graph_def,
                output_node_names.split(",") #         ,     
            )
            with tf.gfile.GFile(output_graph, "wb") as f: #    
                f.write(output_graph_def.SerializeToString()) #     
            print("%d ops in the final graph." % len(output_graph_def.node)) #            
    
            for op in graph.get_operations():
                print(op.name, op.values())
    
    if __name__ == '__main__':
        parser = argparse.ArgumentParser()
        parser.add_argument("model_folder", type=str, help="input ckpt model dir") #     ,help    ,type      ,
        #              ckpt   ,     error: too few arguments
        aggs = parser.parse_args()
        freeze_graph(aggs.model_folder)
        # freeze_graph("model/ckpt") #