cifar-10データセット処理
10576 ワード
cifar-10データセット処理
CIFAR-10データセットは10クラスの60000個の32 x 32カラー画像からなり、各クラスに6000個の画像がある.50000個のトレーニング画像と10000個のテスト画像があります.訓練画像は5ロットに分け,試験画像は1ロットに分けた.
pythonコードは次のとおりです.
import numpy as np
import random
import pickle
import platform
import os
#
def load_pickle(f):
version=platform.python_version_tuple()# python
if version[0]== '2':
return pickle.load(f)
elif version[0]== '3':
return pickle.load(f,encoding='latin1')
raise ValueError("invalid python version:{}".format(version))
#
def load_CIFAR_batch(filename):
with open(filename,'rb') as f:
datadict=load_pickle(f)
X=datadict['data']
Y=datadict['labels']
X=X.reshape(10000,3,32,32).transpose(0,2,3,1).astype("float")
#reshape() ,transpose()
Y=np.array(Y)
return X,Y
#
def load_CIFAR10(ROOT):
xs=[]
ys=[]
for b in range(1,6):
f=os.path.join(ROOT,'data_batch_%d'%(b,))#os.path.join()
X,Y=load_CIFAR_batch(f)
xs.append(X)
ys.append(Y)
Xtr=np.concatenate(xs)#
Ytr=np.concatenate(ys)
del X,Y
Xte,Yte=load_CIFAR_batch(os.path.join(ROOT,'test_batch'))
return Xtr,Ytr,Xte,Yte
datasets = './cifar-10-batches-py'
train_x,train_y,test_x,test_y = load_CIFAR10(datasets)
print('train_x shape:%s, train_y shape:%s' % (train_x.shape, train_y.shape))
print('test_x shape:%s, test_y shape:%s' % (test_x.shape, test_y.shape))
出力結果:
train_x shape:(50000, 32, 32, 3), train_y shape:(50000,)
test_x shape:(10000, 32, 32, 3), test_y shape:(10000,)
そのうち50000はピクチャ数、32323は画像のwidth、hight、channel
次はcifar-100データセット処理です.