CIFAR 10 Classification
11230 ワード
Let's get used to the sensorflow framework and understand how CNN is work!
IMPORT LIBRARY
IMPORT LIBRARY
import tensorflow as tf
import numpy as np
import matplotlib.pyplot as plt
LOAD DATAcifar10 = tf.keras.datasets.cifar10
#bring CIFAR10 dataset from tensorflow
(train_images,train_labels),(test_images,test_labels) = cifar10.load_data()
#load data separately in to train_images,train_label,test_images,test_data
ANALISYS DATA#check each dataset shape and number of data size
print(train_images.shape)
print(len(train_images))
print(test_images.shape)
print(len(test_images))
#data visualization
plt.figure(figsize=(3,3)) # draw figure
for i in range(9):
plt.subplot(3,3,i+1) # subplot into 3*3 array shape
plt.imshow(train_images[i])
plt.colorbar()
plt.grid(False)
plt.show()
MODEL TRAIN#model building
model = tf.keras.Sequential([
tf.keras.layers.experimental.preprocessing.Rescaling(1./255),
# normalize dataset values to between 0-1
tf.keras.layers.Conv2D(32, 3, activation='relu'),
# convolutional layer with 32 many and 3*3 sized kernels , use ReLu for activation function
tf.keras.layers.MaxPooling2D(),
# maxpooling layers
tf.keras.layers.Conv2D(64, 3, activation='relu'),
tf.keras.layers.MaxPooling2D(),
tf.keras.layers.Conv2D(128, 3, activation='relu'),
tf.keras.layers.MaxPooling2D(),
tf.keras.layers.Flatten(),
# flatten layers from n dimension array form to 1 dimension
tf.keras.layers.Dense(128, activation='relu'),
# connect 128 nodes using dense layer
tf.keras.layers.Dense(10)
# because we want 10 different classes to be defined connect to 10 nodes
])
#model.summary()
model.compile(optimizer='adam',
loss=tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True),
metrics=['accuracy'])
# compile model using 'adam' for optimizer, crossentropy for loss function
model.fit(train_images, train_labels, epochs=15)
# actually train model epochs(=15) times
test_loss, test_acc = model.evaluate(test_images, test_labels, verbose=2)
# check the result value of loss and accuracy
print('\nTest accuracy:', test_acc)
Reference
この問題について(CIFAR 10 Classification), 我々は、より多くの情報をここで見つけました https://velog.io/@south1128/CIFAR-10-Classificationテキストは自由に共有またはコピーできます。ただし、このドキュメントのURLは参考URLとして残しておいてください。
Collection and Share based on the CC Protocol