python Tensorflow 3層全接続ニューラルネットワーク手書きデジタル識別を実現

3707 ワード

import tensorflow as tf
import numpy as np
from tensorflow.examples.tutorials.mnist import input_data
from PIL import Image
import cv2
import os

x = tf.placeholder(dtype=tf.float32, shape=[None, 28*28],name='x')
y = tf.placeholder(dtype=tf.float32, shape=[None, 10],name='y')
#     
mnist = input_data.read_data_sets('./MNIST',one_hot=True)
batch_size = 1000
def add_layer(input_data, input_num, output_num, activation_function=None):
    w = tf.Variable(initial_value=tf.random_normal(shape=[input_num, output_num]))
    b = tf.Variable(initial_value=tf.random_normal(shape=[1, output_num]))
    output = tf.add(tf.matmul(input_data, w),b)
    if activation_function:
        output = activation_function(output)
    return output


def build_nn(data):
    hidden_layer1 = add_layer(data, 28*28, 100, activation_function= tf.sigmoid)
    hidden_layer2 = add_layer(hidden_layer1, 100, 50, activation_function=tf.sigmoid)
    output_layer = add_layer(hidden_layer2, 50, 10)
    return output_layer


def train_nn(data):
    output = build_nn(data)
    #   y       
    loss = tf.nn.softmax_cross_entropy_with_logits(labels=y, logits=output)
    #       ,   
    loss = tf.reduce_mean(loss)
    #      
    optimizer = tf.train.GradientDescentOptimizer(learning_rate=1).minimize(loss)

    #       
    saver = tf.train.Saver()
    with tf.Session() as sess:
        sess.run(tf.global_variables_initializer())
        if not os.path.exists('checkpoint'):
            for i in range(50):
                each_cost = 0
                for j in range(int(mnist.train.num_examples) // batch_size):
                    x_data, y_data = mnist.train.next_batch(batch_size)
                    cost, _ = sess.run([loss, optimizer], feed_dict={x: x_data, y: y_data})
                    each_cost += cost
                print('Epoch', i, ': ', each_cost)
            accuracy = tf.reduce_mean(tf.cast(tf.equal(tf.argmax(y, 1), tf.argmax(output, 1)), tf.float32))
            acc = sess.run(accuracy, feed_dict={x: mnist.test.images, y: mnist.test.labels})
            print(acc)
            saver.save(sess, './mnist.ckpt')
        else:
            saver.restore(sess, './mnist.ckpt')
            predict('./1.jpg', sess, output)





def reconstuct_image():
    """             """
    for i in range(10):
        if not os.path.exists('./{}'.format(i)):
            os.makedirs('./{}'.format(i))
    batch_size = 1
    for j in range(int(mnist.train.num_examples)//batch_size):
        #   x_data    [[28*28  ]]
        x_data, y_data = mnist.train.next_batch(batch_size)
        img = Image.fromarray(np.reshape(np.array(x_data[0]*255, dtype = 'uint8'), newshape=(28, 28)))
        dir = np.argmax(y_data[0])
        img.save('./{}/{}.bmp'.format(dir,j))
        if j%1000==0:
            print("   ", j, "/", mnist.train.num_examples)


def read_data(path):
    """  opencv    """
    #       
    image = cv2.imread(path, cv2.IMREAD_GRAYSCALE)
    processed_image = cv2.resize(image, dsize=(28, 28))
    cv2.waitKey(0)
    cv2.destroyAllWindows()
    processed_image = np.resize(processed_image, new_shape=(1, 28*28))
    return image, processed_image

def predict(image_path, sess, output):
    """         """
    #     
    image, processed_image = read_data(image_path)
    result = sess.run(output, feed_dict={x:processed_image})
    result = np.argmax(result,1)
    print('the prediction is ',result)


if __name__ == '__main__':
    # reconstuct_image()
    train_nn(x)