pytorchは自分でトレーニングセットとテストセットを作ることを実現します


pytorchは画像認識に使用できますが、私たちが今ほとんど使っているのはMINISTとcifar 10ピクチャで、自分のトレーニングとテスト画像パスを使うには、トレーニングセットとテストセットを読み取るコードを作る必要があります.本稿ではpytorchがトレーニングセットの読み取りとテストセットの汎用コードを実現することについて述べる.
まず、画像の経路を読み取るフレームワーク:torch.utils.data.Datasetはpytorchがデータセットを表す抽象クラスであり、このクラスで自分のデータセットを処理する際にDatasetを継承しなければならない.次に、len(dataset)がデータセットのサイズgetitemを返すように、len(dataset)がi番目のデータサンプルを返すことができるように、次の関数を書き換える.
次に、img_を調整する必要がある具体的なコードクリップを示します.id=int(img_path[-12:-9])、この段落はファイル名を読むので、自分のファイル名に基づいてimg_を設定する必要がありますpathの数値.このファイルの名前はPathです.py
import glob
import torch
from torch.utils.data import Dataset
import numpy as np
from PIL import Image
import matplotlib.pyplot as plt

class path(Dataset):

    def __init__(self, root_path):
        self.mDataX = []
        self.mDataY = []

        for img_path in glob.glob(root_path + r'\*'):
            img = Image.open(img_path)
            img = img.convert('L')  #  
            # new_size = np.array(img.size) / 4
            # new_size = new_size.astype(int)
            # img = img.resize(new_size, Image.BILINEAR)  #  ( , )(640, 480) (160, 120)
            img_data = np.array(img, dtype=float)
            img_data = img_data.reshape(-1)
            self.mDataX.append(img_data)
            img_id = int(img_path[-12:-9])  ## , 
            self.mDataY.append(img_id)

        self.mDataX = torch.tensor(self.mDataX)
        self.mDataY = torch.tensor(self.mDataY)

    def __getitem__(self, data_index):
        input_tensor = torch.tensor(self.mDataX[data_index])
        output_tensor = torch.tensor(self.mDataY[data_index])
        return input_tensor, output_tensor

    def __len__(self):
        return len(self.mDataX)

次は上に書いてあるものを呼び出します.pyファイル
import torch
from Path import *

train_set = Path(r'C:\datasets\ \ ')
train_loader = torch.utils.data.DataLoader(train_set, batch_size=1, shuffle=False)  # , num_workers=8
print('OK! ', len(train_set), len(train_loader))

test_set = Path(r'C:\datasets\ \ ')
test_loader = torch.utils.data.DataLoader(test_set, batch_size=1, shuffle=False)  # , num_workers=8
print('OK! ', len(test_set), len(test_loader))

次は自分のトレーニングセットとテストセットの画像を使うことができます!!!