PyTorchデータ読み出し

1858 ワード

torch.utils.data.DataLoader
torch.utils.data.DataLoader(torch.utils.data.dataset,batch_size,shuffle,num_workers,pin_memory)
重要なのはこの2つのクラスです:torch.utils.data.DataLoader torch.utils.data.dataset
import torchvision.transforms as transforms

train_loader = torch.utils.data.DataLoader(
ImageList(root=opt.root_path, fileList=opt.train_list, 
transform=transforms.Compose([ 
transforms.ToTensor(),              #        Tensor  ,   
])),
batch_size=opt.batch_size, shuffle=True,
num_workers=opt.workers, pin_memory=True)

クラスをデータリーダとして書き、torchを継承する.utils.data.dataset
#load_imglist.py
import torch.utils.data 

from PIL import Image
import os



def default_list_reader(fileList):
    imgList = []
    with open(fileList, 'r') as file:
        for line in file.readlines():
            imgPath, label = line.strip().split(' ')
            imgList.append((imgPath, int(label)))
    return imgList


class ImageList(torch.utils.data.Dataset):
    def __init__(self, root, fileList, transform=None):
        self.root      = root
        self.imgList   = default_list_reader(fileList)
        self.transform = transform


    def __getitem__(self, index):

        imgPath, target = self.imgList[index]

        print(imgPath)

        img_loc=os.path.join(self.root, imgPath)
        img = Image.open(img_loc).convert('L')  #        ,     RGB  

        if self.transform is not None:
            img = self.transform(img)

        return img, target

    def __len__(self):
        return len(self.imgList)

使用
for i,(input,target) in enumerate(train_loader):
    print(i,target)
    print(input.shape)

出力の最後の結果は
(1093, 
 928
[torch.LongTensor of size 1]
)
(1L, 1L, 64L, 64L)

出力されるTensorは4次元で、画像を自動的に1次元加算します.