[第7週目]MNIST GAN例


データセットの準備


トレーニングデータ
テストデータ

dataloader

class MnistDataset(Dataset):
	def __init__(self, csv_file):
		self.data_df = pd.read_csv(csv_file, header=None)

	# Dataset을 위한 특수 메소드, 데이터셋의 길이를 반환
	def __len__(self):
		return len(self.data_df)

	# Dataset을 위한 특수 메소드, n번째 아이템을 반환
	def __getitem__(self, index):
		label = self.data_df.iloc[index, 0]
		# 10개 숫자중 label의 숫자에만 1로 one-hot encoding
		target = torch.zeros((10))
		target[label] = 1.0

		# 0-255의 이미지를 0-1로 정규화
		image_values = torch.FloatTensor(self.data_df.iloc[index, 1:].values) / 255.0

		# 레이블, 이미지 데이터 센서, 목표 텐서 반환
		return label, image_values, target

	def plot_image(self, index):
		img = self.data_df.iloc[index, 1:].values.reshape(28, 28)
		plt.title("label = " + str(self.data_df.iloc[index, 0]))
		plt.imshow(img, interpolation='none', cmap = "Blues")
		plt.show()
mnist_dataset = MnistDataset('data/mnist_train.csv')

mnist_dataset.plot_image(17)

判別器

# discriminator class

class Discriminator(nn.Module):
    
    def __init__(self):
        # initialise parent pytorch class
        super().__init__()
        
        # define neural network layers
        self.model = nn.Sequential(
            nn.Linear(784, 200),
            nn.Sigmoid(),
            nn.Linear(200, 1),
            nn.Sigmoid()
        )
        
        # create loss function
        self.loss_function = nn.MSELoss()

        # create optimiser, simple stochastic gradient descent
        self.optimiser = torch.optim.SGD(self.parameters(), lr=0.01)

        # counter and accumulator for progress
        self.counter = 0;
        self.progress = []
    
    def forward(self, inputs):
        # simply run model
        return self.model(inputs)
    
    def train(self, inputs, targets):
        # calculate the output of the network
        outputs = self.forward(inputs)
        
        # calculate loss
        loss = self.loss_function(outputs, targets)

        # increase counter and accumulate error every 10
        self.counter += 1;
        if (self.counter % 10 == 0):
            self.progress.append(loss.item())
        if (self.counter % 10000 == 0):
            print("counter = ", self.counter)

        # zero gradients, perform a backward pass, update weights
        self.optimiser.zero_grad()
        loss.backward()
        self.optimiser.step()
    
    
    def plot_progress(self):
        df = pandas.DataFrame(self.progress, columns=['loss'])
        df.plot(ylim=(0, 1.0), figsize=(16,8), alpha=0.1, marker='.', grid=True, yticks=(0, 0.25, 0.5))
識別器の性能検査(真偽を区別する能力があるかどうか)
def generate_random(size):
    random_data = torch.rand(size)
    return random_data

D = Discriminator()

for label, image_data_tensor, target_tensor in mnist_dataset:
    D.train(image_data_tensor, torch.FloatTensor([1.0]))
    D.train(generate_random(784), torch.FloatTensor([0.0]))
    
    
for i in range(4):
    image_data_tensor = mnist_dataset[random.randint(0, 60000)][1]
    print(D.forward(image_data_tensor).item())
    
for i in range(4):
    print(D.forward(generate_random(784)).item())
![](https://media.vlpt.us/images/withdongyeong/post/eba029bd-019f-4874-ad27-d49140ab0ef4/image.png)

ビルダー

# generator class

class Generator(nn.Module):
    
    def __init__(self):
        # initialise parent pytorch class
        super().__init__()
        
        # define neural network layers
        self.model = nn.Sequential(
            nn.Linear(1, 200),
            nn.Sigmoid(),
            nn.Linear(200, 784),
            nn.Sigmoid()
        )

        # create optimiser, simple stochastic gradient descent
        self.optimiser = torch.optim.SGD(self.parameters(), lr=0.01)

        # counter and accumulator for progress
        self.counter = 0;
        self.progress = []
        
    def forward(self, inputs):        
        # simply run model
        return self.model(inputs)
    
    def train(self, D, inputs, targets):
        # calculate the output of the network
        g_output = self.forward(inputs)
        
        # pass onto Discriminator
        d_output = D.forward(g_output)
        
        # calculate error
        loss = D.loss_function(d_output, targets)

        # increase counter and accumulate error every 10
        self.counter += 1;
        if (self.counter % 10 == 0):
            self.progress.append(loss.item())

        # zero gradients, perform a backward pass, update weights
        self.optimiser.zero_grad()
        loss.backward()
        self.optimiser.step()    
    
    def plot_progress(self):
        df = pandas.DataFrame(self.progress, columns=['loss'])
        df.plot(ylim=(0, 1.0), figsize=(16,8), alpha=0.1, marker='.', grid=True, yticks=(0, 0.25, 0.5))
  • ジェネレータ能力確認
  • G = Generator()
    output = G.forward(generate_random(1))
    img = output.detach().numpy().reshape(28, 28)
    plt.imshow(img, interpolation='none', cmap='Blues')

    特に、MSE lossは現在の損失値として機能する.
    0.5の平方、すなわち0.25はジェネレータと判別器のバランス状態である.

    学習後の分析

    %time
    D = Discriminator()
    G = Generator()
    for label, image_data_tensor, target_tensor in mnist_dataset:
        # 1단계, 참에 대한 판별기 훈련
        D.train(image_data_tensor, torch.FloatTensor([1.0]))
        
        # 2단계, 거짓에 대한 판별기 훈련
        D.train(G.forward(generate_random(1)).detach(), torch.FloatTensor([0.0]))
        
        # 3단계, 생성기 훈련
        G.train(D, generate_random(1), torch.FloatTensor([1.0]))

    初期判別器がリードし、ジェネレータの性能が徐々に向上し、バランスがとれ、最終判別器が引き続き性能優位にある場合
    f, axarr = plt.subplots(2,3, figsize=(16, 8))
    for i in range(2):
        for j in range(3):
            output = G.forward(generate_random(1))
            img = output.detach().numpy().reshape(28, 28)
            axarr[i, j].imshow(img, interpolation='none', cmap='Blues')

    生成された画像には一定のパターンがあるように見えます.
    同じ形態を持つ(差異があっても認識できない)
    この現象をモードクラッシュと呼ぶ.
    モードがクラッシュすると、ジェネレータは1つまたはごく少数の選択しか生成しません.2020年の授業資料基準という現象について、研究を続けているという.
    この点を説明する1つの像の理論は,生成時に判別器より先に離れた後,常に判別器を通過できる蜜点を発見し,その後使用を継続することである.
    判別器をより頻繁に訓練することで性能を向上させることは,このような状況を緩和できると言えるが,実際の効果はない.
    訓練の質より質が重要だからだ.
    鑑別器が正常に動作しない(品質が低下する)と、良いフィードバックが得られないからです.

    強化されたGAN性能


    訓練の質を高めることは理解できる(良いフィードバックが必要)
    訓練の質を高めるために.
  • 損失関数
  • 分類問題では,バイナリクロスエントロピーはMSElossよりも
  • 効率的である.
  • での理想損失値は0.25ではなくln(2),0.69であった.
  • アクティブ化関数
  • 傾斜消失対応
  • 正規化
  • の平均値を0にし、その分布を制限することで、極端値
  • を回避する.
    使用できます.
    そこで,判別器とジェネレータのニューラルネットワークを少し変える.
    # 판별기
            self.model = nn.Sequential(
                nn.Linear(784, 200),
                nn.LeakyReLU(0.02),
                nn.LayerNorm(200),
                nn.Linear(200, 1),
                nn.Sigmoid()
            )
    
    # 생성기
    
    self.model = nn.Sequential(
                nn.Linear(1, 200),
                nn.LeakyReLU(0.02),
                nn.LayerNorm(200),
                nn.Linear(200, 784),
                nn.Sigmoid()
            )
  • もAdamオプティカル(光学式)ドライブを使用しています.
  • # 판별기 and 생성기
    
    self.optimiser = torch.optim.Adam(self.parameters(), lr=0.0001)
    再生成された画像を表示します.

    さっきよりもイメージの違いが目立つけど.
    同様に,全体的に生成されるパターンにも同様の問題がある.
    入力seed値を増やすことで改善できます.
    既存のランダムな数値を入力することから始めます.
    100個に増えれば.
    # 생성기
    
    self.model = nn.Sequential(
                nn.Linear(100, 200),
                nn.LeakyReLU(0.02),
                nn.LayerNorm(200),
                nn.Linear(200, 784),
                nn.Sigmoid()
            )

    いくつかの変化が起こったが,モードクラッシュは解決しなかった.
    判別器とジェネレータに与えるシード値の性質は異なるはずだという.
    判別器については、実際のデータで観察される値を0から1までの値にする.
    ジェネレータでは、0から1の間の値を必要とせず、ニューラルネットワークの学習では、平均値が0で分散が限られ、正規化された値が学習に有利である.
    すなわち,ジェネレータでは,標準正規分布,平均値0,分散値1の分布から値を抽出することが有意義である.
    torch.rand=0と1の平均値
    torch.randn=平均値0,標準偏差1のGauss分布から数値を生成
    ここまで続けると、モードクラッシュは解決されます.
    シード値は、異なるタイプの数値形式を生成します.
    フォークが増えるにつれて、品質が向上します.

    モードクラッシュは、常に問題ドメインによって解決されるわけではありません.
    上級者では解決できないことも多い.
    モードクラッシュを解決するために
    私はいろいろな措置を取ることを学ばなければならない.

    GANとSEED


    しかしseedが生成した画像にはいくつかの特性がある.
    seed Aとseed Bがある場合、
    seed Aからseed Bまでの値を12ステップに分けて画像を作成します.
    seed Aが生成した画像から、seed Bが生成した画像に徐々に変化していくことがわかる.

    seed Aとseed Bを加えて、
    生成された画像はseed Aの画像とseed Bの画像と同じである.

    seed Aからseed Bを減算すると、
    単にseed Aの画像からseed Bと重なる部分を削除するわけではない.


    上記画像の結果、seedが生成した画像
    想像以上に複雑な論理に戻ることを示唆します.
    (seed 1-seed 2の数字3)