5.1 mnistデータ分類器の精度を98%以上に向上

22642 ワード

Mnistデータ分類器の精度を98%以上に向上させるにはどうすればいいですか?
テクニック:
  • ネットワークの層数および各層ニューロンの個数
  • オプティマイザの選択:Adam,SGD,Adagrad,RMSprop,Adadelta
  • 学習率の更新:反復回数の増加に伴い、指数は
  • 減少した.
  • 学習輪数の設定
  • プログラム:
    mnist = input_data.read_data_sets("C:/Users/WangT/Desktop/MNIST_data",one_hot=True)#    
    
    batch_size = 100#              
    n_batch = mnist.train.num_examples // batch_size
    #        ,//      ,       ,   /
    #mnist.train.num_examples         ,    mnist.validation.num_examples, mnist.test.num_examples.
    
    x = tf.placeholder(tf.float32,[None,784])
    y = tf.placeholder(tf.float32,[None,10])
    keep_prob = tf.placeholder(tf.float32)
    #placeholder   ,          MNIST  ,        784    , 2            ,        【none,784】,  None                   。
    
    lr = tf.Variable(0.001, dtype = tf.float32)
    
    W1 = tf.Variable(tf.truncated_normal([784,500],stddev=0.1))
    b1 = tf.Variable(tf.zeros([500])+0.1)
    L1 = tf.nn.tanh(tf.matmul(x,W1)+b1)
    #L1_drop = tf.nn.dropout(L1,keep_prob)
    #     ,   Variable  ,       ,          
    
    W2 = tf.Variable(tf.truncated_normal([500,300],stddev=0.1))
    b2 = tf.Variable(tf.zeros([300])+0.1)
    L2 = tf.nn.tanh(tf.matmul(L1,W2)+b2)
    
    W3= tf.Variable(tf.truncated_normal([300,10],stddev=0.1))
    b3 = tf.Variable(tf.zeros([10])+0.1)
    prediction = tf.nn.softmax(tf.matmul(L2,W3)+b3)
    #      
    # loss = tf.reduce_mean(tf.square(y - prediction))
    loss = tf.reduce_mean(tf.nn.softmax_cross_entropy_with_logits(labels=y,logits=prediction))
    #    ,      ,tf.square   ,tf.reduce_mean     
    #train_step = tf.train.GradientDescentOptimizer(0.2).minimize(loss)
    train_step = tf.train.AdamOptimizer(lr).minimize(loss)
    
    #tf       , lr     ,            loss
    
    init = tf.global_variables_initializer()
    #            
    correct_prediction = tf.equal(tf.argmax(y,1), tf.argmax(prediction,1))
    #tf.equal()     ,    true    false,tf.argmax(y,1)  y      x
    accuracy = tf.reduce_mean(tf.cast(correct_prediction,tf.float32))
    #tf.cast()              ,       ,     
    
    with tf.Session() as sess:#    
        sess.run(init)
        for epoch in range (51):#      51 
            for batch in range(n_batch):#       n_batch  
                batch_xs,batch_ys = mnist.train.next_batch(batch_size)#           
                sess.run(tf.assign(lr,0.001*(0.95**epoch)))
                sess.run(train_step, feed_dict={x:batch_xs,y:batch_ys})#      
            
            acc1 = sess.run(accuracy, feed_dict={x:mnist.test.images, y:mnist.test.labels})#            ,          
            acc2 = sess.run(accuracy,feed_dict={x:mnist.train.images,y:mnist.train.labels})
            print("Iter"+str(epoch)+",Testing Accuracy"+str(acc1)+",Training accuracy"+str(acc2))
    
    
    Extracting C:/Users/WangT/Desktop/MNIST_data\train-images-idx3-ubyte.gz
    Extracting C:/Users/WangT/Desktop/MNIST_data\train-labels-idx1-ubyte.gz
    Extracting C:/Users/WangT/Desktop/MNIST_data\t10k-images-idx3-ubyte.gz
    Extracting C:/Users/WangT/Desktop/MNIST_data\t10k-labels-idx1-ubyte.gz
    Iter0,Testing Accuracy0.9499,Training accuracy0.955691
    Iter1,Testing Accuracy0.9627,Training accuracy0.970036
    Iter2,Testing Accuracy0.9661,Training accuracy0.977127
    Iter3,Testing Accuracy0.9705,Training accuracy0.981327
    Iter4,Testing Accuracy0.9759,Training accuracy0.986345
    Iter5,Testing Accuracy0.9755,Training accuracy0.988509
    Iter6,Testing Accuracy0.9749,Training accuracy0.989545
    Iter7,Testing Accuracy0.9783,Training accuracy0.991491
    Iter8,Testing Accuracy0.9771,Training accuracy0.992618
    Iter9,Testing Accuracy0.9793,Training accuracy0.992564
    Iter10,Testing Accuracy0.9789,Training accuracy0.993691
    Iter11,Testing Accuracy0.9796,Training accuracy0.994327
    Iter12,Testing Accuracy0.9779,Training accuracy0.994382
    Iter13,Testing Accuracy0.9788,Training accuracy0.995036
    Iter14,Testing Accuracy0.9795,Training accuracy0.995073
    Iter15,Testing Accuracy0.9801,Training accuracy0.995491
    Iter16,Testing Accuracy0.9806,Training accuracy0.995727
    

    新しい内容:
    1.lr = tf.Variable(0.001, dtype = tf.float32)#      ,     ,    0.001,   float32
    2.sess.run(tf.assign(lr,0.001*(0.95**epoch)))0.001,    ,  tf.assign()