Caltech-256データセット処理(3)トレーニングセットと検証セットをPyTorch Dateloaderにロード

10290 ワード

PyTorchでのCaltech-256データセットの処理:Caltech-256データセット処理(一)label抽出Caltech-256データセット処理(二)トレーニングセットとテストセットの作成Caltech-256データセット処理(三)トレーニングセットと検証セットDateloaderへのロード
  • Caltech-256の各ピクチャのサイズは一定ではないので、ここではcrop操作を行う必要がある.
  • ここでサボって、meanとstdはimagenetのデータに行って、厳密には単独で計算する必要があります.
  • rstrip()とstrip()は、具体的なシーンに応じて柔軟に使用できますが、ここでは念のため多く使用されています.
  • import torch
    from torch.autograd import Variable
    from torchvision import transforms
    from torch.utils.data import Dataset, DataLoader
    from PIL import Image
    
    root='/media/this/02ff0572-4aa8-47c6-975d-16c3b8062013/'
    
    def default_loader(path):
        return Image.open(path).convert('RGB')
    
    class MyDataset(Dataset):
        def __init__(self, txt, transform=None, target_transform=None, loader=default_loader):
            fh = open(txt, 'r')
            imgs = []
            for line in fh:
                line = line.rstrip()
                line = line.strip('
    '
    ) line = line.rstrip() words = line.split() imgs.append((words[0],int(words[1]))) self.imgs = imgs self.transform = transform self.target_transform = target_transform self.loader = loader def __getitem__(self, index): fn, label = self.imgs[index] img = self.loader(fn) if self.transform is not None: img = self.transform(img) return img,label def __len__(self): return len(self.imgs) mean = [ 0.485, 0.456, 0.406 ] std = [ 0.229, 0.224, 0.225 ] transform = transforms.Compose([ transforms.Scale(256), transforms.CenterCrop(224), transforms.ToTensor(), transforms.Normalize(mean = mean, std = std), ]) train_data = MyDataset(txt=root+'dataset-trn.txt', transform=transform) test_data = MyDataset(txt=root+'dataset-val.txt', transform=transform) train_loader = DataLoader(dataset=train_data, batch_size=64, shuffle=True) test_loader = DataLoader(dataset=test_data, batch_size=64) ''' for idx, (data, target) in enumerate(test_loader): if(idx%10==0): print(str(idx)+' '+str(target)) for idx, (data, target) in enumerate(train_loader): if(idx%10==0): print(str(idx)+' '+str(target)) '''