Pytorch:自分で作成したデータセットを定義する方法


本文は個人の知識学習の記録であり、将来は復習して振り返ることができる.
Pytorchでデータセットを定義するには、主に2つの主要なクラスに関連します.
  • Datasets
  • DataLoader

  • 1 Datasets
    1.1 Datasetsとは?
    Datasetsは私たちが使っているデータセットのライブラリで、pytorchはCifar 10、MNISTなど多くのデータセットを持っています.
    1.2 Datasetsを定義する理由
    Pytorchにはツール関数torchがあります.utils.Data.DataLoaderは、この関数を使用して、mini-batchを使用してデータセットをロードする準備をするときにマルチスレッド並列処理を使用することができ、データセットの準備を高速化することができます.Datasetsは、このツール関数を構築するインスタンスパラメータの1つです.
    1.3 Datasetsの定義方法
    Datasetクラスは、Pytorchの画像データセットで最も重要なクラスであり、Pytorchのすべてのデータセットロードクラスで継承すべき親でもあります.親クラスの2つのプライベートメンバー関数を再ロードする必要があります.そうしないと、エラー・プロンプトがトリガーされます.
    def __getitem__(self,index):
    
    def __len__(self):
  • __len__:戻りデータセットのサイズ
  • __getitem__:データセットインデックスをサポートする関数
  • の作成
    注意:
    ポイントはgetitem関数で、getitemはindexを受信し、ピクチャデータとラベルを返します.このindexは通常listのindexを指し、このlistの各要素にはピクチャデータのパスとラベル情報が含まれています.
    リストの作成方法
    通常の方法は、画像のパスとラベル情報をtxtに格納し、そのtxtから読み出すことである.
    1.3.1データ読み出しの基本フロー
  • ピクチャのパスとタグ情報を格納txt
  • を作成する.
  • は、これらの情報をlistに変換し、各要素はサンプル
  • に対応する.
  • getitem関数によりデータとラベルを読み出し、データとラベル
  • を返す.
    1.3.2 Datasetsの全体フレームワーク
    from torchvision.utils.data import Dataset
    
    class MyDataset(Dataset):#    Dataset
        def __init__(self):
            # TODO
            # 1.              。
            #         ,                     。
            pass
        def __getitem__(self, index):
            # TODO
    
            #1.          (  ,  numpy.fromfile,PIL.Image.open)。
            #2.     (  torchvision.Transform)。
            #3.     (       )。
            #        ,   :read one data,   data
            pass
        def __len__(self):
            #    0          。
    

    1.4栗
    この栗は画像分割実験で独自のデータセットを作成します.
    import os
    import cv2
    import numpy as np
    from torch.utils.data import Dataset
    
    
    #    
    class TrainDataset(Dataset):  #   Dataset
        def __init__(self, data_path, transform=None):  #     list,                txt 
            #     
            self.images = os.listdir(data_path + '/images')
            self.labels = os.listdir(data_path + '/masks')
    
            #              
            assert len(self.images) == len(self.labels), 'Number does not match'
    
            self.transform = transform  #        
    
            #        ,           ,        
            #   list  
            self.images_and_labels = []  #        
            for i in range(len(self.images)):  #         ,      
                self.images_and_labels.append(
                    (data_path + '/images/' + self.images[i], data_path + '/masks/' + self.labels[i])
                )
    
        def __getitem__(self, index):  #        ,        
            #     
            image_path, label_path = self.images_and_labels[index]
            #     
            image = cv2.imread(image_path)  #     ,(H,W,C)
            image = cv2.resize(image, (224, 224))  #         224*224
    
            #        ,         ,     
            label = cv2.imread(label_path, 0)  #     ,     
            label = cv2.resize(label, (224, 224))  #         224*224
            #        ,       .   0,    1
            label = label / 255  #        [0,1.0],,              
            label = label.astype('uint8')  #         ,       ,      1 ,  0。
    
            # one-hot  
            label = np.eye(2)[label]  #              
            label = np.array(list(map(lambda x: abs(x-1), label))).astype('float32')  # 0  1,1  0
            label = label.transpose(2, 0, 1)  # (H,W,C) => (C,H,W)
    
            if self.transform is not None:
                image = self.transform(image)
            return image, label  #     
    
        def __len__(self):  #    ,        
            return len(self.images)
    
    
    #    
    class TestDataset(Dataset):
        def __init__(self, data_path, transform=None):
            self.images = os.listdir(data_path + '/images')
            self.transform = transform
            self.imgs = []
            for i in range(len(self.images)):
                # self.imgs.append(data_path + '/images/' + self.images[i])
                self.imgs.append(os.path.join(data_path, 'images/', self.images[i]))
    
        def __getitem__(self, item):
            img_path = self.imgs[item]
            img = cv2.imread(img_path)
            img = cv2.resize(img, (224, 224))
    
            if self.transform is not None:
                img = self.transform(img)
            return img
    
        def __len__(self):
            return len(self.images)
    
    
    if __name__ == '__main__':
        img = cv2.imread('../data/train/masks/150.jpg', 0)
        img = cv2.resize(img, (16, 16))
        img2 = img / 255
        cv2.imshow('pic1', img2)
        cv2.waitKey()
        print(img2)
    
        img3 = img2.astype('uint8')
        cv2.imshow('pic2', img3)
        cv2.waitKey()
        print(img3)
    
        #           3 
        hot1 = np.eye(2)[img3]  #                ,(0,1)    ,(1,0)    
        print(hot1)
        print(hot1.ndim)
        print(hot1.shape)  # (16,16,2) C=16,H=16,W=16
    
        hot2 = np.array(list(map(lambda x: abs(x - 1), hot1))) #       。(1,0)    ,(0,1)    
        print(hot2)
        print(hot2.ndim)
        print(hot2.shape)  # (16,16,2) C=16,H=16,W=16
    
        hot3 = hot2.transpose(2, 0, 1)
        print(hot3)  # (C=2,H=16,W=16)
    

    2 DataLoader
    Datasetクラスは、データセットデータを読み込み、読み込んだデータをインデックスします.しかし、この機能だけでは十分ではありません.実際にデータセットをロードする過程で、私たちのデータ量は往々にして大きくなります.これにはいくつかの機能が必要です.
  • ロット別読み出し可能:batch-size
  • はデータをランダムに読み取ることができ、データをシャッフル操作(shuffling)することができ、データセット内のデータ分布の順序
  • を乱すことができる.
  • は、データを並列にロードすることができる(マルチコアプロセッサによるデータのロード効率の向上)
  • .
    この場合、Dataloaderクラスが必要です.一般的な操作はbatch_です.size(batchあたりのサイズ)、shuffle(shuffle操作を行うかどうか)、num_workers(データをロードするときにいくつかのサブプロセスを使用します).Dataloaderというクラスは、自分でコードを設計する必要はありません.DataLoaderクラスを利用して、私たちが設計したDatasetサブクラスを読み込むだけです.
    from torchvision.utils.data import DataLoader
    
    train_loader = DataLoader(dataset=train_data, batch_size=6, shuffle=True ,num_workers=4)
    test_loader = DataLoader(dataset=test_data, batch_size=6, shuffle=False,num_workers=4)
    

    参照先:https://blog.csdn.net/sinat_42239797/article/details/90641659