tensorflow呼び出しプリトレーニングモデル


最近tensorflowのプリトレーニングモデルを使って、自分の心得を記録しています~
Tensorflow保存したモデルのウェイト値を読み込み、出力します.
リファレンスリンクhttps://blog.csdn.net/AManFromEarth/article/details/81057577 https://blog.csdn.net/aiseu001/article/details/79851176
import tensorflow as tf
import tensorflow.contrib.slim as slim
from tensorflow.python import pywrap_tensorflow
#  ,  tensorflow   python       
model_reader = pywrap_tensorflow.NewCheckpointReader(r"saver-test")

#  , reader      dict     
var_dict = model_reader.get_variable_to_shape_map()
print(len(var_dict))#          
print(var_dict) #          

#               
w1 = model_reader.get_tensor("conv1/W") #       conv1/W   (conv1     )
print(type(w1)) #    w1     
print(w1.shape) #    w1     (11, 11, 3, 96)
# print(w1)  #    w1  

#                
for key in var_dict:
    print("variable name: ", key)
    print(model_reader.get_tensor(key))


#         ,  :
with open("output.txt","w+") as f:
#      
    for key in var_dict:
        f.write(str(key))
        f.write(str(model_reader.get_tensor(key)))

プリトレーニングモデルを読み込み、選択的なロードパラメータがあります.
私は元のネットワークをいくつか修正して、いくつかの新しいパラメータを追加したので、予備訓練モデルを導入して元のパラメータの一部を初期化したいと思っています.新しく追加したパラメータはランダム初期化を採用しています.コード:
reader = tf.train.NewCheckpointReader("output/saver-test")
restore_dict = dict()
for v in tf.trainable_variables(): #                
    tensor_name = v.name.split(':')[0] #       :0   (conv1/b:0->conv1/b)
    # print(tensor_name) #           
    if reader.has_tensor(tensor_name): #                   ,       restore      
        print('has tensor', tensor_name)
        restore_dict[tensor_name] = v
saver = tf.train.Saver(restore_dict)#         

with tf.Session(config=config) as sess:
    sess.run(tf.global_variables_initializer()) #            
    sess.run(tf.local_variables_initializer())
    saver.restore(sess, "output/saver-test") # restore     
    #           conv1/b:0  ,    ,b             conv1/b   ,  restore  
    b=tf.get_default_graph().get_tensor_by_name("conv1/b:0") 
    print(sess.run(b))

前回の訓練結果を抽出して訓練を継続する
2つ目のケースとの違いは、モデル構造が同じで、これを復元するのがもっと簡単で、直接このように操作すればいいので、パラメータの初期化も~~~
saver = tf.train.Saver(max_to_keep=100) #max_to_keep     ckpt     ,   5
with tf.Session(config=config) as sess:
    saver.restore(sess, "output/saver")

リンクを共有:https://blog.csdn.net/qq_25737169/article/details/78125061 https://stackoverflow.com/questions/52532150/how-to-restore-pretrained-model-to-initialize-parameters