PyTorchはDatasetとDataloaderが遭遇した問題を解決します。


今日はPyTorchを使っていますが、Datasetが問題になりました。コードを先に見ます

class psDataset(Dataset):
  def __init__(self, x, y, transforms = None):
    super(Dataset, self).__init__()
    self.x = x
    self.y = y
    if transforms == None:
      self.transforms = Compose([Resize((224, 224)), ToTensor()])
    else:
      self.transforms = transforms
    
  def __len__(self):
    return len(self.x)
  
  def __getitem__(self, idx):
    img = Image.open(self.x[idx])
    img = self.transforms(img)    
    return img, torch.tensor([[self.y[idx]]])
結果運行時報が間違っています。RuntimeError:invalid argment 0:Sizes of tens must match except in dimension 0.Got 3 and 1 in dimension 1 at/opt/conda/conda-bld/pytouch_152282087074/work/touch/lib/TH/generic/THTensorMath.cc:2897
Googleが発見したのは、読み込んだ画像には階調図(1チャネル)があり、ほとんどがRGB画像(3チャネル)であり、透明度のあるものもある(4チャネル)。
。これは、読み込み後の最後の次元(チャネル数)が一致しない(おそらく1、3、または4)ことをもたらします。
Dataloaderがbatch dataを作る時、tenssorのshopは同じでなければいけません。このエラーを報告しました。解決の方法は、[img]img.com nvert(「RGB」)です。終了
コード全体は以下の通りです

class psDataset(Dataset):
  def __init__(self, x, y, transforms = None):
    super(Dataset, self).__init__()
    self.x = x
    self.y = y
    if transforms == None:
      self.transforms = Compose([Resize((224, 224)), ToTensor()])
    else:
      self.transforms = transforms
    
  def __len__(self):
    return len(self.x)
  
  def __getitem__(self, idx):
    img = Image.open(self.x[idx])
    img = img.convert("RGB")
    img = self.transforms(img)    
    return img, torch.tensor([[self.y[idx]]])
以上のPyTorchでDatasetとDataloaderが遭遇した問題を解決しました。小編集が皆さんに共有した内容は全部です。参考にしてもらいたいです。皆さんも応援してください。