[論文実習]Pix 2 Pix


この記事は,CVPR,2017年に発表された従来の広告ネットワーク(Pix 2 Pix)を有する画像から画像への変換論文の実験である.
  • Summary link -> link
  • Pix2Pix

  • Unpaired Image-to-Image Translation using Cycle Consistent Adversarial Networks
  • https://learnopencv.com/paired-image-to-image-translation-pix2pix/

  • 画素から画像への変換または画素への変換と呼ばれるGANモデルは、Genretorが入力としてノイズベクトルを受け入れるため、従来のGANモデルのアルゴリズムを導入し、重要なアーキテクチャを変更する.画像から画像への変換は、ある領域の画像を別の領域に変換するタスクで、入力画像と出力画像のマッピングを学ぶことができます.したがって、トレーニングデータセットでは、異なる領域からのデータが使用されます.
    Image-to-Image翻訳はペアリング/ペアリングされていないものであってもよく、この文章ではペアリング翻訳について議論します!

    2. Applications of Pix2Pix

  • 白黒画像をカラー画像
  • に変換する.
  • エッジを有意義な写真に変換
  • 航空写真を地図
  • に変換
  • 低解像度画像を高解像度
  • に変換する.

    3. What is a Pix2Pix GAN?


    Pix 2 Pixは初めてBerkeley AI ResearchでCVPR 2017で発表された.上の論文は何度も引用されているので、非常に多く使われています.
    Pix 2 Pix GANはCGANのアイデアを拡張し、画像への変換を追加しました.Pix 2 Pix GANは、ジェネレータにノイズベクトルの概念を使用しない.
  • 画像はGenerator入力、翻訳後の画像出力
  • に入る.
  • Distributorは条件識別器であり、入力された画像と条件はrea l/kakeである.キャラクターは以前のようにreal/lakeを判断します.
  • Pix 2 Pix GANの最終目標は他のGANと同じである.ジェネレータにDistributorを欺く画像を生成させる.
  • 初期のGANアーキテクチャでは,ノイズバックグラウンドはランダムであり,他の出力の作成に寄与した.しかし,この概念はPix 2 Pixには適用されず,論文の著者らはGenerator出力において一定の確率を保つ方法を見つけた.

    4. UNET Generator


    の初期GANSアーキテクチャと同様に、ジェネレータは入力方式でノイズベクトルを受信するのではなく、入力方式で画像を受信し、ジェネレータを自動エンコーダとして構成する.したがって、Generatorにはコーデックネットワークがあります.Pix 2 PixはUNETをジェネレータとして使用し、ミラー層間にskip-connectionsがあることを特徴とする.Skip-connectionsは、AEがダウンサンプリング時に失われた情報(特に低層特性)を保持し、逆伝搬時の消失勾配を防止することができる.

    5. PatchGAN Discriminator


    他のGANSと同様、 Pix 2 PixのDistributorもrea l/akeを区別するためです.Pix 2 Pixは、確率値(スカラー)を出力としてではなく、領域内のTensor値を返すPatchGANというDistributorを使用します.すなわち、discriptorは、入力画像に対して、画像全体を一度に判断するのではなく、マトリクス値を返し、詳細領域に対して区切り値を返す.

    PatchGAN構造
  • Distributorは標準のConvolution-Batch Normalization-RELUブロック構造を採用し、
  • ネットワークは、rea l/lake予測の単一特徴マッピングを出力する.
  • 従来のリスナー:CGANの影響を受けて、モデムは所定の条件に基づいて入力画像の真偽を判断する.従って、入力はrea l/ake画像とcondition画像とが直列に接続されている.
  • Patch:[256256,n channels]レベルの入力画像はパッチに分類され、出力は[30,30]タイプのTenserとは逆になることを意味する.これは、各メッシュに[70,70]の面積のパッチの確率値が含まれていることを意味する.
    上記の図に示すように、出力予測マトリクスに格納される値は、rea l/lakeおよびfakeである.PatchGAN NxNにおけるパッチを判断した後,平均で最終結果Dを得た.上記論文では70 x 70サイズのパッチが最も有効である.PatchGANの長所は、選別器が本物だということですfakeとヘルプジェネレータを判断して騙すのが上手ですGeneratorが生成した画像がPatchGANに入ると、0でいっぱいのマトリクスを返すことを学びます.逆に、実際の画像については、PatchGANは満1のマトリクスを返すことを学習する.
  • 6. Pix2Pix Loss


    ジェネレータとDistributorを最適化するために,標準的な訓練法は勾配stepを用いてGとDを交互に学習することである.本論文では,Distributorの役割は不変であるが,Generatorの役割はDを欺き,2層目の距離をground‐truthに最も近づけることである.本論文ではL 1とL 2も比較し,L 1のぼかし度が低いことを示した.

    Discriminator


    Pix 2 Pix Distributorで使用した損失関数は以前のGANSモデルと同じである.すなわち、realとfakeを区別するために、負のlog-likelionを最小化することを目標としています.著者らは,Generatorより学習速度が速くなるのを防ぐため,2つの部分に分けた.

    Generator


    実際のラベル値はGeneratorを学習するために使用されます.さらに,誤りを最小化するために,本論文ではさらにL 1損失項を追加した.L 1 loss値は実際の答えと予測値との差であり、L 1制限で変換された画像がターゲット画像とあまり似ていなければ延長の役割を果たすことができる.

    Total Loss


    従来のGANモデルの損失関数には,前述のL 1損失項を加えると,最終的なPIX 2 PIX GANの総損失といえる.

    7. Pytorch Implementation


    import torch
    import torch.nn as nn
    import torch.nn.functional as F
    import torch.optim as optim
    import torch.utils.data as data
    from torch.utils.data import DataLoader
    import torchvision.transforms as transforms
    
    from PIL import Image
    import matplotlib.pyplot as plt
    from math import log10 # For metric function
    
    import os
    os.environ['CUDA_VISIBLE_DEVICES'] = '0'

    DataLoader

    # Load Dataset from ImageFolder
    class Dataset(data.Dataset): # torch기본 Dataset 상속받기
        def __init__(self, image_dir, direction):
            super(Dataset, self).__init__() # 초기화 상속
            self.direction = direction # 
            self.a_path = os.path.join(image_dir, "a") # a는 건물 사진
            self.b_path = os.path.join(image_dir, "b") # b는 Segmentation Mask
            self.image_filenames = [x for x in os.listdir(self.a_path)] # a 폴더에 있는 파일 목록
            self.transform = transforms.Compose([transforms.Resize((256, 256)), # 이미지 크기 조정
                                                transforms.ToTensor(), # Numpy -> Tensor
                                                 transforms.Normalize(mean=(0.5, 0.5, 0.5), 
                                                    std=(0.5, 0.5, 0.5)) # Normalization : -1 ~ 1 range
                                                ])
            self.len = len(self.image_filenames)
        
        def __getitem__(self, index):
            
            # 건물사진과 Segmentation mask를 각각 a,b 폴더에서 불러오기
            a = Image.open(os.path.join(self.a_path, self.image_filenames[index])).convert('RGB') # 건물 사진
            b = Image.open(os.path.join(self.b_path, self.image_filenames[index])).convert('RGB') # Segmentation 사진
            
            # 이미지 전처리
            a = self.transform(a)
            b = self.transform(b)
            
            if self.direction == "a2b": # 건물 -> Segmentation
                return a, b
            else:  # Segmentation -> 건물
                return b, a
        
        def __len__(self):
            return self.len
            
    train_dataset = Dataset("./data/facades/train/", "b2a")
    test_dataset = Dataset("./data/facades/test/", "b2a")
    
    train_loader = DataLoader(dataset=train_dataset, num_workers=0, batch_size=1, shuffle=True) # Shuffle
    test_loader = DataLoader(dataset=test_dataset, num_workers=0, batch_size=1, shuffle=False)
    num workerは、現在の作業環境でデータをロードするプロセスを制御するパラメータです.0はDefault値で、0はMain Processにデータをロードすることを示します.複数の処理を使用してデータをロードする場合は、プロセスの数に応じてパラメータ値を調整できます.
    # -1 ~ 1사이의 값을 0~1사이로 만들어준다
    def denorm(x):
        out = (x + 1) / 2
        return out.clamp(0, 1)
    
    # 이미지 시각화 함수
    def show_images(real_a, real_b, fake_b):
        plt.figure(figsize=(30,90))
        plt.subplot(131)
        plt.imshow(real_a.cpu().data.numpy().transpose(1,2,0))
        plt.xticks([])
        plt.yticks([])
        
        plt.subplot(132)
        plt.imshow(real_b.cpu().data.numpy().transpose(1,2,0))
        plt.xticks([])
        plt.yticks([])
        
        plt.subplot(133)
        plt.imshow(fake_b.cpu().data.numpy().transpose(1,2,0))
        plt.xticks([])
        plt.yticks([])
        
        plt.show()

    Conv & DeConv function

    # Conv -> Batchnorm -> Activate function Layer
    '''
    코드 단순화를 위한 convolution block 생성을 위한 함수
    Encoder에서 사용될 예정
    '''
    def conv(c_in, c_out, k_size, stride=2, pad=1, bn=True, activation='relu'):
        layers = []
        
        # Conv layer
        layers.append(nn.Conv2d(c_in, c_out, k_size, stride, pad, bias=False))
        
        # Batch Normalization
        if bn:
            layers.append(nn.BatchNorm2d(c_out))
        
        # Activation
        if activation == 'lrelu':
            layers.append(nn.LeakyReLU(0.2))
        elif activation == 'relu':
            layers.append(nn.ReLU())
        elif activation == 'tanh':
            layers.append(nn.Tanh())
        elif activation == 'none':
            pass
        
        return nn.Sequential(*layers)
    
    # Deconv -> BatchNorm -> Activate function Layer
    '''
    코드 단순화를 위한 convolution block 생성을 위한 함수
    Decoder에서 이미지 복원을 위해 사용될 예정
    '''
    def deconv(c_in, c_out, k_size, stride=2, pad=1, bn=True, activation='lrelu'):
        layers = []
        
        # Deconv.
        layers.append(nn.ConvTranspose2d(c_in, c_out, k_size, stride, pad, bias=False))
        
        # Batchnorm
        if bn:
            layers.append(nn.BatchNorm2d(c_out))
        
        # Activation
        if activation == 'lrelu':
            layers.append(nn.LeakyReLU(0.2))
        elif activation == 'relu':
            layers.append(nn.ReLU())
        elif activation == 'tanh':
            layers.append(nn.Tanh())
        elif activation == 'none':
            pass
                    
        return nn.Sequential(*layers)

    Generator - UNET

    class Generator(nn.Module):
        # initializers
        def __init__(self):
            super(Generator, self).__init__()
            # Unet encoder
            self.conv1 = conv(3, 64, 4, bn=False, activation='lrelu') # (B, 64, 128, 128)
            self.conv2 = conv(64, 128, 4, activation='lrelu') # (B, 128, 64, 64)
            self.conv3 = conv(128, 256, 4, activation='lrelu') # (B, 256, 32, 32)
            self.conv4 = conv(256, 512, 4, activation='lrelu') # (B, 512, 16, 16)
            self.conv5 = conv(512, 512, 4, activation='lrelu') # (B, 512, 8, 8)
            self.conv6 = conv(512, 512, 4, activation='lrelu') # (B, 512, 4, 4)
            self.conv7 = conv(512, 512, 4, activation='lrelu') # (B, 512, 2, 2)
            self.conv8 = conv(512, 512, 4, bn=False, activation='relu') # (B, 512, 1, 1)
    
            # Unet decoder
            self.deconv1 = deconv(512, 512, 4, activation='relu') # (B, 512, 2, 2)
            self.deconv2 = deconv(1024, 512, 4, activation='relu') # (B, 512, 4, 4)
            self.deconv3 = deconv(1024, 512, 4, activation='relu') # (B, 512, 8, 8) # Hint : U-Net에서는 Encoder에서 넘어온 Feature를 Concat합니다! (Channel이 2배)
            self.deconv4 = deconv(1024, 512, 4, activation='relu') # (B, 512, 16, 16)
            self.deconv5 = deconv(1024, 256, 4, activation='relu') # (B, 256, 32, 32)
            self.deconv6 = deconv(512, 128, 4, activation='relu') # (B, 128, 64, 64)
            self.deconv7 = deconv(256, 64, 4, activation='relu') # (B, 64, 128, 128)
            self.deconv8 = deconv(128, 3, 4, activation='tanh') # (B, 3, 256, 256)
    
        # forward method
        def forward(self, input):
            # Unet encoder
            e1 = self.conv1(input)
            e2 = self.conv2(e1)
            e3 = self.conv3(e2)
            e4 = self.conv4(e3)
            e5 = self.conv5(e4)
            e6 = self.conv6(e5)
            e7 = self.conv7(e6)
            e8 = self.conv8(e7)
                                  
            # Unet decoder
            d1 = F.dropout(self.deconv1(e8), 0.5, training=True)
            d2 = F.dropout(self.deconv2(torch.cat([d1, e7], 1)), 0.5, training=True)
            d3 = F.dropout(self.deconv3(torch.cat([d2, e6], 1)), 0.5, training=True)
            d4 = self.deconv4(torch.cat([d3, e5], 1))
            d5 = self.deconv5(torch.cat([d4, e4], 1))
            d6 = self.deconv6(torch.cat([d5, e3], 1))
            d7 = self.deconv7(torch.cat([d6, e2], 1))
            output = self.deconv8(torch.cat([d7, e1], 1))
            
            return output

    Discriminator - PatchGAN


    5つのボリュームの受信ドメイン、すなわちパッチのみを見て、本物か偽物かを判断します.
    class Discriminator(nn.Module):
        # initializers
        def __init__(self):
            super(Discriminator, self).__init__()
            self.conv1 = conv(6, 64, 4, bn=False, activation='lrelu')
            self.conv2 = conv(64, 128, 4, activation='lrelu')
            self.conv3 = conv(128, 256, 4, activation='lrelu')
            self.conv4 = conv(256, 512, 4, 1, 1, activation='lrelu')
            self.conv5 = conv(512, 1, 4, 1, 1, activation='none')
    
        # forward method
        def forward(self, input):
            out = self.conv1(input)
            out = self.conv2(out)
            out = self.conv3(out)
            out = self.conv4(out)
            out = self.conv5(out)
    
            return out

    Train

    # Generator와 Discriminator를 GPU로 보내기
    G = Generator().cuda()
    D = Discriminator().cuda()
    
    criterionL1 = nn.L1Loss().cuda()
    criterionMSE = nn.MSELoss().cuda()
    
    # Setup optimizer
    g_optimizer = optim.Adam(G.parameters(), lr=0.0002, betas=(0.5, 0.999))
    d_optimizer = optim.Adam(D.parameters(), lr=0.0002, betas=(0.5, 0.999))
    
    # Train
    for epoch in range(1, 100):
        for i, (real_a, real_b) in enumerate(train_loader, 1):
            # forward
            real_a, real_b = real_a.cuda(), real_b.cuda()
            real_label = torch.ones(1).cuda()
            fake_label = torch.zeros(1).cuda()
            
            fake_b = G(real_a) # G가 생성한 fake Segmentation mask
            
            #============= Train the discriminator =============#
            # train with fake
            fake_ab = torch.cat((real_a, fake_b), 1)
            pred_fake = D.forward(fake_ab.detach())
            loss_d_fake = criterionMSE(pred_fake, fake_label)
    
            # train with real
            real_ab = torch.cat((real_a, real_b), 1)
            pred_real = D.forward(real_ab)
            loss_d_real = criterionMSE(pred_real, real_label)
            
            # Combined D loss
            loss_d = (loss_d_fake + loss_d_real) * 0.5
            
            # Backprop + Optimize
            D.zero_grad()
            loss_d.backward()
            d_optimizer.step()
    
            #=============== Train the generator ===============#
            # First, G(A) should fake the discriminator
            fake_ab = torch.cat((real_a, fake_b), 1)
            pred_fake = D.forward(fake_ab)
            loss_g_gan = criterionMSE(pred_fake, real_label)
    
            # Second, G(A) = B
            loss_g_l1 = criterionL1(fake_b, real_b) * 10
            
            loss_g = loss_g_gan + loss_g_l1
            
            # Backprop + Optimize
            G.zero_grad()
            D.zero_grad()
            loss_g.backward()
            g_optimizer.step()
            
            if i % 200 == 0:
                print('======================================================================================================')
                print('Epoch [%d/%d], Step[%d/%d], d_loss: %.4f, g_loss: %.4f'
                      % (epoch, 100, i, len(train_loader), loss_d.item(), loss_g.item()))
                print('======================================================================================================')
                show_images(denorm(real_a.squeeze()), denorm(real_b.squeeze()), denorm(fake_b.squeeze()))
    結果の生成
    Epoch [99/100], Step[400/400], d_loss: 0.2500, g_loss: 1.5049
    ======================================================================================================```

    左の順番は、Facade入力画像、実画像、および生成された画像(Facade->RGB)です.

    Reference

  • Image-to-Image Translation with Conditional Adversarial Networks
  • Fastキャンパス、The Redが主宰する女性教授の実習資料の中で