Tensorflow使用ノート(1):Tensorflowのモデル保存と使用

3495 ワード

トレーニングしたモデルパラメータを保存して使用する方法
引用する
最近Tensorflowを学んでCNNを構築して、訓練は時間を費やして、訓練した各パラメータを保存するのが最も簡便で、ネット上で多くの教程があって、しかし教程に従って順風満帆になってやはりいくつかの穴を踏んだとは限らなくて、それから自分で穴を埋めました
トレーニングの結果を保存する方法:
セッションをsessと仮定して、計算図はgraphネット上で多くの資料を見て、使用します
saver=tf.train.Saver()  #                
file_name = 'saved_model/model.ckpt'  #            saved_model    model.ckpt  
saver.saver(sess,file_name )  #         

これでモデルを保存して、これでいいのでしょうか?まだだめですが、このようにして、IDEはいつも間違っています:No Variable to save
そして私の考えは:sessをsaverに伝えるかもしれませんか?次は私のコードと試行的な修正です.
graph = tf.Graph()  #    
with graph.as_default():
    #      
    ...
    #               ,        
sess = tf.Session(graph=graph)  #                 
#    saver  sess    ,       
saver=tf.train.Saver(sess)  #                
saver.saver(sess,'saved_model/model.ckpt')

保存できる変数はありません.考えてみればgraphに載せるかも?
graph = tf.Graph()
with graph.as_default():
    #      
    ... 
    # ---
sess = tf.Session(graph=graph)  #                 
#    graph      
saver = tf.train.Saver(graph)
saver.saver(sess,'saved_model/model.ckpt')

結局だめだった
何度も試したあげく、やっと正しく直した.
graph = tf.Graph()
with graph.as_default():
    #      
    ... 
    # ---
    saver = tf.train.Saver()  #             ,         ,             

sess = tf.Session(graph=graph)  #                 
#   ,        
saver.saver(sess,'saved_model/model.ckpt')

保存成功、フォルダsaved_modelの下には、.data,.index,.metaの接尾辞を持ついくつかのファイルcheckpointファイル(このファイルは重要です.記録されています)が表示されます.私たちはこの3つの書類を無視できるようだ.
まずまとめておきます
サブマップを定義したり、tensorflowのデフォルト計算図を使用せずにgraphを自分で定義したりした場合、graphを定義して最後に定義し、どのサブマップの変数を保存したいのか、どのサブマップに関連するSaverを定義してこそ、目的の効果を実現することができます.Tensorflowのグラフィックとセッションセッションセッションはまだ抽象的で、うっかり混乱してしまいました.
Notes:with graph.asを使用していないと仮定します.graph():この構造は、スクリプト上でtensorflowの変数を直接定義し、saver()を使用しても問題ないはずです.
保存したモデルパラメータの読み込みと使用
保存したら、別の新しいスクリプトを呼び出すにはどうすればいいですか.例えば、test.pyファイルで私のコードを使用してテストする必要があります.保存したモデルパラメータをどのように使用しますか.
2つの方法があります.
先に訓練して、計算図を構築してすでにあなたが定義したネットワークパラメータのコードをtest.pyファイルの下に貼り付けます
graph = tf.Graph()
with graph.as_default():
    #      
    ... 
    # ---
    saver = tf.train.Saver()

特に注意して、前のステップまで運転すると
次に次のコードを使用します.
with tf.Session(graph=graph) as sess:
        check_point_path = 'saved_model/' #           
        ckpt = tf.train.get_checkpoint_state(checkpoint_dir=check_point_path)

#               saver.restore(sess,ckpt.model_checkpoint_path)  #     ,               ,     。

上記の方法が少し煩わしいと思ったら、train.pyというスクリプトのコードを直接import train.pyで仮定することができます.
graph = tf.Graph()
with graph.as_default():
    #      
    ... 
    # ---
    saver = tf.train.Saver()

ではtest.pyにこのように書くことができます
import train
#   python       
graph = train.graph
sess = train.sess

with tf.Session(graph=graph) as sess:
        check_point_path = 'saved_model/' #           
        ckpt = tf.train.get_checkpoint_state(checkpoint_dir=check_point_path)

#               saver.restore(sess,ckpt.model_checkpoint_path)  


Notes:saver.restore(sess,ckpt.model_checkpoint_path)を使用すると、このときsess.run(init)を使用してパラメータを初期化する必要はありません(そうしないと訓練されたパラメータを上書きします).前にrunを使用して初期化すると、重みは定義に基づいて初期化されますが、この文を使用するとモデルのパラメータが上書きされます
最後にもう一度
高級な使い方のようですが、選択によって反復回数を変えて更新する際の重みをつけることができます.ここでは簡単にまとめて、後で勉強してから更新しましょう