自分で車輪を作る:dataloaderを深く勉強して自分で実現する
自分で車輪を作る:dataloaderを深く勉強して自分で実現する
**概要:**コンピュータのパフォーマンスが制限されているため、すべての深さ学習フレームワークは一括ランダム勾配降下を採用しているため、計算のたびにbatch_を読み込む必要があります.sizeのデータ.ここでは,深さ学習フレームワークによる一括読み出しを実現する原理を自己実現で紹介し,具体的な詳細や論理にかかわらず,概略フローと原理だけを重視する.
全体的なプロセス: yieldを用いて生成器関数を書くバッチピクチャ/寸法情報の読み出し を実現する. multiprocessing/threading加速ファイル読み出し を採用する.時間比較 深さ学習の概略フロー
Dataloaderでは、通常、複数のプロセス(num_workers)を使用してファイルI/Oの速度を速め、ネットワークの逆伝播を回避し、データがありません.
1.yieldでジェネレータ関数を書く
2.muitlprocessingを使用してファイルの読み込み速度を速める
100個のbatchを読み取ることで、時間が80倍に向上したことがわかります.また,一般的な深さ学習フレームワークでは,上の機能をいくつかのマルチプロセスで処理する.eg:
ネット上の資料によるとthreadingの効率はmuitlprocessingほど高くなく、ここではテストしません.
reference
[1]python[2]argman/EAST
**概要:**コンピュータのパフォーマンスが制限されているため、すべての深さ学習フレームワークは一括ランダム勾配降下を採用しているため、計算のたびにbatch_を読み込む必要があります.sizeのデータ.ここでは,深さ学習フレームワークによる一括読み出しを実現する原理を自己実現で紹介し,具体的な詳細や論理にかかわらず,概略フローと原理だけを重視する.
全体的なプロセス:
for i in range(epoch):
data, lable = dataloader.next(batch_size=16) # batch_size
output = model(data) #
loss = crition(output, label) #
loss.backward() #
Dataloaderでは、通常、複数のプロセス(num_workers)を使用してファイルI/Oの速度を速め、ネットワークの逆伝播を回避し、データがありません.
1.yieldでジェネレータ関数を書く
# coding:utf-8
# ,
import os
import glob
import numpy as np
import cv2
def get_images(path):
files = []
for ext in ['jpg', 'png', 'jpeg', 'JPG']:
files.extend(glob.glob(
os.path.join(path, '*.{}'.format(ext))))
return files
def dataset(batch_size=2, path='/media/chenjun/data/1_deeplearning/7_ammeter_data/test'):
"""
batch_size:
path:
"""
# 1.
image_list = get_images(path)
index = np.arange(0, len(image_list))
while True:
np.random.shuffle(index)
images = []
image_names = []
for i in index:
try:
im_name = image_list[i]
im = cv2.imread(im_name) #
#
# text_polys = fun1()
images.append(im[:,:, ::-1].astype(np.float32)) # cv2 BGR, RGB
image_names.append(im_name)
if len(images) == batch_size:
yield images, image_names # ,
images = []
image_names = []
except Exception as e:
import traceback
traceback.print_exc()
continue # , for ,
2.muitlprocessingを使用してファイルの読み込み速度を速める
<!-- , 100 batch -->
import time
mydataset = dataset()
start = time.time()
for _ in range(100):
im, im_name = next(mydataset)
# print(im_name)
print('use time:{}'.format(time.time() - start))
>>> use time:0.16786599159240723
<!-- muitlprocessing , 100 batch -->
import multiprocessing
def data_generator(data, q):
for _ in range(100): #
generator_output = next(data)
q.put(generator_output)
q = multiprocessing.Queue()
start2 = time.time()
thread = multiprocessing.Process(target=data_generator, args=(dataset(), q))
thread.start() #
print('mulprocess time is:{}'.format(time.time() - start2))
>>> mulprocess time is:0.002292633056640625
100個のbatchを読み取ることで、時間が80倍に向上したことがわかります.また,一般的な深さ学習フレームワークでは,上の機能をいくつかのマルチプロセスで処理する.eg:
for _ in range(workers):
if self._use_multiprocessing:
# Reset random seed else all children processes
# share the same seed
np.random.seed(self.random_seed)
thread = multiprocessing.Process(target=data_generator_task)
ネット上の資料によるとthreadingの効率はmuitlprocessingほど高くなく、ここではテストしません.
reference
[1]python[2]argman/EAST