pytochのdataloaderは深く分析します.

62365 ワード

本論文の内容は以下の通りになります.https://www.cnblogs.com/ranjiewen/p/10128046.html PyTorch学習ノート(6)——Data Loaderソースコード解析
  • dataloaderは、本質的には反復可能なオブジェクトであり、iter()を使用してアクセスし、next()を使用してアクセスできない.
  • iter(dataloader)を使って帰ってきたのは、次のnextを使ってアクセスできます.
  • は、for inputs, labels in dataloadersを使用して反復可能なオブジェクトへのアクセスも可能である.
  • 一般的に、私たちはdatasetsオブジェクトを実現し、dataloaderに入る.その後内部でyeildを使ってbatchのデータを返します.
  • ①Data Loaderは本質的にはiterableであり(pythonの内蔵タイプlistなどと同じ)、マルチプロセスを利用してbatch dataの処理を加速し、yieldを使用して有限メモリ②Queの特徴を使用して、キューにデータがない場合:queue.get()がブロックされ、ブロックされる場合、他のプロセス/スレッドはque.put(作業スレッド)があります.そして成功をゲットできます.データがいっぱいになった時:queue.put()③Data Loaderは効率的で、簡潔で、直感的なネットワーク入力データ構造で、使用と拡張に便利です.
    データPipeLine pytouchを入力するデータをモデルにロードする操作手順はこうです.
    ①Datasetオブジェクトを作成する②Data Loaderオブジェクトを作成する③ループというData Loaderオブジェクトをモデルにロードして訓練する
    dataset=MyDataset()dataloader=DataLoader(dataset)num_epoches=100 for epoch in range(num uepoches):for img、label in dataloader:…
    したがって,直接データがモデルに入るための鍵となるステップとして,Data Loaderは非常に重要である.
    まず、Data Loaderを簡単に紹介します.PyTorchでデータを読み取る重要なインターフェースです.このインターフェースはdataloader.pyで定義されています.PyTorchでモデルを訓練するなら、基本的にはこのインターフェースを使用します.後のトレーニングに使います.
    データロードはデータセットとサンプラで構成され、pythonのシングル、マルチプロセスのiteratorsに基づいてデータを処理します.iter_同前next_u.方法は、iterableは_u uしかありません.iter_方法
    1.Data Loader
    まず、Data Loaderのパラメータを紹介します.
        PipeLine
    pytorch                  :
    
    ①      Dataset   
    ②      DataLoader   
    ③      DataLoader   , img, label          
    
    dataset = MyDataset()
    dataloader = DataLoader(dataset)
    num_epoches = 100
    for epoch in range(num_epoches):
        for img, label in dataloader:
            ....
      ,                 , DataLoader    。
    
            DataLoader,  PyTorch            ,      dataloader.py ,    PyTorch              (      …),      :     Dataset  batch size  、  shuffle      Batch Size   Tensor,       。
    
       DataLoader    :“              ,  python  、    iterators     。”  iterator iterable           ,         iterators __iter__ __next__  , iterable  __iter__  。
    
    1.DataLoader
    
         DataLoader(object)
  • まずdataloader初期化時にdatasetsのサンプルリスト’
  • を得る.
    class DataLoader(object):
        r"""
        Data loader. Combines a dataset and a sampler, and provides
        single- or multi-process iterators over the dataset.
    
        Arguments:
            dataset (Dataset): dataset from which to load the data.
            batch_size (int, optional): how many samples per batch to load
                (default: 1).
            shuffle (bool, optional): set to ``True`` to have the data reshuffled
                at every epoch (default: False).
            sampler (Sampler, optional): defines the strategy to draw samples from
                the dataset. If specified, ``shuffle`` must be False.
            batch_sampler (Sampler, optional): like sampler, but returns a batch of
                indices at a time. Mutually exclusive with batch_size, shuffle,
                sampler, and drop_last.
            num_workers (int, optional): how many subprocesses to use for data
                loading. 0 means that the data will be loaded in the main process.
                (default: 0)
            collate_fn (callable, optional): merges a list of samples to form a mini-batch.
            pin_memory (bool, optional): If ``True``, the data loader will copy tensors
                into CUDA pinned memory before returning them.
            drop_last (bool, optional): set to ``True`` to drop the last incomplete batch,
                if the dataset size is not divisible by the batch size. If ``False`` and
                the size of dataset is not divisible by the batch size, then the last batch
                will be smaller. (default: False)
            timeout (numeric, optional): if positive, the timeout value for collecting a batch
                from workers. Should always be non-negative. (default: 0)
            worker_init_fn (callable, optional): If not None, this will be called on each
                worker subprocess with the worker id (an int in ``[0, num_workers - 1]``) as
                input, after seeding and before data loading. (default: None)
    
        .. note:: By default, each worker will have its PyTorch seed set to
                  ``base_seed + worker_id``, where ``base_seed`` is a long generated
                  by main process using its RNG. However, seeds for other libraies
                  may be duplicated upon initializing workers (w.g., NumPy), causing
                  each worker to return identical random numbers. (See
                  :ref:`dataloader-workers-random-seed` section in FAQ.) You may
                  use ``torch.initial_seed()`` to access the PyTorch seed for each
                  worker in :attr:`worker_init_fn`, and use it to set other seeds
                  before data loading.
    
        .. warning:: If ``spawn`` start method is used, :attr:`worker_init_fn` cannot be an
                     unpicklable object, e.g., a lambda function.
        """
    
        __initialized = False
    
        def __init__(self, dataset, batch_size=1, shuffle=False, sampler=None, batch_sampler=None,
                     num_workers=0, collate_fn=default_collate, pin_memory=False, drop_last=False,
                     timeout=0, worker_init_fn=None):
            self.dataset = dataset
            self.batch_size = batch_size
            self.num_workers = num_workers
            self.collate_fn = collate_fn
            self.pin_memory = pin_memory
            self.drop_last = drop_last
            self.timeout = timeout
            self.worker_init_fn = worker_init_fn
    
            if timeout < 0:
                raise ValueError('timeout option should be non-negative')
    
            if batch_sampler is not None:
                if batch_size > 1 or shuffle or sampler is not None or drop_last:
                    raise ValueError('batch_sampler option is mutually exclusive '
                                     'with batch_size, shuffle, sampler, and '
                                     'drop_last')
                self.batch_size = None
                self.drop_last = None
    
            if sampler is not None and shuffle:
                raise ValueError('sampler option is mutually exclusive with '
                                 'shuffle')
    
            if self.num_workers < 0:
                raise ValueError('num_workers option cannot be negative; '
                                 'use num_workers=0 to disable multiprocessing.')
    
            if batch_sampler is None:
                if sampler is None:
                    if shuffle:
                        sampler = RandomSampler(dataset)  // list  
                    else:
                        sampler = SequentialSampler(dataset)
                batch_sampler = BatchSampler(sampler, batch_size, drop_last)
    
            self.sampler = sampler
            self.batch_sampler = batch_sampler
            self.__initialized = True
    
        def __setattr__(self, attr, val):
            if self.__initialized and attr in ('batch_size', 'sampler', 'drop_last'):
                raise ValueError('{} attribute should not be set after {} is '
                                 'initialized'.format(attr, self.__class__.__name__))
    
            super(DataLoader, self).__setattr__(attr, val)
    
        def __iter__(self):
            return _DataLoaderIter(self)
    
        def __len__(self):
            return len(self.batch_sampler)
    
    その中:RandomSampler、BatchSamplerはすでにbatchデータを採用したindexインデックスを得ました.イェードバッグの仕組みはもうあります!!!
    class RandomSampler(Sampler):
        r"""Samples elements randomly, without replacement.
    
        Arguments:
            data_source (Dataset): dataset to sample from
        """
    
        def __init__(self, data_source):
            self.data_source = data_source
    
        def __iter__(self):
            return iter(torch.randperm(len(self.data_source)).tolist())
    
        def __len__(self):
            return len(self.data_source)
    
    class BatchSampler(Sampler):
        r"""Wraps another sampler to yield a mini-batch of indices.
    
        Args:
            sampler (Sampler): Base sampler.
            batch_size (int): Size of mini-batch.
            drop_last (bool): If ``True``, the sampler will drop the last batch if
                its size would be less than ``batch_size``
    
        Example:
            >>> list(BatchSampler(SequentialSampler(range(10)), batch_size=3, drop_last=False))
            [[0, 1, 2], [3, 4, 5], [6, 7, 8], [9]]
            >>> list(BatchSampler(SequentialSampler(range(10)), batch_size=3, drop_last=True))
            [[0, 1, 2], [3, 4, 5], [6, 7, 8]]
        """
    
        def __init__(self, sampler, batch_size, drop_last):
            if not isinstance(sampler, Sampler):
                raise ValueError("sampler should be an instance of "
                                 "torch.utils.data.Sampler, but got sampler={}"
                                 .format(sampler))
            if not isinstance(batch_size, _int_classes) or isinstance(batch_size, bool) or \
                    batch_size <= 0:
                raise ValueError("batch_size should be a positive integeral value, "
                                 "but got batch_size={}".format(batch_size))
            if not isinstance(drop_last, bool):
                raise ValueError("drop_last should be a boolean value, but got "
                                 "drop_last={}".format(drop_last))
            self.sampler = sampler
            self.batch_size = batch_size
            self.drop_last = drop_last
    
        def __iter__(self):
            batch = []
            for idx in self.sampler:
                batch.append(idx)
                if len(batch) == self.batch_size:
                    yield batch
                    batch = []
            if len(batch) > 0 and not self.drop_last:
                yield batch
    
        def __len__(self):
            if self.drop_last:
                return len(self.sampler) // self.batch_size
            else:
                return (len(self.sampler) + self.batch_size - 1) // self.batch_size
    
  • そのうち_Data Loader Iterはdataloaderオブジェクトとして入力します.もしnum_workers=0よく分かります.num_ウォーカー!=0マルチスレッド機構を導入し、データローディングプロセスを加速する.
  • マルチスレッドがない場合:batch=self.com llate_fn([self.dataset[i]for i in indices])はindexをdataデータに変換して返します.self.dataset[i]は、datasetsオブジェクトの
  • を呼び出します.
    getitem()方法
  • マルチスレッドでは、各スレッドにインデックス・キューindex_を作成します.queues;一つのウォーカーを共有するレスリングqueueデータキュー!_にありますウォーカーカーカーloop方法でデータを読み込みます.
  • class _DataLoaderIter(object):
        r"""Iterates once over the DataLoader's dataset, as specified by the sampler"""
    
        def __init__(self, loader):
            self.dataset = loader.dataset
            self.collate_fn = loader.collate_fn
            self.batch_sampler = loader.batch_sampler
            self.num_workers = loader.num_workers
            self.pin_memory = loader.pin_memory and torch.cuda.is_available()
            self.timeout = loader.timeout
            self.done_event = threading.Event()
    
            self.sample_iter = iter(self.batch_sampler)
    
            base_seed = torch.LongTensor(1).random_().item()
    
            if self.num_workers > 0:
                self.worker_init_fn = loader.worker_init_fn
                self.index_queues = [multiprocessing.Queue() for _ in range(self.num_workers)]
                self.worker_queue_idx = 0
                self.worker_result_queue = multiprocessing.SimpleQueue()
                self.batches_outstanding = 0
                self.worker_pids_set = False
                self.shutdown = False
                self.send_idx = 0
                self.rcvd_idx = 0
                self.reorder_dict = {}
    
                self.workers = [
                    multiprocessing.Process(
                        target=_worker_loop,
                        args=(self.dataset, self.index_queues[i],
                              self.worker_result_queue, self.collate_fn, base_seed + i,
                              self.worker_init_fn, i))
                    for i in range(self.num_workers)]
    
                if self.pin_memory or self.timeout > 0:
                    self.data_queue = queue.Queue()
                    if self.pin_memory:
                        maybe_device_id = torch.cuda.current_device()
                    else:
                        # do not initialize cuda context if not necessary
                        maybe_device_id = None
                    self.worker_manager_thread = threading.Thread(
                        target=_worker_manager_loop,
                        args=(self.worker_result_queue, self.data_queue, self.done_event, self.pin_memory,
                              maybe_device_id))
                    self.worker_manager_thread.daemon = True
                    self.worker_manager_thread.start()
                else:
                    self.data_queue = self.worker_result_queue
    
                for w in self.workers:
                    w.daemon = True  # ensure that the worker exits on process exit
                    w.start()
    
                _update_worker_pids(id(self), tuple(w.pid for w in self.workers))
                _set_SIGCHLD_handler()
                self.worker_pids_set = True
    
                # prime the prefetch loop
                for _ in range(2 * self.num_workers):
                    self._put_indices()
    
        def __len__(self):
            return len(self.batch_sampler)
    
        def _get_batch(self):
            if self.timeout > 0:
                try:
                    return self.data_queue.get(timeout=self.timeout)
                except queue.Empty:
                    raise RuntimeError('DataLoader timed out after {} seconds'.format(self.timeout))
            else:
                return self.data_queue.get()
    
        def __next__(self):
            if self.num_workers == 0:  # same-process loading
                indices = next(self.sample_iter)  # may raise StopIteration
                batch = self.collate_fn([self.dataset[i] for i in indices])
                if self.pin_memory:
                    batch = pin_memory_batch(batch)
                return batch
    
            # check if the next sample has already been generated
            if self.rcvd_idx in self.reorder_dict:
                batch = self.reorder_dict.pop(self.rcvd_idx)
                return self._process_next_batch(batch)
    
            if self.batches_outstanding == 0:
                self._shutdown_workers()
                raise StopIteration
    
            while True:
                assert (not self.shutdown and self.batches_outstanding > 0)
                idx, batch = self._get_batch()
                self.batches_outstanding -= 1
                if idx != self.rcvd_idx:
                    # store out-of-order samples
                    self.reorder_dict[idx] = batch
                    continue
                return self._process_next_batch(batch)
    
        next = __next__  # Python 2 compatibility
    
        def __iter__(self):
            return self
    
        def _put_indices(self):
            assert self.batches_outstanding < 2 * self.num_workers
            indices = next(self.sample_iter, None)
            if indices is None:
                return
            self.index_queues[self.worker_queue_idx].put((self.send_idx, indices))
            self.worker_queue_idx = (self.worker_queue_idx + 1) % self.num_workers
            self.batches_outstanding += 1
            self.send_idx += 1
    
        def _process_next_batch(self, batch):
            self.rcvd_idx += 1
            self._put_indices()
            if isinstance(batch, ExceptionWrapper):
                raise batch.exc_type(batch.exc_msg)
            return batch
    
    def _worker_loop(dataset, index_queue, data_queue, collate_fn, seed, init_fn, worker_id):
        global _use_shared_memory
        _use_shared_memory = True
    
        # Intialize C side signal handlers for SIGBUS and SIGSEGV. Python signal
        # module's handlers are executed after Python returns from C low-level
        # handlers, likely when the same fatal signal happened again already.
        # https://docs.python.org/3/library/signal.html Sec. 18.8.1.1
        _set_worker_signal_handlers()
    
        torch.set_num_threads(1)
        random.seed(seed)
        torch.manual_seed(seed)
    
        if init_fn is not None:
            init_fn(worker_id)
    
        watchdog = ManagerWatchdog()
    
        while True:
            try:
                r = index_queue.get(timeout=MANAGER_STATUS_CHECK_INTERVAL)
            except queue.Empty:
                if watchdog.is_alive():
                    continue
                else:
                    break
            if r is None:
                break
            idx, batch_indices = r
            try:
                samples = collate_fn([dataset[i] for i in batch_indices])
            except Exception:
                data_queue.put((idx, ExceptionWrapper(sys.exc_info())))
            else:
                data_queue.put((idx, samples))
                del samples
    
  • はキュー操作が必要で、データをキャッシュして、ローディング速度を向上させます.