Pytorch:自分で作成したデータセットを定義する方法
5996 ワード
本文は個人の知識学習の記録であり、将来は復習して振り返ることができる.
Pytorchでデータセットを定義するには、主に2つの主要なクラスに関連します. Datasets DataLoader
1 Datasets
1.1 Datasetsとは?
Datasetsは私たちが使っているデータセットのライブラリで、pytorchはCifar 10、MNISTなど多くのデータセットを持っています.
1.2 Datasetsを定義する理由
Pytorchにはツール関数torchがあります.utils.Data.DataLoaderは、この関数を使用して、mini-batchを使用してデータセットをロードする準備をするときにマルチスレッド並列処理を使用することができ、データセットの準備を高速化することができます.Datasetsは、このツール関数を構築するインスタンスパラメータの1つです.
1.3 Datasetsの定義方法
Datasetクラスは、Pytorchの画像データセットで最も重要なクラスであり、Pytorchのすべてのデータセットロードクラスで継承すべき親でもあります.親クラスの2つのプライベートメンバー関数を再ロードする必要があります.そうしないと、エラー・プロンプトがトリガーされます. __len__:戻りデータセットのサイズ __getitem__:データセットインデックスをサポートする関数 の作成
注意:
ポイントはgetitem関数で、getitemはindexを受信し、ピクチャデータとラベルを返します.このindexは通常listのindexを指し、このlistの各要素にはピクチャデータのパスとラベル情報が含まれています.
リストの作成方法
通常の方法は、画像のパスとラベル情報をtxtに格納し、そのtxtから読み出すことである.
1.3.1データ読み出しの基本フローピクチャのパスとタグ情報を格納txt を作成する.は、これらの情報をlistに変換し、各要素はサンプル に対応する. getitem関数によりデータとラベルを読み出し、データとラベル を返す.
1.3.2 Datasetsの全体フレームワーク
1.4栗
この栗は画像分割実験で独自のデータセットを作成します.
2 DataLoader
Datasetクラスは、データセットデータを読み込み、読み込んだデータをインデックスします.しかし、この機能だけでは十分ではありません.実際にデータセットをロードする過程で、私たちのデータ量は往々にして大きくなります.これにはいくつかの機能が必要です.ロット別読み出し可能:batch-size はデータをランダムに読み取ることができ、データをシャッフル操作(shuffling)することができ、データセット内のデータ分布の順序 を乱すことができる.は、データを並列にロードすることができる(マルチコアプロセッサによるデータのロード効率の向上) .
この場合、Dataloaderクラスが必要です.一般的な操作はbatch_です.size(batchあたりのサイズ)、shuffle(shuffle操作を行うかどうか)、num_workers(データをロードするときにいくつかのサブプロセスを使用します).Dataloaderというクラスは、自分でコードを設計する必要はありません.DataLoaderクラスを利用して、私たちが設計したDatasetサブクラスを読み込むだけです.
参照先:https://blog.csdn.net/sinat_42239797/article/details/90641659
Pytorchでデータセットを定義するには、主に2つの主要なクラスに関連します.
1 Datasets
1.1 Datasetsとは?
Datasetsは私たちが使っているデータセットのライブラリで、pytorchはCifar 10、MNISTなど多くのデータセットを持っています.
1.2 Datasetsを定義する理由
Pytorchにはツール関数torchがあります.utils.Data.DataLoaderは、この関数を使用して、mini-batchを使用してデータセットをロードする準備をするときにマルチスレッド並列処理を使用することができ、データセットの準備を高速化することができます.Datasetsは、このツール関数を構築するインスタンスパラメータの1つです.
1.3 Datasetsの定義方法
Datasetクラスは、Pytorchの画像データセットで最も重要なクラスであり、Pytorchのすべてのデータセットロードクラスで継承すべき親でもあります.親クラスの2つのプライベートメンバー関数を再ロードする必要があります.そうしないと、エラー・プロンプトがトリガーされます.
def __getitem__(self,index):
def __len__(self):
注意:
ポイントはgetitem関数で、getitemはindexを受信し、ピクチャデータとラベルを返します.このindexは通常listのindexを指し、このlistの各要素にはピクチャデータのパスとラベル情報が含まれています.
リストの作成方法
通常の方法は、画像のパスとラベル情報をtxtに格納し、そのtxtから読み出すことである.
1.3.1データ読み出しの基本フロー
1.3.2 Datasetsの全体フレームワーク
from torchvision.utils.data import Dataset
class MyDataset(Dataset):# Dataset
def __init__(self):
# TODO
# 1. 。
# , 。
pass
def __getitem__(self, index):
# TODO
#1. ( , numpy.fromfile,PIL.Image.open)。
#2. ( torchvision.Transform)。
#3. ( )。
# , :read one data, data
pass
def __len__(self):
# 0 。
1.4栗
この栗は画像分割実験で独自のデータセットを作成します.
import os
import cv2
import numpy as np
from torch.utils.data import Dataset
#
class TrainDataset(Dataset): # Dataset
def __init__(self, data_path, transform=None): # list, txt
#
self.images = os.listdir(data_path + '/images')
self.labels = os.listdir(data_path + '/masks')
#
assert len(self.images) == len(self.labels), 'Number does not match'
self.transform = transform #
# , ,
# list
self.images_and_labels = [] #
for i in range(len(self.images)): # ,
self.images_and_labels.append(
(data_path + '/images/' + self.images[i], data_path + '/masks/' + self.labels[i])
)
def __getitem__(self, index): # ,
#
image_path, label_path = self.images_and_labels[index]
#
image = cv2.imread(image_path) # ,(H,W,C)
image = cv2.resize(image, (224, 224)) # 224*224
# , ,
label = cv2.imread(label_path, 0) # ,
label = cv2.resize(label, (224, 224)) # 224*224
# , . 0, 1
label = label / 255 # [0,1.0],,
label = label.astype('uint8') # , , 1 , 0。
# one-hot
label = np.eye(2)[label] #
label = np.array(list(map(lambda x: abs(x-1), label))).astype('float32') # 0 1,1 0
label = label.transpose(2, 0, 1) # (H,W,C) => (C,H,W)
if self.transform is not None:
image = self.transform(image)
return image, label #
def __len__(self): # ,
return len(self.images)
#
class TestDataset(Dataset):
def __init__(self, data_path, transform=None):
self.images = os.listdir(data_path + '/images')
self.transform = transform
self.imgs = []
for i in range(len(self.images)):
# self.imgs.append(data_path + '/images/' + self.images[i])
self.imgs.append(os.path.join(data_path, 'images/', self.images[i]))
def __getitem__(self, item):
img_path = self.imgs[item]
img = cv2.imread(img_path)
img = cv2.resize(img, (224, 224))
if self.transform is not None:
img = self.transform(img)
return img
def __len__(self):
return len(self.images)
if __name__ == '__main__':
img = cv2.imread('../data/train/masks/150.jpg', 0)
img = cv2.resize(img, (16, 16))
img2 = img / 255
cv2.imshow('pic1', img2)
cv2.waitKey()
print(img2)
img3 = img2.astype('uint8')
cv2.imshow('pic2', img3)
cv2.waitKey()
print(img3)
# 3
hot1 = np.eye(2)[img3] # ,(0,1) ,(1,0)
print(hot1)
print(hot1.ndim)
print(hot1.shape) # (16,16,2) C=16,H=16,W=16
hot2 = np.array(list(map(lambda x: abs(x - 1), hot1))) # 。(1,0) ,(0,1)
print(hot2)
print(hot2.ndim)
print(hot2.shape) # (16,16,2) C=16,H=16,W=16
hot3 = hot2.transpose(2, 0, 1)
print(hot3) # (C=2,H=16,W=16)
2 DataLoader
Datasetクラスは、データセットデータを読み込み、読み込んだデータをインデックスします.しかし、この機能だけでは十分ではありません.実際にデータセットをロードする過程で、私たちのデータ量は往々にして大きくなります.これにはいくつかの機能が必要です.
この場合、Dataloaderクラスが必要です.一般的な操作はbatch_です.size(batchあたりのサイズ)、shuffle(shuffle操作を行うかどうか)、num_workers(データをロードするときにいくつかのサブプロセスを使用します).Dataloaderというクラスは、自分でコードを設計する必要はありません.DataLoaderクラスを利用して、私たちが設計したDatasetサブクラスを読み込むだけです.
from torchvision.utils.data import DataLoader
train_loader = DataLoader(dataset=train_data, batch_size=6, shuffle=True ,num_workers=4)
test_loader = DataLoader(dataset=test_data, batch_size=6, shuffle=False,num_workers=4)
参照先:https://blog.csdn.net/sinat_42239797/article/details/90641659