cifar-10データセット処理

10576 ワード

cifar-10データセット処理


CIFAR-10データセットは10クラスの60000個の32 x 32カラー画像からなり、各クラスに6000個の画像がある.50000個のトレーニング画像と10000個のテスト画像があります.訓練画像は5ロットに分け,試験画像は1ロットに分けた.
pythonコードは次のとおりです.
import numpy as np
import random
import pickle
import platform
import os

# 
def load_pickle(f):
    version=platform.python_version_tuple()# python 
    if version[0]== '2':
        return pickle.load(f)
    elif version[0]== '3':
        return pickle.load(f,encoding='latin1')
    raise ValueError("invalid python version:{}".format(version))
# 
def load_CIFAR_batch(filename):
    with open(filename,'rb') as f:
        datadict=load_pickle(f)
        X=datadict['data']
        Y=datadict['labels']
        X=X.reshape(10000,3,32,32).transpose(0,2,3,1).astype("float")
        #reshape() ,transpose() 
        Y=np.array(Y)
        return X,Y

# 
def load_CIFAR10(ROOT):
    xs=[]
    ys=[]
    for b in range(1,6):
        f=os.path.join(ROOT,'data_batch_%d'%(b,))#os.path.join() 
        X,Y=load_CIFAR_batch(f)
        xs.append(X)
        ys.append(Y)
    Xtr=np.concatenate(xs)# 
    Ytr=np.concatenate(ys)
    del X,Y
    Xte,Yte=load_CIFAR_batch(os.path.join(ROOT,'test_batch'))
    return Xtr,Ytr,Xte,Yte
datasets = './cifar-10-batches-py'
train_x,train_y,test_x,test_y = load_CIFAR10(datasets)
print('train_x shape:%s, train_y shape:%s' % (train_x.shape, train_y.shape))
print('test_x shape:%s, test_y shape:%s' % (test_x.shape, test_y.shape))

出力結果:
train_x shape:(50000, 32, 32, 3), train_y shape:(50000,)
test_x shape:(10000, 32, 32, 3), test_y shape:(10000,)

そのうち50000はピクチャ数、32323は画像のwidth、hight、channel
次はcifar-100データセット処理です.