[pytorch]Tensor,numpyとPIL形式の相互変換

12437 ワード

[pytorch]Tensor,numpyとPIL形式の相互変換
  • PILとTensor
  • TensorとNumpy
  • 画像展示
  • 枚以上の画像の変換
  • pytorchまたはpythonで一般的に処理される画像は、これらのフォーマットにほかならない.
  • PIL:python持参画像処理ライブラリを用いて読み取った画像フォーマット
  • .
  • Numpy:python-opencvライブラリで読み込まれたピクチャフォーマット
  • Tensor:pytorchで訓練中に採用されたベクトルフォーマット(注意、その後の説明画像フォーマットはすべてRGB 3チャネル、24-bit真色、つまり私たちが普段使っている画像形式である.
  • PILとTensor
  • PIL to Tensor
  • import torch
    from PIL import Image
    import matplotlib.pyplot as plt
    
    # loader  torchvision    transforms  
    loader = transforms.Compose([
        transforms.ToTensor()])  
    
    unloader = transforms.ToPILImage()
    
    
    #       
    #   tensor  
    def image_loader(image_name):
        image = Image.open(image_name).convert('RGB')
        image = loader(image).unsqueeze(0)
        return image.to(device, torch.float)
    
    
  • Tensor to PIL
  • #   tensor  
    #   PIL    
    def tensor_to_PIL(tensor):
        image = tensor.cpu().clone()
        image = image.squeeze(0)
        image = unloader(image)
        return image
    
    

    TensorとNumpy
  • Tensor to Numpy
  • import cv2
    import torch
    import matplotlib.pyplot as plt
    
    def tensor_to_np(tensor):
        img = tensor.mul(255).byte()
        img = img.cpu().numpy().squeeze(0).transpose((1, 2, 0))
        return img
    
  • Numpy to Tensor
  • def toTensor(img):
        assert type(img) == np.ndarray,'the img type is {}, but ndarry expected'.format(type(img))
        img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
        img = torch.from_numpy(img.transpose((2, 0, 1)))
        return img.float().div(255).unsqueeze(0)  # 255     256
    
    

    写真展示
  • 展示Tensor
  • def show_from_tensor(tensor, title=None):
        img = tensor.clone()
        img = tensor_to_np(img)
        plt.figure()
        plt.imshow(img)
        if title is not None:
            plt.title(title)
        plt.pause(0.001)
    
    
  • 展示Numpy
  • def show_from_cv(img, title=None):
        img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
        plt.figure()
        plt.imshow(img)
        if title is not None:
            plt.title(title)
        plt.pause(0.001)
    
    

    複数の画像の変換
    以上のコードはいずれも1枚の画像フォーマットの変換であり、複数に変更すると、コードを少し変更するだけで、以下の例になります.
    #   N x H x W X C  numpy          tensor  
    def toTensor(img):
        img = torch.from_numpy(img.transpose((0, 3, 1, 2)))
        return img.float().div(255).unsqueeze(0)