Pytochはデータセットのカスタム読み込みを実現します。
VOC 2012意味分割データセットの読み取りを例にとって、コードコメントを参照してください。
VocDataset.py
VocDataset.py
from PIL import Image
import torch
import torch.utils.data as data
import numpy as np
import os
import torchvision
import torchvision.transforms as transforms
import time
#VOC
VOC_COLORMAP = [[0, 0, 0], [128, 0, 0], [0, 128, 0], [128, 128, 0],
[0, 0, 128], [128, 0, 128], [0, 128, 128], [128, 128, 128],
[64, 0, 0], [192, 0, 0], [64, 128, 0], [192, 128, 0],
[64, 0, 128], [192, 0, 128], [64, 128, 128], [192, 128, 128],
[0, 64, 0], [128, 64, 0], [0, 192, 0], [128, 192, 0],
[0, 64, 128]]
# , ,
def voc_label_indices(colormap, colormap2label):
"""Assign label indices for Pascal VOC2012 Dataset."""
idx = ((colormap[:, :, 2] * 256 + colormap[ :, :,1]) * 256+ colormap[:, :,0])
#out = np.empty(idx.shape, dtype = np.int64)
out = colormap2label[idx]
out=out.astype(np.int64)#
end = time.time()
return out
class MyDataset(data.Dataset):#
def __init__(self, root, is_train, crop_size=(320,480)):
self.rgb_mean =(0.485, 0.456, 0.406)
self.rgb_std = (0.229, 0.224, 0.225)
self.root=root
self.crop_size=crop_size
images = []#
txt_fname = '%s/ImageSets/Segmentation/%s' % (root, 'train.txt' if is_train else 'val.txt')
with open(txt_fname, 'r') as f:
self.images = f.read().split()
#
self.files = []
for name in self.images:
img_file = os.path.join(self.root, "JPEGImages/%s.jpg" % name)
label_file = os.path.join(self.root, "SegmentationClass/%s.png" % name)
self.files.append({
"img": img_file,
"label": label_file,
"name": name
})
self.colormap2label = np.zeros(256**3)
#
for i, cm in enumerate(VOC_COLORMAP):
self.colormap2label[(cm[2] * 256 + cm[1]) * 256 + cm[0]] = i
#
def __getitem__(self, index):
datafiles = self.files[index]
name = datafiles["name"]
image = Image.open(datafiles["img"])
label = Image.open(datafiles["label"]).convert('RGB')# PNG rgb ,
# , 0
imgCenterCrop = transforms.Compose([
transforms.CenterCrop(self.crop_size),
transforms.ToTensor(),
transforms.Normalize(self.rgb_mean, self.rgb_std),#
])
labelCenterCrop = transforms.CenterCrop(self.crop_size)
cropImage=imgCenterCrop(image)
croplabel=labelCenterCrop(label)
croplabel=torch.from_numpy(np.array(croplabel)).long()# torch
#
mylabel=voc_label_indices(croplabel, self.colormap2label)
return cropImage,mylabel
#
def __len__(self):
return len(self.files)
Train.py
import matplotlib.pyplot as plt
import torch.utils.data as data
import torchvision.transforms as transforms
import numpy as np
from PIL import Image
from VocDataset import MyDataset
#VOC
VOC_COLORMAP = [[0, 0, 0], [128, 0, 0], [0, 128, 0], [128, 128, 0],
[0, 0, 128], [128, 0, 128], [0, 128, 128], [128, 128, 128],
[64, 0, 0], [192, 0, 0], [64, 128, 0], [192, 128, 0],
[64, 0, 128], [192, 0, 128], [64, 128, 128], [192, 128, 128],
[0, 64, 0], [128, 64, 0], [0, 192, 0], [128, 192, 0],
[0, 64, 128]]
root='../data/VOCdevkit/VOC2012'
train_data=MyDataset(root,True)
trainloader = data.DataLoader(train_data, 4)
#
for i, data in enumerate(trainloader):
getimgs, labels= data
img = transforms.ToPILImage()(getimgs[0])
labels = labels.numpy()#tensor numpy
labels=labels[0]#
labels = labels.transpose((1,0))# , 1 0 , 0 1
##
newIm= Image.new('RGB', (480, 320))# ,
for i in range(0, 480):
for j in range(0, 320):
sele=labels[i][j]#
newIm.putpixel((i, j), (int(VOC_COLORMAP[sele][0]), int(VOC_COLORMAP[sele][1]), int(VOC_COLORMAP[sele][2])))
#
plt.figure("image")
ax1 = plt.subplot(1,2,1)
ax2 = plt.subplot(1,2,2)
plt.sca(ax1)
plt.imshow(img)
plt.sca(ax2)
plt.imshow(newIm)
plt.show()
以上のPytouchはデータ集のカスタム読み取りを実現しました。つまり、小編集が皆さんに提供した内容は全部分かりました。参考にしてもらいたいです。どうぞよろしくお願いします。