PyTorchデータ読み出し
torch.utils.data.DataLoader
torch.utils.data.DataLoader(torch.utils.data.dataset,batch_size,shuffle,num_workers,pin_memory)
重要なのはこの2つのクラスです:torch.utils.data.DataLoader torch.utils.data.dataset
クラスをデータリーダとして書き、torchを継承する.utils.data.dataset
使用
出力の最後の結果は
出力されるTensorは4次元で、画像を自動的に1次元加算します.
torch.utils.data.DataLoader(torch.utils.data.dataset,batch_size,shuffle,num_workers,pin_memory)
重要なのはこの2つのクラスです:torch.utils.data.DataLoader torch.utils.data.dataset
import torchvision.transforms as transforms
train_loader = torch.utils.data.DataLoader(
ImageList(root=opt.root_path, fileList=opt.train_list,
transform=transforms.Compose([
transforms.ToTensor(), # Tensor ,
])),
batch_size=opt.batch_size, shuffle=True,
num_workers=opt.workers, pin_memory=True)
クラスをデータリーダとして書き、torchを継承する.utils.data.dataset
#load_imglist.py
import torch.utils.data
from PIL import Image
import os
def default_list_reader(fileList):
imgList = []
with open(fileList, 'r') as file:
for line in file.readlines():
imgPath, label = line.strip().split(' ')
imgList.append((imgPath, int(label)))
return imgList
class ImageList(torch.utils.data.Dataset):
def __init__(self, root, fileList, transform=None):
self.root = root
self.imgList = default_list_reader(fileList)
self.transform = transform
def __getitem__(self, index):
imgPath, target = self.imgList[index]
print(imgPath)
img_loc=os.path.join(self.root, imgPath)
img = Image.open(img_loc).convert('L') # , RGB
if self.transform is not None:
img = self.transform(img)
return img, target
def __len__(self):
return len(self.imgList)
使用
for i,(input,target) in enumerate(train_loader):
print(i,target)
print(input.shape)
出力の最後の結果は
(1093,
928
[torch.LongTensor of size 1]
)
(1L, 1L, 64L, 64L)
出力されるTensorは4次元で、画像を自動的に1次元加算します.