python Tensorflow 3層全接続ニューラルネットワーク手書きデジタル識別を実現
3707 ワード
import tensorflow as tf
import numpy as np
from tensorflow.examples.tutorials.mnist import input_data
from PIL import Image
import cv2
import os
x = tf.placeholder(dtype=tf.float32, shape=[None, 28*28],name='x')
y = tf.placeholder(dtype=tf.float32, shape=[None, 10],name='y')
#
mnist = input_data.read_data_sets('./MNIST',one_hot=True)
batch_size = 1000
def add_layer(input_data, input_num, output_num, activation_function=None):
w = tf.Variable(initial_value=tf.random_normal(shape=[input_num, output_num]))
b = tf.Variable(initial_value=tf.random_normal(shape=[1, output_num]))
output = tf.add(tf.matmul(input_data, w),b)
if activation_function:
output = activation_function(output)
return output
def build_nn(data):
hidden_layer1 = add_layer(data, 28*28, 100, activation_function= tf.sigmoid)
hidden_layer2 = add_layer(hidden_layer1, 100, 50, activation_function=tf.sigmoid)
output_layer = add_layer(hidden_layer2, 50, 10)
return output_layer
def train_nn(data):
output = build_nn(data)
# y
loss = tf.nn.softmax_cross_entropy_with_logits(labels=y, logits=output)
# ,
loss = tf.reduce_mean(loss)
#
optimizer = tf.train.GradientDescentOptimizer(learning_rate=1).minimize(loss)
#
saver = tf.train.Saver()
with tf.Session() as sess:
sess.run(tf.global_variables_initializer())
if not os.path.exists('checkpoint'):
for i in range(50):
each_cost = 0
for j in range(int(mnist.train.num_examples) // batch_size):
x_data, y_data = mnist.train.next_batch(batch_size)
cost, _ = sess.run([loss, optimizer], feed_dict={x: x_data, y: y_data})
each_cost += cost
print('Epoch', i, ': ', each_cost)
accuracy = tf.reduce_mean(tf.cast(tf.equal(tf.argmax(y, 1), tf.argmax(output, 1)), tf.float32))
acc = sess.run(accuracy, feed_dict={x: mnist.test.images, y: mnist.test.labels})
print(acc)
saver.save(sess, './mnist.ckpt')
else:
saver.restore(sess, './mnist.ckpt')
predict('./1.jpg', sess, output)
def reconstuct_image():
""" """
for i in range(10):
if not os.path.exists('./{}'.format(i)):
os.makedirs('./{}'.format(i))
batch_size = 1
for j in range(int(mnist.train.num_examples)//batch_size):
# x_data [[28*28 ]]
x_data, y_data = mnist.train.next_batch(batch_size)
img = Image.fromarray(np.reshape(np.array(x_data[0]*255, dtype = 'uint8'), newshape=(28, 28)))
dir = np.argmax(y_data[0])
img.save('./{}/{}.bmp'.format(dir,j))
if j%1000==0:
print(" ", j, "/", mnist.train.num_examples)
def read_data(path):
""" opencv """
#
image = cv2.imread(path, cv2.IMREAD_GRAYSCALE)
processed_image = cv2.resize(image, dsize=(28, 28))
cv2.waitKey(0)
cv2.destroyAllWindows()
processed_image = np.resize(processed_image, new_shape=(1, 28*28))
return image, processed_image
def predict(image_path, sess, output):
""" """
#
image, processed_image = read_data(image_path)
result = sess.run(output, feed_dict={x:processed_image})
result = np.argmax(result,1)
print('the prediction is ',result)
if __name__ == '__main__':
# reconstuct_image()
train_nn(x)