pytorch DataLoaderカスタムデータセット

4836 ワード

pytorchはmini-batchのデータを生成し、トレーニングやテスト時にマルチスレッド処理を行い、データの準備を速めるデータ処理方式を提供しています.この関数ツールは
torch.utils.data import Dataset, DataLoader

その中でDatasetは私たちが自分のマルチスレッドデータ処理フレームワークを定義する親クラスで、私たちが定義したフレームワークはこのクラスの下で簡単にデータ準備を定義するフレームワークを継承しましょう!!!
from torch.utils.data import Dataset,DataLoader
class MyDataset(Dataset):
    def __init__(self, filepath, transform=None,keys = None, target_transform=None):
        pass
    '''
     ,filepath ,transform (features) ,target_transform (labels) ,keys , , , ndarray ,  
    '''
    def __getitem__(self,index):
        pass
    def __len__(self):
        pass

次に、データと組み合わせて、次のコードを説明します.
class MyDataset(Dataset):
    def __init__(self, filepath, transform=None,keys = None, target_transform=None):
        with open(filepath,'rb') as f:
            self.data = pickle.load(f)
        self.keys = keys
        self.input_seq = self.data[self.keys[0]]  ###  
        self.output_seq = self.data[self.keys[1]]   ####  
        self.transform = transform ####  
        self.target_transform = target_transform   ######  

    def __getitem__(self, index):
        input_seq,output_seq = self.input_seq[index],self.output_seq[index]  ##  
        if self.transform is not None:
            input_seq = self.transform(input_seq)
            output_seq = self.transform(output_seq)
        return input_seq,output_seq  ###  

    def __len__(self):
        return self.data[self.keys[0]].shape[0]   ###  , 

train_data = MyDataset(filepath = 'train30.pickle',keys = ['aa','bb'])
test_data = MyDataset(filepath = 'test30.pickle',keys = ['aa','bb'])
train_loader = DataLoader(dataset = train_data,batch_size = 32,shuffle = False)
test_loader = DataLoader(dataset = test_data,batch_size = 32,shuffle = False)

train_を呼び出すことができますloaderとtest_loaderはmini-batchのデータを呼び出し、batch_sizeはtrain_loaderとtest_loaderの値はすでに設定されており、以下のコードで一括して呼び出されます.
from torch.autograd import Variable
for i,(input_seq,out_seq) in enumerate(train_loader):
    input_seq = Variable(input_seq.cuda())
    output_seq = Variable(output_seq.cuda())