TypeError: img should be PIL Image. Got class torch.Tensor

22331 ワード

背景:
pytorchではMNISTデータセットを使用して、以下のコードで可視化します.
import torch
import torch.nn as nn
import torchvision
from torchvision import datasets, transforms
from torch.autograd import Variable
import matplotlib.pyplot as plt

# part 1:       ,torch          API
mnist_train_dataset = datasets.MNIST(root="./data/",
                                      train=True,
                                      download=True,
                                      transform=
                                        transforms.Compose([transforms.ToTensor(), transforms.Normalize(mean=[0.5],std=[0.5]),transforms.Resize((28,28))])
                                    )
                                        
mnist_test_dataset = datasets.MNIST(root="./data/",
                                      train=False,
                                      download=True,
                                      transform = transforms.Compose([transforms.ToTensor(), transforms.Resize((28,28)))
                  )

# part 2:     , dataloader
data_loader_train = torch.utils.data.DataLoader(
    dataset=mnist_train_dataset,
    batch_size=128,
    shuffle=True
)

data_loader_test = torch.utils.data.DataLoader(
    dataset=mnist_test_dataset,
    batch_size = 1,
    shuffle=True
)


# part 3:      ,    
images,labels = next(iter(data_loader_train))
# TypeError: img should be PIL Image. Got 
img = torchvision.utils.make_grid(images)
img = img.numpy().transpose(1,2,0)
std=mean=[0.5,0.5,0.5]
img = img * std + mean
#   imshow   :Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).
#           
print([int(labels[i].numpy()) for i,label in enumerate(labels)])
plt.imshow(img)
plt.show()

次のエラーが発生します.
Traceback (most recent call last):
  File "d:/GitHub/studyNote/pytorch  /mnist.torch.py", line 45, in 
    images,labels = next(iter(data_loader_train))
  File "E:\ProgramData\Miniconda3\envs\pytorch\lib\site-packages\torch\utils\data\dataloader.py", line 560, in __next__
    batch = self.collate_fn([self.dataset[i] for i in indices])
  File "E:\ProgramData\Miniconda3\envs\pytorch\lib\site-packages\torch\utils\data\dataloader.py", line 560, in 
    batch = self.collate_fn([self.dataset[i] for i in indices])
  File "E:\ProgramData\Miniconda3\envs\pytorch\lib\site-packages\torchvision\datasets\mnist.py", line 95, in __getitem__
    img = self.transform(img)
  File "E:\ProgramData\Miniconda3\envs\pytorch\lib\site-packages\torchvision\transforms\transforms.py", line 61, in __call__
    img = t(img)
  File "E:\ProgramData\Miniconda3\envs\pytorch\lib\site-packages\torchvision\transforms\transforms.py", line 196, in __call__
    return F.resize(img, self.size, self.interpolation)
  File "E:\ProgramData\Miniconda3\envs\pytorch\lib\site-packages\torchvision\transforms\functional.py", line 229, in resize
    raise TypeError('img should be PIL Image. Got {}'.format(type(img)))
TypeError: img should be PIL Image. Got 

考え:
彼はPIL形式の画像が必要で、ちょうどtransformsの中で1つの方法があります:transforms.ToPILImage()で、それから
transform=transforms.Compose([
    						transforms.ToTensor(),                               
                              transforms.Normalize(mean=[0.5],std=[0.5
                              transforms.Resize([28,28]),
                              transforms.ToPILImage()
                             ])

しかし、やはり間違っています.
Traceback (most recent call last):
  File "d:/GitHub/studyNote/pytorch  /mnist.torch.py", line 45, in 
    images,labels = next(iter(data_loader_train))
  File "E:\ProgramData\Miniconda3\envs\pytorch\lib\site-packages\torch\utils\data\dataloader.py", line 560, in __next__
    batch = self.collate_fn([self.dataset[i] for i in indices])
  File "E:\ProgramData\Miniconda3\envs\pytorch\lib\site-packages\torch\utils\data\dataloader.py", line 560, in 
    batch = self.collate_fn([self.dataset[i] for i in indices])
  File "E:\ProgramData\Miniconda3\envs\pytorch\lib\site-packages\torchvision\datasets\mnist.py", line 95, in __getitem__
    img = self.transform(img)
  File "E:\ProgramData\Miniconda3\envs\pytorch\lib\site-packages\torchvision\transforms\transforms.py", line 61, in __call__
    img = t(img)
  File "E:\ProgramData\Miniconda3\envs\pytorch\lib\site-packages\torchvision\transforms\transforms.py", line 196, in __call__
    return F.resize(img, self.size, self.interpolation)
  File "E:\ProgramData\Miniconda3\envs\pytorch\lib\site-packages\torchvision\transforms\functional.py", line 229, in resize
    raise TypeError('img should be PIL Image. Got {}'.format(type(img)))
TypeError: img should be PIL Image. Got 

bingでstackoverflowにクエリーすると、類似のエラーが見つかりました.
train_transforms = transforms.Compose(
[transforms.Resize(255), 
transforms.CenterCrop(224), 
transforms.ToTensor(), 
transforms.RandomHorizontalFlip(), 
transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])])

TypeError: img should be PIL Image. Got
– from https://stackoverflow.com/questions/57079219/img-should-be-pil-image-got-class-torch-tensor
下の大神の解決策は:transforms.RandomHorizontalFlip() works on PIL.Images , not torch.Tensor . In your code above, you are applying transforms.ToTensor() prior to transforms.RandomHorizontalFlip() , which results in tensor. transforms.RandomHorizontalFlip() works on PIL.Images , not torch.Tensor . In your code above, you are applying transforms.ToTensor() prior to transforms.RandomHorizontalFlip() , which results in tensor.
But, as per the official pytorch documentation here,
transforms.RandomHorizontalFlip() horizontally flip the given PIL Image randomly with a given probability.
So, just change the order of your transformation in above code, like below:
train_transforms = transforms.Compose([transforms.Resize(255), 
                                       transforms.CenterCrop(224),  
                                       transforms.RandomHorizontalFlip(),
                                       transforms.ToTensor(), 
                                       transforms.Normalize([0.485, 0.456, 0.406], 										[0.229, 0.224, 0.225])])

順序の問題であることを発見し、交換する必要があり、ToTensorをRandomHorizontalFlipの後に置く.
解決:
この問題は私たちも同じ方法で試みた.
から
transform=transforms.Compose([
    						transforms.ToTensor(),                               
                              transforms.Normalize(mean=[0.5],std=[0.5
                              transforms.Resize([28,28]),
                              transforms.ToPILImage()
                             ])

変更後:
transform=transforms.Compose([
    						transforms.Resize([28,28]),
    						transforms.ToTensor(),                               
                              transforms.Normalize(mean=[0.5],std=[0.5])
                              # transforms.ToPILImage()
                             ])

この順番に興味を持って、また試してみました.
transform=transforms.Compose([
                              transforms.Scale([28,28]),
                              transforms.Normalize(mean=[0.5],std=[0.5]),
                              transforms.ToTensor()
                             ])

エラーの検出:
  File "d:/GitHub/studyNote/pytorch  /mnist.torch.py", line 47, in 
    images,labels = next(iter(data_loader_train))
  File "E:\ProgramData\Miniconda3\envs\pytorch\lib\site-packages\torch\utils\data\dataloader.py", line 560, in __next__
    batch = self.collate_fn([self.dataset[i] for i in indices])
  File "E:\ProgramData\Miniconda3\envs\pytorch\lib\site-packages\torch\utils\data\dataloader.py", line 560, in 
    batch = self.collate_fn([self.dataset[i] for i in indices])
  File "E:\ProgramData\Miniconda3\envs\pytorch\lib\site-packages\torchvision\datasets\mnist.py", line 95, in __getitem__
    img = self.transform(img)
  File "E:\ProgramData\Miniconda3\envs\pytorch\lib\site-packages\torchvision\transforms\transforms.py", line 61, in __call__
    img = t(img)
  File "E:\ProgramData\Miniconda3\envs\pytorch\lib\site-packages\torchvision\transforms\transforms.py", line 164, in __call__
    return F.normalize(tensor, self.mean, self.std, self.inplace)
  File "E:\ProgramData\Miniconda3\envs\pytorch\lib\site-packages\torchvision\transforms\functional.py", line 201, in normalize
    raise TypeError('tensor is not a torch image.')
TypeError: tensor is not a torch image.

ToTensorはNormalizeの前に行かなければならないようです.
皆さん、新しい発見があればコメントで補足できます