【Pytorch】---データの読み出しと操作(Dataset,DataLoader)

7268 ワード

前言
Pytorchで重要なのはデータの処理です.ここで、データの読み取りを行うのは一般的に3つのクラスがあります.
  • Dataset
  • DataLoader

  • このうち,この3つは順次パッケージ化の関係である:「Dataset被パッケージ化DataLoader,DataLoader再パッケージ化DataLoaderIter
    Dataset Datasettorch.utils.data.Datasetにあり、クラスMyDatasetをカスタマイズするたびに、2つのメンバー関数を継承し、実装する必要があります.
  • __len__()
  • __getitem__()

  • 例:
    import torch
    from torch.utils.data import Dataset
    import pandas as pd
    
    #       
    class MyDataset(Dataset):
        
        #    
        def __init__(self, file_name):
            #     
            self.data = pd.read_csv(file_name)
        
        #   df   
        def __len__(self):
            return len(self.data)
        
        #    idx+1    
        def __getitem__(self, idx):
            return self.data[idx].label
    
    #             
    #           median_benchmark.csv   
    ds = MyDataset('median_benchmark.csv')
    
    '''
    len(ds)           
    ds[101]             
    '''
    

    DataLoader DataLoadertorch.utils.data.DataLoaderにあり、Datasetの読み取り操作を提供しています.
    #             
    torch.nn.data.DataLoader(dataset, batch_size=1, shuffle=False, num_workers=0)
    
  • dataset:前述のカスタムクラスDataset
  • batch_size:デフォルトは1で、毎回読み込まれるbatchのサイズ
  • shuffle:デフォルトはFalseであり、データに対してshuffle操作を行うかどうか(データセットを乱すと簡単に理解できる)
  • num_works:デフォルトは0で、データをロードするたびにサブプロセスを使用する数、すなわち単純なマルチスレッドプリフェッチデータの方法
  • DataLoaderは反復器を返し、この反復器によってデータを取得します.Dataloderの目的は、所与のn n n個のデータをDataloader操作後、呼び出しのたびに小さなbatchを呼び出すことである.
  • は,(5000,28,28)(5000,28,28)(5000,28,28)(5000,28,28)であり,5000,000,5000個のサンプルがあることを示し,各サンプルのsizeは(28,28)(28,28)
  • である.
  • Dataloaderの処理後、一度に得られたものは(100,28,28)(100,28,28)(100,28,28,28)(batch_sizeサイズを100とする)であり、今回100個のサンプルを取り出したことを示し、各サンプルのsizeは(28,28)(28,28)
  • である
    #      Dataset    
    
    from torch.utils.data import DataLoader
    
    dl = DataLoader(ds, batch_size=10, shuffle=True, num_works=2)
    

    反復器によるデータの分割取得:
    dl_data = iter(dl)
    print(next(dl_data))
    
    '''
    Output:(  )
    
    tensor([24000., 24000., 24000., 24000., 24000., 24000., 24000., 24000., 24000.,
            24000.], dtype=torch.float64)
    
    '''
    

    あるいは、直接forループによる遍歴出力
    for i, data in enumerate(dl):
        print(i, data)
        
        #        ,   break
        break
    
    '''
    Output:(  )
    
    0 tensor([24000., 24000., 24000., 24000., 24000., 24000., 24000., 24000., 24000.,
            24000.], dtype=torch.float64)
    ''''
    

    参考資料
  • [1]. https://pytorch.org/docs/stable/data.html?highlight=dataloader#torch.utils.data.DataLoader
  • [2]. https://github.com/pandadreamer/pytorch-handbook/blob/master/chapter2/2.1.4-pytorch-basics-data-lorder.ipynb
  • [3]. https://zhuanlan.zhihu.com/p/30934236