フォーマット変換-画像をpklフォーマット、トレーニングセット、テストセットに一括変換
3653 ワード
from PIL import Image
import os
import pickle
import numpy as np
import time
from random import shuffle
def create_dataset(filename, write_to_file_root, channel, image_suffix, num_test, step_every_print=500):
""" Create znyp Dataset
input:
filename: the file should be followed the directory tree (e.g., './classes')
write_to_file_root: the root location of output file (e.g., './')
channel:
1: there are only 1 channel
3: the image have three channels (R, G, B)
image_suffix: the image suffix of the dataset (e.g., '.jpg')
step_every_print: number of steps required to print the imformation
"""
Filelist = []
for root, dirnames, filenames in os.walk(filename):
for filename in filenames:
if filename.endswith(image_suffix):
Filelist.append(os.path.join(root, filename))
shuffle(Filelist)
test_Filelist = Filelist[:num_test]
train_Filelist = Filelist[num_test:]
_pickel_dataset(test_Filelist, write_to_file_root, 'testsh', channel, step_every_print)
_pickel_dataset(train_Filelist, write_to_file_root, 'trainsh', channel, step_every_print)
def _pickel_dataset(filename, write_to_file_root, category, channel, step_every_print=500):
""" Pickel Dataset
input:
filename: a list of image location
write_to_file_root: the root location of output file (e.g., './')
category: dataset category
channel:
1: there are only 1 channel
3: the image have three channels (R, G, B)
step_every_print: number of steps required to print the imformation
return
a Python "pickled" object
"""
start_time = time.time()
data = []
labels = []
num_images = 0
for idx, filename in enumerate(filename):
if idx % step_every_print == 0:
print('Current append image : %s' % filename)
im = Image.open(filename)
im = (np.array(im))
H, W = im.shape[0], im.shape[1]
if channel == 1:
r = im.flatten()
g = []
b = []
elif channel == 3:
# get pixel from red channel, then green then blue
r = im[:, :, 0].flatten()
g = im[:, :, 1].flatten()
b = im[:, :, 2].flatten()
else:
raise Exception('The channel for the image should be 1 or 3')
num_images += 1
# append the label
label = int(filename.split(os.sep)[-2])
labels.append(label)
# append the pixel
data += (list(r) + list(g) + list(b))
# convert the list to numpy
data = np.array(data, np.uint8)
datadict = {'data': data, 'labels': labels, 'height': 32, 'width': 32, 'channel': channel, 'num_images': num_images}
# write to pickle
outname = os.path.join(write_to_file_root, category + '_data')
f = open(outname, 'wb')
pickle.dump(datadict, f, True)
f.close()
end_time = time.time()
print('%s set took %.2f seconds' % (category, end_time - start_time))
# write to bin
# output_file = open('data_batch_1.bin', 'wb')
# data.tofile(output_file)
# output_file.close()
if __name__ == '__main__':
# create dataset
create_dataset(filename='./train4', write_to_file_root='./znyp_dataset', channel=3, image_suffix='.jpg', num_test=12630)