【Augmentation Zoo】RetinaNet+VOC+KITTIのデータプリプロセッシング-pytorch版


この間に見たデータ増強手法を統合し,VOCとKITTIデータでの効果を試験した.私の仕事はVOCとKITTIデータの前処理を完了し、RetinaNetのモデルコードはpytorch-retinanetから来ています.
このプロジェクトgithub倉庫は次のとおりです.https://github.com/zzl-pointcloud/MyRetinaNet0722.
 
目次
一、VOCデータの前処理
二、KITTIデータの前処理
三、Resizerクラスとcollater()クラス
1.Resizerクラス
2.Collaterクラス
コード全体の処理ロジックは次のとおりです.
  • はtorchを継承する.DatasetクラスVocDatasetsクラス、KittiDatasetsクラス、書き換え_などの新しいデータセットクラスを定義します.getitem__(image_index)関数は、ピクチャ番号を入力し、sample={'img':img,'annots':annots}を返す機能です.クラス内の他の関数は__にサービスされます.getitem__load_などの関数image(),load_annotations()など.
  • transformをDatasetに転送し、transform.Compose([fun1(), fun2(), ...]).ここでfunはobject継承クラスであり、その中の__を定義するcall__()は、関数として使用できるようにします.各ピクチャに対して関数fun 1(),fun 2(),....ここでfun()はデータ増強法の入口である.
  • sampler(データセットからサンプルを取るポリシー)処理後、データセットクラスはDataLoaderオブジェクトに変換される.samplerで設定したyieldにより、各データを反復して返します.
  • は、このデータ前処理部で完了し、モデルに送られてトレーニングを開始する.訓練のデータの流れは個人的に以下のように理解しています:
  • 各epochは、サンプリング本のポリシーは、ステップ3のsamplerによって決定されるまで、データ全体をモデル内で1回実行する.各epochにおいて、データセットN=batch_size * iter_num、各iter、順方向伝播、逆方向伝播、検証セットテスト、モデル保存など.
    retinanet
    
    optimizer = optim.Adam(retinanet.parameters(), lr=1e-5)
    
    
    for epoch_num in range(epochs):
        for iter_num, data in enumerate(dataloader_train):
            #  , loss
            retinanet.train()
            classification_loss, regression_loss = retinanet([data['img'].float, data['annot']])  
            classification_loss = classification_loss.mean()
            regression_loss = regression_loss.mean()
            loss = classification_loss + regression_loss
            
            # , 
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()
    
        """
        validation part
    
        """
    
    """
    test part
    
    """
    #  
    torch.save(retinanet, "model_final.pt")

    一、VOCデータの前処理

    class VocDataset(Dataset):
        def __init__(self,
                     root_dir,
                     image_set='train',         # train/val/test
                     years=['2007', '2012'],    #  2007+2012
                     transform=None,
                     keep_difficult=False
                     ):
            self.root_dir = root_dir
            self.years = years
            self.image_set = image_set
            self.transform = transform
            self.keep_difficult = keep_difficult
    
            self.categories = VOC_CLASSES
    
            self.name_2_label = dict(
                zip(self.categories, range(len(self.categories)))
            )
            self.label_2_name = {
                v: k
                for k, v in self.name_2_label.items()
            }
            self.ids = list()
            self.find_file_list()
    
        def __len__(self):
            return len(self.ids)
    
        def __getitem__(self, image_index):
    
            img = self.load_image(image_index)
            annots = self.load_annotations(image_index)
            sample = {'img':img, 'annot':annots}
            if self.transform:
                sample = self.transform(sample)
            return sample
    
        def find_file_list(self):
            for year in self.years:
                if not (year == '2012' and self.image_set == 'test'):
                    root_path = os.path.join(self.root_dir, 'VOC' + year)
                    file_path = os.path.join(root_path, 'ImageSets', 'Main', self.image_set + '.txt')
                    for line in open(file_path):
                        self.ids.append((root_path, line.strip()))
    
        def load_image(self, image_index):
            image_root_dir, img_idx = self.ids[image_index]
            image_path = os.path.join(image_root_dir,
                                     'JPEGImages', img_idx + '.jpg')
            img = cv2.imread(image_path)
            if len(img.shape) == 2:
                img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
    
            return img.astype(np.float32)/255.0
    
        def load_annotations(self, image_index):
            image_root_dir, img_idx = self.ids[image_index]
            anna_path = os.path.join(image_root_dir,
                                    'Annotations', img_idx + '.xml')
            annotations = []
            target = ET.parse(anna_path).getroot()
            for obj in target.iter("object"):
                difficult = int(obj.find('difficult').text) == 1
                if not self.keep_difficult and difficult:
                    continue
                bbox = obj.find('bndbox')
    
                pts = ['xmin', 'ymin', 'xmax', 'ymax']
    
                bndbox = []
                for pt in pts:
                    cut_pt = bbox.find(pt).text
                    bndbox.append(float(cut_pt))
                name = obj.find('name').text.lower().strip()
                label = self.name_2_label[name]
                bndbox.append(label)
                annotations += [bndbox]
            annotations = np.array(annotations)
    
            return annotations
    
        def label_to_name(self, voc_label):
            return self.label_2_name[voc_label]
    
        def name_to_label(self, voc_name):
            return self.name_2_label[voc_name]
    
        def image_aspect_ratio(self, image_index):
            image_root_dir, img_idx = self.ids[image_index]
            image_path = os.path.join(image_root_dir,
                                      'JPEGImages', img_idx + '.jpg')
            img = cv2.imread(image_path)
            return float(img.shape[1] / float(img.shape[0]))
    
        def num_classes(self):
            return 20

    二、KITTIデータの前処理


    KITTIに対するデータの前処理コードはVOCと似ているが、KittiDatasetクラスを初期化する前に、KITTIデータセットを手動でトレーニング/検証セットに分割し、VOCに類似するtrainを生成する必要がある.txtとval.txtファイル.そこで私はSplitKittiDatasetクラス(tools.py)を実現しました.
    1.ファイル名リスト、およびlenを取得
    2.range(len)でindexを生成し、乱した後、区分割合でtrain_をとるindexとval_index、listから対応するファイル名を取ります
    3.txtファイルに保存します.
    class KittiDataset(Dataset):
        def __init__(self,
                     root_dir,
                     sets,
                     transform=None,
                     keep_difficult=False
                     ):
            self.root_dir = root_dir
            self.sets = sets
            self.transform = transform
            self.keep_difficult = keep_difficult
    
            self.categories = KITTI_CLASSES
    
            self.name_2_label = dict(
                zip(self.categories, range(len(self.categories)))
            )
            self.label_2_name = {
                v: k
                for k, v in self.name_2_label.items()
            }
            self.ids = list()
            self.find_file_list()
    
        def __len__(self):
            return len(self.ids)
    
        def __getitem__(self, image_index):
            img = self.load_image(image_index)
            annot = self.load_annotations(image_index)
            sample = {'img':img, 'annot':annot}
            if self.transform:
                sample = self.transform(sample)
            return sample
    
        def find_file_list(self):
            file_path = os.path.join(self.root_dir, self.sets + '.txt')
            for line in open(file_path):
                self.ids.append(line.strip())
    
        def load_image(self, image_index):
            img_idx = self.ids[image_index]
            image_path = os.path.join(self.root_dir,
                                     'image_2', img_idx + '.png')
            img = cv2.imread(image_path)
            if len(img.shape) == 2:
                img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
    
            return img.astype(np.float32)/255.0
    
        def load_annotations(self, image_index):
            img_idx = self.ids[image_index]
            anna_path = os.path.join(self.root_dir,
                                    'label_2', img_idx + '.txt')
            annotations = []
            with open(anna_path) as file:
                lines = file.readlines()
                for line in lines:
                    items = line.split(" ")
                    name = items[0].lower().strip()
                    if name == 'dontcare':
                        continue
                    else:
                        bndbox = [float(items[i+4]) for i in range(4)]
                        label = self.name_2_label[name]
                        bndbox.append(int(label))
                    annotations.append(bndbox)
            annotations = np.array(annotations)
            return annotations
    
        def label_to_name(self, voc_label):
            return self.label_2_name[voc_label]
    
        def name_to_label(self, voc_name):
            return self.name_2_label[voc_name]
    
        def image_aspect_ratio(self, image_index):
            img_idx = self.ids[image_index]
            image_path = os.path.join(self.root_dir,
                                      'image_2', img_idx + '.png')
            img = cv2.imread(image_path)
            return float(img.shape[1] / float(img.shape[0]))
    
        def num_classes(self):
            return 8

    三、Resizerクラスとcollater()クラス


    それぞれ、画像を限定サイズと位置合わせに変更するために使用します.

    1.Resizerクラス

     , 608/1024
    scale =   /  
    if   * scale >  :
        scale =   /  
    resized_image = cv2.resize(image,   * scale,   * scale)
    
     resized_image 32 

    2.Collaterクラス


    画像はデータセットの最長幅と最長高さ(例えば、長辺上限*長辺上限)に塗りつぶされ、画像は左上隅から揃えられ、残りの部分は0に塗りつぶされます.
    annots充填はsample単位で最大目標数に拡張され、残りの充填-1