[PyTorch小试牛刀]実戦六・自分のデータセットを用意して训练に使う(猫と犬の大戦データセットに基づく)
[PyTorch小试牛刀]実戦六・自分のデータセットを用意して训练に使う(猫と犬の大戦データセットに基づく)
上記のいくつかの実戦では、Pytorch公式に準備されたFashionMNISTデータセットを使用してトレーニングとテストを行いました.このブログでは、より多くのシーンに対応するために、データセットを自分で準備する方法について説明します.
私たちが今回使っているのは猫と犬の大戦データセットで、始まる前にデータを処理して、形式は以下の通りです.
datas │ └───train │ │ │ └───cats │ │ │ cat1000.jpg │ │ │ cat1001.jpg │ │ │ … │ └───dogs │ │ │ dog1000.jpg │ │ │ dog1001.jpg │ │ │ … └───valid │ │ │ └───cats │ │ │ cat0.jpg │ │ │ cat1.jpg │ │ │ … │ └───dogs │ │ │ dog0.jpg │ │ │ dog1.jpg │ │ │ …
trainデータセットには23000個のデータがあり、validデータセットには2000個のデータがネットワーク性能を検証するために使用されている.
コード部分1.ステルス辞書形式を採用し、コードが簡潔で、理解しにくい
2.顕性辞書形式を採用し、コードが少し多く、理解しやすい
出力結果
上記のいくつかの実戦では、Pytorch公式に準備されたFashionMNISTデータセットを使用してトレーニングとテストを行いました.このブログでは、より多くのシーンに対応するために、データセットを自分で準備する方法について説明します.
私たちが今回使っているのは猫と犬の大戦データセットで、始まる前にデータを処理して、形式は以下の通りです.
datas │ └───train │ │ │ └───cats │ │ │ cat1000.jpg │ │ │ cat1001.jpg │ │ │ … │ └───dogs │ │ │ dog1000.jpg │ │ │ dog1001.jpg │ │ │ … └───valid │ │ │ └───cats │ │ │ cat0.jpg │ │ │ cat1.jpg │ │ │ … │ └───dogs │ │ │ dog0.jpg │ │ │ dog1.jpg │ │ │ …
trainデータセットには23000個のデータがあり、validデータセットには2000個のデータがネットワーク性能を検証するために使用されている.
コード部分1.ステルス辞書形式を採用し、コードが簡潔で、理解しにくい
import torch as t
import torchvision as tv
import os
data_dir = "./datas"
BATCH_SIZE = 100
EPOCH = 10
transform = {
x:tv.transforms.Compose(
[tv.transforms.Resize([64,64]),tv.transforms.ToTensor()]#tv.transforms.Resize
)
for x in ["train","valid"]
}
datasets = {
x:tv.datasets.ImageFolder(root = os.path.join(data_dir,x),transform=transform[x])
for x in ["train","valid"]
}
dataloader = {
x:t.utils.data.DataLoader(dataset= datasets[x],
batch_size=BATCH_SIZE,
shuffle=True
)
for x in ["train","valid"]
}
b_x,b_y = next(iter(dataloader["train"]))
print(b_x.shape,b_y.shape)
index_classes = datasets["train"].class_to_idx
print(index_classes)
2.顕性辞書形式を採用し、コードが少し多く、理解しやすい
import torch as t
import torchvision as tv
data_dir = "./datas"
BATCH_SIZE = 100
EPOCH = 10
transform = {
"train":tv.transforms.Compose(
[tv.transforms.Resize([64,64]),tv.transforms.ToTensor()]
),
"valid":tv.transforms.Compose(
[tv.transforms.Resize([64,64]),tv.transforms.ToTensor()]
),
}
datasets = {
"train":tv.datasets.ImageFolder(root = os.path.join(data_dir,"train"),transform=transform["train"]),
"vaild":tv.datasets.ImageFolder(root = os.path.join(data_dir,"vaild"),transform=transform["vaild"]),
}
dataloader = {
"train":t.utils.data.DataLoader(dataset= datasets["train"],
batch_size=BATCH_SIZE,
shuffle=True
),
"valid":t.utils.data.DataLoader(dataset= datasets["valid"],
batch_size=100,
shuffle=True
)
}
b_x,b_y = next(iter(dataloader["train"]))
print(b_x.shape,b_y.shape)
index_classes = datasets["train"].class_to_idx
print(index_classes)
出力結果
torch.Size([100, 3, 64, 64]) torch.Size([100])
{'cats': 0, 'dogs': 1}