pytorchは自分でトレーニングセットとテストセットを作ることを実現します
9460 ワード
pytorchは画像認識に使用できますが、私たちが今ほとんど使っているのはMINISTとcifar 10ピクチャで、自分のトレーニングとテスト画像パスを使うには、トレーニングセットとテストセットを読み取るコードを作る必要があります.本稿ではpytorchがトレーニングセットの読み取りとテストセットの汎用コードを実現することについて述べる.
まず、画像の経路を読み取るフレームワーク:torch.utils.data.Datasetはpytorchがデータセットを表す抽象クラスであり、このクラスで自分のデータセットを処理する際にDatasetを継承しなければならない.次に、len(dataset)がデータセットのサイズgetitemを返すように、len(dataset)がi番目のデータサンプルを返すことができるように、次の関数を書き換える.
次に、img_を調整する必要がある具体的なコードクリップを示します.id=int(img_path[-12:-9])、この段落はファイル名を読むので、自分のファイル名に基づいてimg_を設定する必要がありますpathの数値.このファイルの名前はPathです.py
次は上に書いてあるものを呼び出します.pyファイル
次は自分のトレーニングセットとテストセットの画像を使うことができます!!!
まず、画像の経路を読み取るフレームワーク:torch.utils.data.Datasetはpytorchがデータセットを表す抽象クラスであり、このクラスで自分のデータセットを処理する際にDatasetを継承しなければならない.次に、len(dataset)がデータセットのサイズgetitemを返すように、len(dataset)がi番目のデータサンプルを返すことができるように、次の関数を書き換える.
次に、img_を調整する必要がある具体的なコードクリップを示します.id=int(img_path[-12:-9])、この段落はファイル名を読むので、自分のファイル名に基づいてimg_を設定する必要がありますpathの数値.このファイルの名前はPathです.py
import glob
import torch
from torch.utils.data import Dataset
import numpy as np
from PIL import Image
import matplotlib.pyplot as plt
class path(Dataset):
def __init__(self, root_path):
self.mDataX = []
self.mDataY = []
for img_path in glob.glob(root_path + r'\*'):
img = Image.open(img_path)
img = img.convert('L') #
# new_size = np.array(img.size) / 4
# new_size = new_size.astype(int)
# img = img.resize(new_size, Image.BILINEAR) # ( , )(640, 480) (160, 120)
img_data = np.array(img, dtype=float)
img_data = img_data.reshape(-1)
self.mDataX.append(img_data)
img_id = int(img_path[-12:-9]) ## ,
self.mDataY.append(img_id)
self.mDataX = torch.tensor(self.mDataX)
self.mDataY = torch.tensor(self.mDataY)
def __getitem__(self, data_index):
input_tensor = torch.tensor(self.mDataX[data_index])
output_tensor = torch.tensor(self.mDataY[data_index])
return input_tensor, output_tensor
def __len__(self):
return len(self.mDataX)
次は上に書いてあるものを呼び出します.pyファイル
import torch
from Path import *
train_set = Path(r'C:\datasets\ \ ')
train_loader = torch.utils.data.DataLoader(train_set, batch_size=1, shuffle=False) # , num_workers=8
print('OK! ', len(train_set), len(train_loader))
test_set = Path(r'C:\datasets\ \ ')
test_loader = torch.utils.data.DataLoader(test_set, batch_size=1, shuffle=False) # , num_workers=8
print('OK! ', len(test_set), len(test_loader))
次は自分のトレーニングセットとテストセットの画像を使うことができます!!!