Tensorflowは訓練されたモデルでテストを行うことができます。


Tensorflowは訓練されたモデルを使って新しいデータをテストすることができます。二つの方法があります。最初の方法はモデルを呼び出して同じpyファイルにトレーニングすることです。中の状況は簡単です。第二は、トレーニングプロセスと呼び出しモデルプロセスがそれぞれ2つのpyファイルにあります。本稿では第二の方法を説明する。
モデルの保存
tenssorflowはトレーニングモデルを保存できるインターフェースを提供しています。使うのも難しくないです。直接コードを説明します。

#    
w1 = tf.Variable(tf.truncated_normal([in_units, h1_units], stddev=0.1))
b1 = tf.Variable(tf.zeros([h1_units]))
y = tf.nn.softmax(tf.matmul(w1, x) + b1)
tf.add_to_collection('network-output', y)

x = tf.placeholder(tf.float32, [None, in_units], name='x')
y_ = tf.placeholder(tf.float32, [None, 10], name='y_')
#         
cross_entropy = tf.reduce_mean(-tf.reduce_sum(y_ * tf.log(y), reduction_indices=[1]))
train_step = tf.train.AdamOptimizer(rate).minimize(cross_entropy)

saver = tf.train.Saver()
with tf.Session() as sess: 
    sess.run(init) 
    saver.save(sess,"save/model.ckpt") 
    train_step.run({x: train_x, y_: train_y})
以上のコードは模型の保存を完成しました。注意すべきは次の行のコードです。

tf.add_to_collection('network-output', y)
この行のコードは神経ネットワークの出力を保存しています。これは導入モデルを使用した後の過程で重要な役割を果たします。
モデルの導入
モデルトレーニングをして保存してから、モデルのテストセットの表現を評価するために導入されます。インターネット上の多くの文章は簡単な四則演算だけで例を作ります。それとも先にコードを入れますか?

with tf.Session() as sess:
  saver = tf.train.import_meta_graph('./model.ckpt.meta')
  saver.restore(sess, './model.ckpt')# .data  
  pred = tf.get_collection('network-output')[0]

  graph = tf.get_default_graph()
  x = graph.get_operation_by_name('x').outputs[0]
  y_ = graph.get_operation_by_name('y_').outputs[0]

  y = sess.run(pred, feed_dict={x: test_x, y_: test_y})
キーコードを説明します。まずpred=tf.get_です。collection('pred_.network')[0],この行のコードは訓練中にネットワーク出力の「インターフェース」を獲得します。簡単に理解すると、tf.get_を通じて(通って)です。collection()この方法はネットワーク構造全体を取得する。ネットワーク構造を取得すると、対応するデータy=sess.run(pred,feed_dict={x:test_]x,y_:test_y}訓練中の入力は

x = tf.placeholder(tf.float32, [None, in_units], name='x')
y_ = tf.placeholder(tf.float32, [None, 10], name='y_')
したがって、モデルを導入した後に必要な入力もこれに対応して以下のコードで取得できます。

  x = graph.get_operation_by_name('x').outputs[0]
  y_ = graph.get_operation_by_name('y_').outputs[0]
モデルを使う最後のステップはテストセットを入力し、訓練されたネットワークによって評価します。

  sess.run(pred, feed_dict={x: test_x, y_: test_y})
この行のコードを理解してください。sess.run()の関数の原型は

run(fetches, feed_dict=None, options=None, run_metadata=None)
Tensorflow対feed_dictはfetch操作を行うので、モデル導入後の演算とは、訓練したネットワークに従って入力したデータを計算します。
以上のTensorflowが訓練されたモデルでテストを行うことができました。小編集は皆さんに全部の内容を共有しています。参考にしてもらいたいです。どうぞよろしくお願いします。