PyTorchソースコード解読のtorch.utils.data.DataLoader
23403 ワード
PyTorchにおけるデータ読み出しの重要なインタフェースはtorchである.utils.data.DataLoader,このインタフェースはdataloaderに定義されている.pyスクリプトでは,PyTorchを用いてモデルを訓練する場合には基本的にこのインタフェースが用いられるが,このインタフェースは主にカスタムデータ読み出しインタフェースの出力やPyTorchの既存のデータ読み出しインタフェースの入力をbatch sizeに従ってTensorにカプセル化し,その後はVariableに再パッケージするだけでモデルの入力として利用できるため,このインタフェースは少し啓発的な役割を果たすことが重要である.このブログでは、このインタフェースのソースコードについて説明します.主にDataLoaderとDataLoaderIterの2つのクラスが含まれています.dataloader.pyスクリプトのgithubアドレス:https://github.com/pytorch/pytorch/blob/master/torch/utils/data/dataloader.py
DataLoaderクラスのソースコードは以下の通りです.まず
DataLoaderIterクラスのソースコードは次のとおりです.self.index_queue = multiprocessing.SimpleQueue()のmultiprocessingはPythonのマルチプロセス管理パケットであり、threadingはPythonのマルチスレッド管理パケットであり、両者の多くのインタフェースの使い方は似ている.やはり例によってまず
DataLoaderIterクラスの
pin_memory_batch関数は、DataLoaderクラスまたはDataLoaderIterクラスに定義されていません.この関数は主にbatch中のTensorに対してbatchを実行する.pin_memory()操作、ここでの多くの条件文はbatchのタイプを判断するために使用され、batchがリストであり、リストの各値がTensorである場合、elif isinstance(batch,collections.Sequence):この条件が実行され、リスト内の各Tensorを遍歴し、最初の条件文の内容:return batchを実行する.pin_memory()
DataloaderIterクラスの_get_batchメソッド.主にタイムアウト時間が設定されているかどうかによって操作され、指定されたタイムアウト時間を超えた後にキューからデータを読み込まなければエラーを報告し、タイムアウト時間を設定せずに一致してキューからデータを読み込まなければ、ずっと詰まっていてエラーを報告しない.この部分はPyTorchが後に修理したバグである.
DataLoaderIterクラスの_process_next_batchメソッド.まずself.rcvd_idxは、次のbatchデータを更新するindexを加算します.そして呼び出し_put_indices()メソッドは、次のbatchの各データのindexを取得する.
DataLoaderIterクラスの_put_indicesメソッド.この方法は主にselfから実現する.sample_iterで次のbatchデータの各データのindex:indices=next(self.sample_iter,None)を読み出し、ここでのindexは前のidxとは異なり、ここでのindexはbatchの各データのindexであり、idxはbatchのindexであることに注意する.そして、読み出したindexをqueueオブジェクトを呼び出すputメソッドによりキューselfに押す.index_Queue中:self.index_queue.put((self.send_idx, indices))
DataLoaderクラスのソースコードは以下の通りです.まず
__init__
のいくつかの重要な入力を見てみましょう:1、dataset、これはPyTorchの既存のデータ読み出しインタフェース(例えばtorchvision.datasets.ImageFolder)またはカスタムデータインタフェースの出力であり、この出力はtorchである.utils.data.Datasetクラスのオブジェクト、またはtorchから継承する.utils.data.Datasetクラスのカスタムクラスのオブジェクト.2、batch_size、状況に応じて設定すればいいです.3、shuffleは、一般的にトレーニングデータに採用されています.4、collate_fnは、異なる場合の入力datasetを処理するためのパッケージであり、カスタマイズされたデータ読み出し出力が非常に珍しい場合を除き、一般的にデフォルトでよい.5、batch_sampler、注釈から分かるようにbatch_size、shuffleなどのパラメータは反発し合い、一般的にデフォルトを採用します.6、sampler、コードから分かるように、shuffleと反発し、一般的にデフォルトでよい.7、num_workersは、コメントからこのパラメータが0以上でなければならないことがわかります.0は、データのインポートがメインプロセスで行われることを示し、他の0以上の数は、複数のプロセスでデータをインポートすることで、データのインポート速度を速めることができます.8、pin_memory、注釈がはっきり書いてあります:pin_memory (bool, optional): If True, the data loader will copy tensors into CUDA pinned memory before returning them. つまり、データコピーの問題です.9、timeoutは、データの読み込みのタイムアウト時間を設定するためのものですが、それを超えてもまだデータが読み込まれていない場合はエラーとなります.__init__
では、RandomSamplerクラスはランダムサンプリングを表し、重複しないため、shuffleの役割を果たす.BatchSamplerクラスは、batch size個のRandomSamplerクラスオブジェクトを1つにカプセル化し、ランダムにbatchを選択する目的を実現します.この2つのサンプリングクラスはsamplerに定義されている.pyスクリプト、アドレス:https://github.com/pytorch/pytorch/blob/master/torch/utils/data/sampler.py.これらはいずれも初期化の際に行われたものである.コードがtorchから実行されるまでutils.data.DataLoaderクラスが生成したオブジェクトからデータを取得する場合、たとえば、train_data=torch.utils.data.DataLoader(...)
for i, (input, target) in enumerate(train_data):
...
でDataLoaderクラスの__iter__
メソッドが呼び出され、__iter__
メソッドでは1行のコード:return DataLoaderIter(self)が入力され、DataLoaderクラスの属性である.したがって、__iter__
メソッドが呼び出されると、別のクラス:DataLoaderIterに関連し、次に説明する.class DataLoader(object):
"""
Data loader. Combines a dataset and a sampler, and provides
single- or multi-process iterators over the dataset.
Arguments:
dataset (Dataset): dataset from which to load the data.
batch_size (int, optional): how many samples per batch to load
(default: 1).
shuffle (bool, optional): set to ``True`` to have the data reshuffled
at every epoch (default: False).
sampler (Sampler, optional): defines the strategy to draw samples from
the dataset. If specified, ``shuffle`` must be False.
batch_sampler (Sampler, optional): like sampler, but returns a batch of
indices at a time. Mutually exclusive with batch_size, shuffle,
sampler, and drop_last.
num_workers (int, optional): how many subprocesses to use for data
loading. 0 means that the data will be loaded in the main process.
(default: 0)
collate_fn (callable, optional): merges a list of samples to form a mini-batch.
pin_memory (bool, optional): If ``True``, the data loader will copy tensors
into CUDA pinned memory before returning them.
drop_last (bool, optional): set to ``True`` to drop the last incomplete batch,
if the dataset size is not divisible by the batch size. If ``False`` and
the size of dataset is not divisible by the batch size, then the last batch
will be smaller. (default: False)
timeout (numeric, optional): if positive, the timeout value for collecting a batch
from workers. Should always be non-negative. (default: 0)
worker_init_fn (callable, optional): If not None, this will be called on each
worker subprocess with the worker id as input, after seeding and before data
loading. (default: None)
"""
def __init__(self, dataset, batch_size=1, shuffle=False, sampler=None, batch_sampler=None,
num_workers=0, collate_fn=default_collate, pin_memory=False, drop_last=False,
timeout=0, worker_init_fn=None):
self.dataset = dataset
self.batch_size = batch_size
self.num_workers = num_workers
self.collate_fn = collate_fn
self.pin_memory = pin_memory
self.drop_last = drop_last
self.timeout = timeout
self.worker_init_fn = worker_init_fn
if timeout < 0:
raise ValueError('timeout option should be non-negative')
if batch_sampler is not None:
if batch_size > 1 or shuffle or sampler is not None or drop_last:
raise ValueError('batch_sampler is mutually exclusive with '
'batch_size, shuffle, sampler, and drop_last')
if sampler is not None and shuffle:
raise ValueError('sampler is mutually exclusive with shuffle')
if self.num_workers < 0:
raise ValueError('num_workers cannot be negative; '
'use num_workers=0 to disable multiprocessing.')
if batch_sampler is None:
if sampler is None:
if shuffle:
sampler = RandomSampler(dataset)
else:
sampler = SequentialSampler(dataset)
batch_sampler = BatchSampler(sampler, batch_size, drop_last)
self.sampler = sampler
self.batch_sampler = batch_sampler
def __iter__(self):
return DataLoaderIter(self)
def __len__(self):
return len(self.batch_sampler)
DataLoaderIterクラスのソースコードは次のとおりです.self.index_queue = multiprocessing.SimpleQueue()のmultiprocessingはPythonのマルチプロセス管理パケットであり、threadingはPythonのマルチスレッド管理パケットであり、両者の多くのインタフェースの使い方は似ている.やはり例によってまず
__init__
を見て、前の部分はすべていくつかの賦値操作で、比較的に特殊なのはselfです.sample_iter=iter(self.batch_sampler)、得られたself.sample_iterはnext(self.sample_iter)によってbatch size個のデータのindexを取得することができる.self.rcvd_idxは、読み込まれたbatchデータのindexを表し、0に初期化され、この値はデータを反復して読み込むときに使用されます.if self.num_workers文は、マルチプロセスまたは単一プロセスの場合に初期化されます.マルチプロセスとしてデータを読み出すように設定されていない場合は、これらの初期化操作は必要ありません.後で、単一プロセスデータの読み出しについて説明します.if文でmultiprocessing.SimpleQueue()クラスは、単純なキューオブジェクトを作成します.multiprocessing.Processクラスは構築プロセスのクラスである、ここでは設定プロセス数に基づいて起動しselfに値を与える.workers.次のforループはstartメソッドを呼び出すことによってselfを順次起動する.workersのプロセス.次にselfについて.pin_memoryの判断文で、この判断文の内部には主にマルチスレッド操作が実現されている.self.pin_memoryの意味は先に紹介しましたが、Trueの場合はCUDAにデータをコピーします.self.data_queue = queue.Queue()はPythonのqueueモジュールの初期化によって1つの先進先出のキューを得る(queueモジュールは先進後出のキューを初期化することもでき、queue.LifoQueue()で初期化する必要がある)、queueモジュールは主にマルチスレッド読み出しデータに応用される.threading.Threadのargsパラメータのうち、最初のパラメータin_Dataは1つのプロセスのデータであり、1つのプロセスの異なるスレッドのデータもキューによって維持され、ここではPythonのqueueモジュールを採用して初期化して1つのキューを得る:queue.Queue().初期化が完了すると、__next__
メソッドが呼び出されます.次に説明します.総じて、マルチプロセスでデータを読み出すように設定すると、キューで読み、マルチプロセスでデータを読み込まない場合は、通常の方法で読みます.class DataLoaderIter(object):
"Iterates once over the DataLoader's dataset, as specified by the sampler"
def __init__(self, loader):
self.dataset = loader.dataset
self.collate_fn = loader.collate_fn
self.batch_sampler = loader.batch_sampler
self.num_workers = loader.num_workers
self.pin_memory = loader.pin_memory and torch.cuda.is_available()
self.timeout = loader.timeout
self.done_event = threading.Event()
self.sample_iter = iter(self.batch_sampler)
if self.num_workers > 0:
self.worker_init_fn = loader.worker_init_fn
self.index_queue = multiprocessing.SimpleQueue()
self.worker_result_queue = multiprocessing.SimpleQueue()
self.batches_outstanding = 0
self.worker_pids_set = False
self.shutdown = False
self.send_idx = 0
self.rcvd_idx = 0
self.reorder_dict = {}
base_seed = torch.LongTensor(1).random_()[0]
self.workers = [
multiprocessing.Process(
target=_worker_loop,
args=(self.dataset, self.index_queue, self.worker_result_queue, self.collate_fn,
base_seed + i, self.worker_init_fn, i))
for i in range(self.num_workers)]
if self.pin_memory or self.timeout > 0:
self.data_queue = queue.Queue()
self.worker_manager_thread = threading.Thread(
target=_worker_manager_loop,
args=(self.worker_result_queue, self.data_queue, self.done_event, self.pin_memory,
torch.cuda.current_device()))
self.worker_manager_thread.daemon = True
self.worker_manager_thread.start()
else:
self.data_queue = self.worker_result_queue
for w in self.workers:
w.daemon = True # ensure that the worker exits on process exit
w.start()
_update_worker_pids(id(self), tuple(w.pid for w in self.workers))
_set_SIGCHLD_handler()
self.worker_pids_set = True
# prime the prefetch loop
for _ in range(2 * self.num_workers):
self._put_indices()
DataLoaderIterクラスの
__next__
メソッドは、3つのif文と1つのwhile文を含む.最初のif文はselfを処理するために使用される.num_workersが0に等しい場合、すなわちマルチプロセスを用いてデータ読み出しを行わない場合、このif文ではまずindices=next(self.sample_iter)によって長さbatch sizeのリスト:indicesが取得され、このリストの各値はbatch中の各データのindexを表し、next操作を実行するたびにbatch sizeのindicesリストが読み出される.そしてself.collate_fn関数はbatch size個のtuple(各tuple長は2であり、そのうち第1の値はデータ、Tensorタイプ、第2の値はラベル、intタイプ)を1つのlistにカプセル化し、このlist長は2であり、2つの値はいずれもTensorであり、1つはbatch size個のデータからなるFloatTensorであり、もう1つはbatch size個のラベルからなるLongTensorである.簡単に言えばcollate_fn関数はbatch size個の分散したTensorを一つのTensorにカプセル化することである.batch = pin_memory_batch(batch)中pin_memory_batch関数の役割は,入力batchの各TensorをCUDAにコピーすることであり,この関数は後述する.2番目のif文は、現在読み出したいbatchのindex(self.rcvd_idx)が以前に読み出されたか否か(読み出されたindexとbatchデータがself.reorder_dict辞書に保存されているので、self.reorder_dictワードの更新が最後のwhile文であるため、最後のwhile文と併せて見ることができる)であり、前に読み出された場合、このindexに基づいてreorderからdict辞書から対応するデータがポップアップされます.最後にbatchデータを返すときはreturn self.process_next_batch(batch)です.この方法は後で詳しく説明します.主に次のbatchのデータindex情報を取得します.3番目のif文self.batches_outstandingの値は、前の初期にself._を呼び出します.put_indices()メソッドが変更されたので、プロセス数selfを仮定します.num_workersが3に設定されているので、ここではself.batches_outstandingは3*2=6ですが、具体的にはself.put_indices()メソッド.最後のwhileサイクルは本当にキューからデータを読み出すための操作であり、最も主要なのはidx、batch=self._である.get_batch()を呼び出して_get_batch()メソッドで読み込みますが、簡単に言えばキューを呼び出したgetメソッドで次のbatchのデータが得られます.得られたbatchは一般的に長さ2のリストで、リストの2つの値はいずれもTensorで、それぞれデータ(batchの)とラベルを表します._get_batch()メソッドはbatchデータを返す以外に、もう一つの出力を得る:idx、この出力はbatchのindexを表し、このif idx!=self.rcvd_idx条件文は、batchのindexが現在のindex:selg,rcvd_に等しくない場合を示します.idxは、読み出したデータを辞書selfに保存する.reorder_dict中:self.reorder_dict[idx]=batchは、読み出したデータのindexがselfに等しくなるまでデータの読み出しを継続する.rcvd_idx. def __next__(self):
if self.num_workers == 0: # same-process loading
indices = next(self.sample_iter) # may raise StopIteration
batch = self.collate_fn([self.dataset[i] for i in indices])
if self.pin_memory:
batch = pin_memory_batch(batch)
return batch
# check if the next sample has already been generated
if self.rcvd_idx in self.reorder_dict:
batch = self.reorder_dict.pop(self.rcvd_idx)
return self._process_next_batch(batch)
if self.batches_outstanding == 0:
self._shutdown_workers()
raise StopIteration
while True:
assert (not self.shutdown and self.batches_outstanding > 0)
idx, batch = self._get_batch()
self.batches_outstanding -= 1
if idx != self.rcvd_idx:
# store out-of-order samples
self.reorder_dict[idx] = batch
continue
return self._process_next_batch(batch)
pin_memory_batch関数は、DataLoaderクラスまたはDataLoaderIterクラスに定義されていません.この関数は主にbatch中のTensorに対してbatchを実行する.pin_memory()操作、ここでの多くの条件文はbatchのタイプを判断するために使用され、batchがリストであり、リストの各値がTensorである場合、elif isinstance(batch,collections.Sequence):この条件が実行され、リスト内の各Tensorを遍歴し、最初の条件文の内容:return batchを実行する.pin_memory()
def pin_memory_batch(batch):
if torch.is_tensor(batch):
return batch.pin_memory()
elif isinstance(batch, string_classes):
return batch
elif isinstance(batch, collections.Mapping):
return {k: pin_memory_batch(sample) for k, sample in batch.items()}
elif isinstance(batch, collections.Sequence):
return [pin_memory_batch(sample) for sample in batch]
else:
return batch
DataloaderIterクラスの_get_batchメソッド.主にタイムアウト時間が設定されているかどうかによって操作され、指定されたタイムアウト時間を超えた後にキューからデータを読み込まなければエラーを報告し、タイムアウト時間を設定せずに一致してキューからデータを読み込まなければ、ずっと詰まっていてエラーを報告しない.この部分はPyTorchが後に修理したバグである.
def _get_batch(self):
if self.timeout > 0:
try:
return self.data_queue.get(True, self.timeout)
except queue.Empty:
raise RuntimeError('DataLoader timed out after {} seconds'.format(self.timeout))
else:
return self.data_queue.get()
DataLoaderIterクラスの_process_next_batchメソッド.まずself.rcvd_idxは、次のbatchデータを更新するindexを加算します.そして呼び出し_put_indices()メソッドは、次のbatchの各データのindexを取得する.
def _process_next_batch(self, batch):
self.rcvd_idx += 1
self._put_indices()
if isinstance(batch, ExceptionWrapper):
raise batch.exc_type(batch.exc_msg)
return batch
DataLoaderIterクラスの_put_indicesメソッド.この方法は主にselfから実現する.sample_iterで次のbatchデータの各データのindex:indices=next(self.sample_iter,None)を読み出し、ここでのindexは前のidxとは異なり、ここでのindexはbatchの各データのindexであり、idxはbatchのindexであることに注意する.そして、読み出したindexをqueueオブジェクトを呼び出すputメソッドによりキューselfに押す.index_Queue中:self.index_queue.put((self.send_idx, indices))
def _put_indices(self):
assert self.batches_outstanding < 2 * self.num_workers
indices = next(self.sample_iter, None)
if indices is None:
return
self.index_queue.put((self.send_idx, indices))
self.batches_outstanding += 1
self.send_idx += 1