5.1 mnistデータ分類器の精度を98%以上に向上
22642 ワード
Mnistデータ分類器の精度を98%以上に向上させるにはどうすればいいですか?
テクニック:ネットワークの層数および各層ニューロンの個数 オプティマイザの選択:Adam,SGD,Adagrad,RMSprop,Adadelta 学習率の更新:反復回数の増加に伴い、指数は 減少した.学習輪数の設定 プログラム:
新しい内容:
テクニック:
mnist = input_data.read_data_sets("C:/Users/WangT/Desktop/MNIST_data",one_hot=True)#
batch_size = 100#
n_batch = mnist.train.num_examples // batch_size
# ,// , , /
#mnist.train.num_examples , mnist.validation.num_examples, mnist.test.num_examples.
x = tf.placeholder(tf.float32,[None,784])
y = tf.placeholder(tf.float32,[None,10])
keep_prob = tf.placeholder(tf.float32)
#placeholder , MNIST , 784 , 2 , 【none,784】, None 。
lr = tf.Variable(0.001, dtype = tf.float32)
W1 = tf.Variable(tf.truncated_normal([784,500],stddev=0.1))
b1 = tf.Variable(tf.zeros([500])+0.1)
L1 = tf.nn.tanh(tf.matmul(x,W1)+b1)
#L1_drop = tf.nn.dropout(L1,keep_prob)
# , Variable , ,
W2 = tf.Variable(tf.truncated_normal([500,300],stddev=0.1))
b2 = tf.Variable(tf.zeros([300])+0.1)
L2 = tf.nn.tanh(tf.matmul(L1,W2)+b2)
W3= tf.Variable(tf.truncated_normal([300,10],stddev=0.1))
b3 = tf.Variable(tf.zeros([10])+0.1)
prediction = tf.nn.softmax(tf.matmul(L2,W3)+b3)
#
# loss = tf.reduce_mean(tf.square(y - prediction))
loss = tf.reduce_mean(tf.nn.softmax_cross_entropy_with_logits(labels=y,logits=prediction))
# , ,tf.square ,tf.reduce_mean
#train_step = tf.train.GradientDescentOptimizer(0.2).minimize(loss)
train_step = tf.train.AdamOptimizer(lr).minimize(loss)
#tf , lr , loss
init = tf.global_variables_initializer()
#
correct_prediction = tf.equal(tf.argmax(y,1), tf.argmax(prediction,1))
#tf.equal() , true false,tf.argmax(y,1) y x
accuracy = tf.reduce_mean(tf.cast(correct_prediction,tf.float32))
#tf.cast() , ,
with tf.Session() as sess:#
sess.run(init)
for epoch in range (51):# 51
for batch in range(n_batch):# n_batch
batch_xs,batch_ys = mnist.train.next_batch(batch_size)#
sess.run(tf.assign(lr,0.001*(0.95**epoch)))
sess.run(train_step, feed_dict={x:batch_xs,y:batch_ys})#
acc1 = sess.run(accuracy, feed_dict={x:mnist.test.images, y:mnist.test.labels})# ,
acc2 = sess.run(accuracy,feed_dict={x:mnist.train.images,y:mnist.train.labels})
print("Iter"+str(epoch)+",Testing Accuracy"+str(acc1)+",Training accuracy"+str(acc2))
Extracting C:/Users/WangT/Desktop/MNIST_data\train-images-idx3-ubyte.gz
Extracting C:/Users/WangT/Desktop/MNIST_data\train-labels-idx1-ubyte.gz
Extracting C:/Users/WangT/Desktop/MNIST_data\t10k-images-idx3-ubyte.gz
Extracting C:/Users/WangT/Desktop/MNIST_data\t10k-labels-idx1-ubyte.gz
Iter0,Testing Accuracy0.9499,Training accuracy0.955691
Iter1,Testing Accuracy0.9627,Training accuracy0.970036
Iter2,Testing Accuracy0.9661,Training accuracy0.977127
Iter3,Testing Accuracy0.9705,Training accuracy0.981327
Iter4,Testing Accuracy0.9759,Training accuracy0.986345
Iter5,Testing Accuracy0.9755,Training accuracy0.988509
Iter6,Testing Accuracy0.9749,Training accuracy0.989545
Iter7,Testing Accuracy0.9783,Training accuracy0.991491
Iter8,Testing Accuracy0.9771,Training accuracy0.992618
Iter9,Testing Accuracy0.9793,Training accuracy0.992564
Iter10,Testing Accuracy0.9789,Training accuracy0.993691
Iter11,Testing Accuracy0.9796,Training accuracy0.994327
Iter12,Testing Accuracy0.9779,Training accuracy0.994382
Iter13,Testing Accuracy0.9788,Training accuracy0.995036
Iter14,Testing Accuracy0.9795,Training accuracy0.995073
Iter15,Testing Accuracy0.9801,Training accuracy0.995491
Iter16,Testing Accuracy0.9806,Training accuracy0.995727
新しい内容:
1.lr = tf.Variable(0.001, dtype = tf.float32)# , , 0.001, float32
2.sess.run(tf.assign(lr,0.001*(0.95**epoch)))
, 0.001, , tf.assign()