自分で車輪を作る:dataloaderを深く勉強して自分で実現する

14956 ワード

自分で車輪を作る:dataloaderを深く勉強して自分で実現する
**概要:**コンピュータのパフォーマンスが制限されているため、すべての深さ学習フレームワークは一括ランダム勾配降下を採用しているため、計算のたびにbatch_を読み込む必要があります.sizeのデータ.ここでは,深さ学習フレームワークによる一括読み出しを実現する原理を自己実現で紹介し,具体的な詳細や論理にかかわらず,概略フローと原理だけを重視する.
全体的なプロセス:
  • yieldを用いて生成器関数を書くバッチピクチャ/寸法情報の読み出し
  • を実現する.
  • multiprocessing/threading加速ファイル読み出し
  • を採用する.
  • 時間比較
  • 深さ学習の概略フロー
    for i in range(epoch):
        data, lable = dataloader.next(batch_size=16)         #   batch_size   
        output = model(data)            #     
        loss = crition(output, label)   #      
        loss.backward()                 #     
    

    Dataloaderでは、通常、複数のプロセス(num_workers)を使用してファイルI/Oの速度を速め、ネットワークの逆伝播を回避し、データがありません.
    1.yieldでジェネレータ関数を書く
    # coding:utf-8
    #      ,             
    import os
    import glob
    import numpy as np 
    import cv2  
    
    
    def get_images(path):
        files = []
        for ext in ['jpg', 'png', 'jpeg', 'JPG']:
            files.extend(glob.glob(
                os.path.join(path, '*.{}'.format(ext))))
        return files
    
    
    def dataset(batch_size=2, path='/media/chenjun/data/1_deeplearning/7_ammeter_data/test'):
        """
                       
            batch_size:    
            path:    
        """
        # 1.         
        image_list = get_images(path)
        index = np.arange(0, len(image_list))
        while True:
            np.random.shuffle(index)
            images = []
            image_names = []
            for i in index:
                try:
                    im_name = image_list[i]
                    im = cv2.imread(im_name)    #     
                    #            
                    # text_polys = fun1()
                    images.append(im[:,:, ::-1].astype(np.float32))     # cv2        BGR,   RGB  
                    image_names.append(im_name)
    
                    if len(images) == batch_size:
                        yield images, image_names        #        ,         
                        images = []
                        image_names = []
                
                except Exception as e:
                    import traceback
                    traceback.print_exc()
                    continue                #           ,  for  ,               
    

    2.muitlprocessingを使用してファイルの読み込み速度を速める
    <!--100 batch -->
    import time
    mydataset = dataset()
    start = time.time()
    for _ in range(100):
        im, im_name = next(mydataset)
    #     print(im_name)
    print('use time:{}'.format(time.time() - start))
    >>>  use time:0.16786599159240723
    
    
    <!--   muitlprocessing        ,  100 batch -->
    import multiprocessing
    def data_generator(data, q):
        for _ in range(100):                #      
            generator_output = next(data)
            q.put(generator_output)
    
    q = multiprocessing.Queue()
    start2 = time.time()
    thread = multiprocessing.Process(target=data_generator, args=(dataset(), q))
    thread.start()              #          
    print('mulprocess time is:{}'.format(time.time() - start2))
    >>>  mulprocess time is:0.002292633056640625
    

    100個のbatchを読み取ることで、時間が80倍に向上したことがわかります.また,一般的な深さ学習フレームワークでは,上の機能をいくつかのマルチプロセスで処理する.eg:
    for _ in range(workers):
                    if self._use_multiprocessing:
                        # Reset random seed else all children processes
                        # share the same seed
                        np.random.seed(self.random_seed)
                        thread = multiprocessing.Process(target=data_generator_task)
                        
    

    ネット上の資料によるとthreadingの効率はmuitlprocessingほど高くなく、ここではテストしません.
    reference
    [1]python[2]argman/EAST