TensorFlow基礎チュートリアル:モデルの永続化(モデルの保存と読み取り)


TensorFlowは訓練したモデルを保存することができ、訓練中に切断されただけでなく、前回の訓練過程を継続することができる.移行学習を行い、他人の訓練モデルに基づいて自分のモデルを訓練することもできます.とても便利です.
TensorFlowモデルcheckpointを保存した後、次のファイルを生成します.
|—checkpoint |—model_name.data-00000-of-00001 |—model_name.index |—model_name.meta
model_nameは定義されたモデル名model_name.metaは図ファイルmodel_name.dataはデータファイル
  • 保存モデル
  • saver = tf.train.Saver()            #  saver  
    saver.save(sess, checkpoint_path)   # sess          
  • リカバリモデル
  • データのみロード
    saver.restore(sess, checkpoint_path)  #           sess

    図とデータのロード
    meta_path = 'model_name.meta'   #   
    model_path = 'model_name'       #    
    saver = tf.train.import_meta_graph(meta_path)  #   
    with tf.Session() as sess:
        saver.restore(sess, model_path)     #    sess     
        graph = tf.get_default_graph()
        x = graph.get_tensor_by_name('InputData:0')  #     tensor

    すべてのプログラムコード(TensorFlowベースチュートリアル:簡単なDNNを構築して手書きデジタル認識を実現)の書き換え
    トレーニングコード
    # coding: utf-8
    from tensorflow.examples.tutorials.mnist import input_data
    mnist = input_data.read_data_sets("MNIST_data/", one_hot=True)
    
    import tensorflow as tf
    
    learning_rate = 0.001
    train_epochs = 10
    batch_size = 64
    checkpoint_path = 'checkpoint/'
    
    n_input = 784
    n_hidden1 = 100
    n_hidden2 = 100
    n_classes = 10
    
    #name  ,      
    x = tf.placeholder(tf.float32, shape=[None, n_input], name='InputData')
    y = tf.placeholder(tf.float32, shape=[None, n_classes], name='LabelData')
    
    weights = {'w1': tf.Variable(tf.random_normal([n_input, n_hidden1]), name='W1'),
                      'w2': tf.Variable(tf.random_normal([n_hidden1, n_hidden2]), name='W2'),
                      'w3': tf.Variable(tf.random_normal([n_hidden2, n_classes]), name='W3')}
    biases = {'b1': tf.Variable(tf.random_normal([n_hidden1]), name='b1'),
                    'b2': tf.Variable(tf.random_normal([n_hidden2]), name='b2'),
                    'b3': tf.Variable(tf.random_normal([n_classes]), name='b3')}
    
    def inference(input_x):
        layer_1 = tf.nn.relu(tf.matmul(x, weights['w1']) + biases['b1'])
        layer_2 = tf.nn.relu(tf.matmul(layer_1, weights['w2']) + biases['b2'])
        out_layer = tf.matmul(layer_2, weights['w3']) + biases['b3']
        return out_layer
    
    #         
    with tf.name_scope('Inference'):
        logits = inference(x)
    with tf.name_scope('Loss'):
        loss = tf.reduce_mean(tf.nn.softmax_cross_entropy_with_logits(logits=logits, labels=y))
    with tf.name_scope('Optimizer'):
        optimizer = tf.train.AdamOptimizer(learning_rate=learning_rate)
        train_op = optimizer.minimize(loss)
    with tf.name_scope('Accuracy'):
        pre_correct = tf.equal(tf.argmax(y, 1), tf.argmax(tf.nn.softmax(logits), 1))
        accuracy = tf.reduce_mean(tf.cast(pre_correct, tf.float32), name='acc')
        print(accuracy)
    
    init = tf.global_variables_initializer()
    saver = tf.train.Saver()
    
    with tf.Session() as sess:
        sess.run(init)
        total_batch = int(mnist.train.num_examples / batch_size)
    
        checkpoint = tf.train.get_checkpoint_state(checkpoint_path)  #  checkpoint  
        if checkpoint and checkpoint.model_checkpoint_path:
            saver.restore(sess, checkpoint_path+'model.ckpt')        #    
            print('continue last train!!')
        else:
            print('restart train!!')
        for epoch in range(train_epochs):
            for batch in range(total_batch):
                batch_x, batch_y = mnist.train.next_batch(batch_size)
                sess.run(train_op, feed_dict={x:batch_x, y:batch_y})   
            if (epoch+1) % 5 == 0:
                loss_, acc = sess.run([loss, accuracy], feed_dict={x:batch_x, y:batch_y})
                print("epoch {},  loss {:.4f}, acc {:.3f}".format(epoch, loss_, acc))
                saver.save(sess, checkpoint_path+'model.ckpt')    #    model.ckpt
    
        print("optimizer finished!")
        print("     ", checkpoint_path)

    モデルリカバリとテストセット精度の計算
    from tensorflow.examples.tutorials.mnist import input_data
    mnist = input_data.read_data_sets("MNIST_data/", one_hot=True)
    import tensorflow as tf
    
    meta_path = 'checkpoint/model.ckpt.meta'            #   
    model_path = 'checkpoint/model.ckpt'                #    
    saver = tf.train.import_meta_graph(meta_path)       #   
    
    with tf.Session() as sess:
        saver.restore(sess, model_path)                 #    
        graph = tf.get_default_graph()
        x = graph.get_tensor_by_name('InputData:0')     #    
        y = graph.get_tensor_by_name('LabelData:0')
        accuracy = graph.get_tensor_by_name('Accuracy/acc:0')
    
        #         
        test_acc = sess.run(accuracy, feed_dict={x:mnist.test.images, y:mnist.test.labels})
        print('test accuracy', test_acc)

    githubソースダウンロードhttps://github.com/gamersover/tensorflow_basic_tutorial/tree/master/model_save_tutorial