へんぶんじどうエンコーダ

2974 ワード

へんぶんじどうエンコーダ


変分エンコーダは自動エンコーダのアップグレードバージョンであり、その構造は自動エンコーダと類似しており、エンコーダとデコーダからも構成されている.
思い出してみると、自動エンコーダには任意に画像を生成することができないという問題があります.私たちは自分で隠しベクトルを構築することができないので、1枚の画像入力符号化を通じて得られた隠しベクトルが何なのかを知る必要があります.このとき、私たちは自動エンコーダを変分することでこの問題を解決することができます.
実際の原理は特に簡単で、符号化プロセスにいくつかの制限を加えるだけで、生成された暗黙ベクトルは、一般的な自動エンコーダと最大の違いである標準的な正規分布にざっと従うことができる.
これにより、新しいピクチャを生成するのは簡単です.標準的な正規分布のランダムな隠しベクトルを与えるだけで、デコーダによって元のピクチャを先に符号化する必要がなく、私たちが望んでいるピクチャを生成することができます.
一般に,encoderによって得られた暗黙ベクトルは標準的な正規分布ではなく,両分布の類似度を測定するためにKL divergenceを用いて暗黙ベクトルと標準正規分布の違いのlossを表し,もう1つのlossは生成ピクチャと元のピクチャの平均二乗誤差を用いて表した.
KL divergenceの公式は以下の通りです.

じゅうパラメータ


KL divergenceでの積分を計算することを避けるために,重パラメータの技法を用いて,1つの暗黙ベクトルを生成するたびに,2つのベクトル,1つの平均値,1つの標準差を生成するのではなく,ここでは符号化後の暗黙ベクトルが1つの正規分布に従うことをデフォルトとし,1つの標準正規分布に標準差を乗じて平均値を加算してこの正規分布を合成することができる.最後にlossは,この生成された正規分布が標準正規分布,すなわち平均値0,分散1に適合することを望んでいる.
標準的な変分自動エンコーダは以下の通りです
mu(平均)とlogvar(分散)を正規分布に訓練する必要がある.すなわち,平均値を0に近づけ,分散を1に近づける必要がある.復号後のピクチャと元のピクチャのlossを低減する必要があるので,我々の最終lossは平均分散と正規分布のlossと復号後と符号化前のlossである.
reconstruction_function = nn.MSELoss(size_average=False) #MSE 

def loss_function(recon_x, x, mu, logvar):
    """
    recon_x: generating images # 
    x: origin images # 
    mu: latent mean # 
    logvar: latent log variance # 
    """
    MSE = reconstruction_function(recon_x, x)
    # loss = 0.5 * sum(1 + log(sigma^2) - mu^2 - sigma^2)
    KLD_element = mu.pow(2).add_(logvar.exp()).mul_(-1).add_(1).add_(logvar)
    KLD = torch.sum(KLD_element).mul_(-0.5)
    # KL divergence
    return MSE + KLD // loss loss

optimizer = torch.optim.Adam(net.parameters(), lr=1e-3)

# tensor 
def to_img(x):
    '''
     
    '''
    x = 0.5 * (x + 1.)
    x = x.clamp(0, 1)
    x = x.view(x.shape[0], 1, 28, 28)
    return x

変分デコーダのモデルは以下の通りである.
class VAE(nn.Module):
    def __init__(self):
        super(VAE, self).__init__()

        self.fc1 = nn.Linear(784, 400)
        self.fc21 = nn.Linear(400, 20) # mean 
        self.fc22 = nn.Linear(400, 20) # var 
        self.fc3 = nn.Linear(20, 400)
        self.fc4 = nn.Linear(400, 784)

    def encode(self, x):
        h1 = F.relu(self.fc1(x))
        return self.fc21(h1), self.fc22(h1)

    def reparametrize(self, mu, logvar):
        std = logvar.mul(0.5).exp_()
        eps = torch.FloatTensor(std.size()).normal_()
        if torch.cuda.is_available():
            eps = Variable(eps.cuda())
        else:
            eps = Variable(eps)
        return eps.mul(std).add_(mu)

    def decode(self, z):
        h3 = F.relu(self.fc3(z))
        return F.tanh(self.fc4(h3))

    def forward(self, x):
        mu, logvar = self.encode(x) #  
        z = self.reparametrize(mu, logvar) #  
        return self.decode(z), mu, logvar #  , 

 
元のリンク:https://github.com/L1aoXingyu/code-of-learn-deep-learning-with-pytorch/blob/master/chapter6_GAN/vae.ipynb