TensorFlow CNN対CIFAR 10画像分類2


cifar 10データセットのダウンロードアドレスは以下の通りです.http://www.cs.toronto.edu/~kriz/cifar.
TensorFlow CNN对CIFAR10图像分类2_第1张图片
python versionをダウンロードします.
 
import os
os.environ['TF_CPP_MIN_LOG_LEVEL']='2'
import tensorflow as tf
import numpy as np

def unpickle(file):
    import pickle
    with open(file, 'rb') as fo:
        dict = pickle.load(fo, encoding='latin1')
    return dict

def onehot(labels):
    n_sample = len(labels)
    n_class = max(labels) +1
    onehot_labels = np.zeros((n_sample, n_class))
    onehot_labels[np.arange(n_sample),labels]=1
    return onehot_labels

data1 = unpickle('cifar10/data_batch_1')
data2 = unpickle('cifar10/data_batch_2')
data3 = unpickle('cifar10/data_batch_3')
data4 = unpickle('cifar10/data_batch_4')
data5 = unpickle('cifar10/data_batch_5')


X_train = np.concatenate((data1['data'], data2['data'], data3['data'],data4['data'], data5['data']), axis=0)
Y_train = np.concatenate((data1['labels'], data2['labels'], data3['labels'],data4['labels'], data5['labels']), axis=0)
Y_train = onehot(Y_train)
test = unpickle('cifar10/test_batch')
X_test = test['data'][:5000, :]

Y_test = onehot(test['labels'])[:5000, :]

print(X_train.shape)
print(Y_train.shape)
print(X_test.shape)
print(Y_test.shape)

xs = tf.placeholder(tf.float32, [None, 32*32*3])
ys = tf.placeholder(tf.float32, [None, 10])
x_image = tf.reshape(xs, [-1, 32, 32, 3])

conv1 = tf.layers.conv2d(x_image, 32, 5, 1, 'same', activation=tf.nn.relu)
poo11 = tf.layers.average_pooling2d(conv1, 3, 2, padding='same')
norm1 = tf.nn.lrn(poo11, 4, bias=1.0, alpha=0.001/9.0, beta=0.75)

conv2 = tf.layers.conv2d(norm1, 64, 5, 1, 'same', activation=tf.nn.relu)
norm2 = tf.nn.lrn(conv2, 4, bias=1.0, alpha=0.001/9.0, beta=0.75)
pool2 =tf.layers.average_pooling2d(norm2, 3, 2, padding='same')
pool2_flat = tf.reshape(pool2, [-1, 8*8*64])

fc1 = tf.layers.dense(pool2_flat, 384, activation=tf.nn.relu)
fc2 = tf.layers.dense(fc1, 192, activation=tf.nn.relu)
output = tf.layers.dense(fc2, 10, activation=tf.nn.softmax)

loss= tf.losses.softmax_cross_entropy(onehot_labels=ys,logits=output)
train_step = tf.train.GradientDescentOptimizer(learning_rate=1e-3).minimize(loss)

correct_prediction = tf.equal(tf.arg_max(output, 1), tf.arg_max(ys, 1))
accuracy = tf.reduce_mean(tf.cast(correct_prediction, tf.float32))

batch_size = 50
init = tf.global_variables_initializer()
with tf.Session() as sess:
    sess.run(init)
    total_batch = int(X_train.shape[0]/batch_size)
    for i in range(400):
        for batch in range(total_batch):
            batch_x =X_train[batch*batch_size:(batch+1)*batch_size,:]
            batch_y =Y_train[batch * batch_size:(batch + 1) * batch_size, :]
            sess.run(train_step, feed_dict={xs: batch_x, ys: batch_y})
        acc = sess.run(accuracy, feed_dict={xs: batch_x, ys: batch_y})
        print(i, acc)



400回運転した結果
0 0.38 1 0.38 2 0.4 3 0.44 0.44 0.44 0.44 0.44 6 0.42 0.44 4 0.46 0.46 0.46 10 0.48 0.5 12 0.52 0.52 0.52 0.52 16 0.52 0.52 18 0.54 18 0.54 0.56 0.56 0.56 0.56 0.56 0.56 0.56 0.56 0.56 0.56 0.56 0.56 0.56 0.56 0.56 0.56 0.56 0.56 0.56 0.56 0.533662 0.58 0.58 0.58 0.58 0.58 0.58 0.68 0.58 0.58 0.68 0.65 0.68 0.68 0.68 0.68 0.68 0.68 0.68 0.68 0.68 0.68 0.68 0.68 0.68 0.68 0.68 0.68 0.68 0.68 0.68 0.77 0.68 0.68 0.68 0.68 0.77 0.68 0.77 0.68 0.68 0.68 0.77 0.68 0.77 0.68 0.77 0.68 0.77 0.775 0.770.7 92 0.7 93 0.7 94 0.7 94 0.7 95 0.7 96 0.7 0.7 0.7 0.7 98 0.7 0.7 0.7 100 0.7 101 0.7 102 0.7 103 0.7 104 0.7 0.7 105 0.7 106 0.7 107 0.7 108 0.7 110 0.7 111 0.7 112 0.7 114 0.7 0.7 116 0.7 0.7 119 0.7 123 0.7 123 123 123 123 123 123 123 1270.72 0.72 138 0.72 138 0.72 139 0.72 0.72 140 0.72 141 0.72 142 0.72 144 0.72 145 0.72 0.72 146 0.72 0.72 147 0.72 0.72 148 0.72 148 0.72 150.72 152 0.72 150.72 150.72 150.72 150.72 157 0.72 157 0.72 157 0.72 150.72 158 0.72 150.72175 0.72 177 0.72 177 0.72 178 0.72 179 0.72 180 0.72 181 0.72 0.72 183 0.72 184 0.72 185 0.72 186 0.72 186 0.72 188 0.72 188 0.72 188 0.72 191 0.72 192 0.72 192 0.72 193 0.72 194 0.72 196 0.72 197 0.72 0.72 198 0.724 214 0.74 215 0.74 216 0.74 217 0.74 217 0.74 219 0.74 220 0.74 221 0.74 222 0.74 224 0.74 0.74 0.74 224 0.74 0.74 0.74 225 0.74 220.74 0.74 0.74 227 0.74 0.74 227 0.74 0.74 234 0.74 234 0.74 234 0.74 234 0.74 230.74 236 230.74 236 230.740.70.70.76 0.70.76 0.70.76 0.76 0.76 0.70.76 0.76 0.70.70.76 0.70.70.76 0.70.70.76 0.70.70.70.76 0.70.70.70.70.76 0.70.70.76 0.70.70.76 0.70.76 0.70.76 0.70.70.76 0.70.70.76 0.70.76 0.70.70.70.76 0.70.70.70.76 0.70.70.70.70.70.70.70.76 0.70.70.70.70.70.70.70.76 0.70.70.70.70.70.70.70.76 0.70.70.70.70.70.70.70.76 0.70.70.70.70.70.70.70.76 0.70.70.70.70.70.70.70.70.76 0.70.70.70.70.70.76 0.70.70.70.70.76 0.70.70.70.70.70.76 0.70.76 0.70.70.70.76 0.76 0.76 0.70.76 0.70.70.70.70.70.70.70.76 0.70.70.70.70.70.70.76 0.70.70.70.70.76 0.70.70.70.70.70.76 0.76 0.72910.76 292 0.76 293 0.76 0.76 294 0.76 0.76 0.76 296 0.76 0.76 0.76 0.76 296 0.76 0.76 0.76 0.76道300 0.76 302 0.76 303 0.76 304 0.76 305 0.76 306 306 0.76 307 0.76 308 0.76 316 0.7326 316 0.760.76 0.76 0.70.76 331 0.70.76 0.70.76 0.70.76 0.70.70.76 0.70.70.76 0.70.70.76 0.70.70.76 0.70.70.70.76 0.70.70.70.76 330.70.76 0.70.76 0.70.70.76 0.70.70.70.76 0.70.70.70.76 0.70.70.70.76 0.70.70.76 0.70.70.70.70.76 0.70.70.70.70.76 0.70.70.70.70.70.70.70.76 336 0.70.76 0.70.70.70.70.70.76 0.70.70.70.70.76 0.70.70.70.70.70.70.70.70.70.70.76 0.70.70.76 0.70.76 0.70.70.70.70.70.76 336 0.70.70.70.70.76 0.70.70.70.70.70.76 336 0.70.76 0.70.76 0.76 0.70.70.70.76 0.70.70.70.70.70.70.70.70.76 336 0.70.70.76 0.70.70.70.70.70.70.70.70.70.70.76 336 0.76369 0.76 370 0.76 371 0.76 372 0.76 372 0.76 373 0.76 0.76 375 0.76 376 0.76 377 0.76 378 0.76 378 0.76 379 0.76 0.76 380 0.76 381 0.76 382 0.76 0.76 384 0.76 385 0.76 0.76 386 0.76 387 0.76 387 0.76 386 0.76 389 0.76 389