【Pytorch】---データの読み出しと操作(Dataset,DataLoader)
7268 ワード
前言
Pytorchで重要なのはデータの処理です.ここで、データの読み取りを行うのは一般的に3つのクラスがあります. Dataset DataLoader
このうち,この3つは順次パッケージ化の関係である:「
Dataset
例:
DataLoaderは,(5000,28,28)(5000,28,28)(5000,28,28)(5000,28,28)であり,5000,000,5000個のサンプルがあることを示し,各サンプルのsizeは(28,28)(28,28) である.は である
反復器によるデータの分割取得:
あるいは、直接forループによる遍歴出力
参考資料 [1]. https://pytorch.org/docs/stable/data.html?highlight=dataloader#torch.utils.data.DataLoader [2]. https://github.com/pandadreamer/pytorch-handbook/blob/master/chapter2/2.1.4-pytorch-basics-data-lorder.ipynb [3]. https://zhuanlan.zhihu.com/p/30934236
Pytorchで重要なのはデータの処理です.ここで、データの読み取りを行うのは一般的に3つのクラスがあります.
このうち,この3つは順次パッケージ化の関係である:「
Dataset
被パッケージ化DataLoader
,DataLoader
再パッケージ化DataLoaderIter
」Dataset
Dataset
はtorch.utils.data.Dataset
にあり、クラスMyDataset
をカスタマイズするたびに、2つのメンバー関数を継承し、実装する必要があります.__len__()
__getitem__()
例:
import torch
from torch.utils.data import Dataset
import pandas as pd
#
class MyDataset(Dataset):
#
def __init__(self, file_name):
#
self.data = pd.read_csv(file_name)
# df
def __len__(self):
return len(self.data)
# idx+1
def __getitem__(self, idx):
return self.data[idx].label
#
# median_benchmark.csv
ds = MyDataset('median_benchmark.csv')
'''
len(ds)
ds[101]
'''
DataLoader
DataLoader
はtorch.utils.data.DataLoader
にあり、Dataset
の読み取り操作を提供しています.#
torch.nn.data.DataLoader(dataset, batch_size=1, shuffle=False, num_workers=0)
dataset
:前述のカスタムクラスDataset
batch_size
:デフォルトは1で、毎回読み込まれるbatchのサイズshuffle
:デフォルトはFalseであり、データに対してshuffle操作を行うかどうか(データセットを乱すと簡単に理解できる)num_works
:デフォルトは0で、データをロードするたびにサブプロセスを使用する数、すなわち単純なマルチスレッドプリフェッチデータの方法DataLoader
は反復器を返し、この反復器によってデータを取得します.Dataloder
の目的は、所与のn n n個のデータをDataloader
操作後、呼び出しのたびに小さなbatchを呼び出すことである.Dataloader
の処理後、一度に得られたものは(100,28,28)(100,28,28)(100,28,28,28)(batch_sizeサイズを100とする)であり、今回100個のサンプルを取り出したことを示し、各サンプルのsizeは(28,28)(28,28)# Dataset
from torch.utils.data import DataLoader
dl = DataLoader(ds, batch_size=10, shuffle=True, num_works=2)
反復器によるデータの分割取得:
dl_data = iter(dl)
print(next(dl_data))
'''
Output:( )
tensor([24000., 24000., 24000., 24000., 24000., 24000., 24000., 24000., 24000.,
24000.], dtype=torch.float64)
'''
あるいは、直接forループによる遍歴出力
for i, data in enumerate(dl):
print(i, data)
# , break
break
'''
Output:( )
0 tensor([24000., 24000., 24000., 24000., 24000., 24000., 24000., 24000., 24000.,
24000.], dtype=torch.float64)
''''
参考資料