貧乏人のVAE解説:確率論の幽霊から解放


前書き

AE = Autoencoder
VAE = Variational Autoencoder

この記事は、読者がAEの知識を持っていることを前提としています。 直感的には、AEをコンプレッサーとして理解することができます。実際、私が知る限り、GoogleはAEを使用してファイルを圧縮しています。これは非常に効率的です。

AEの欠点は、次の図に示すように、さまざまな入力が潜在空間の離散ポイントにマッピングされ、ポイント間に接続がないため、潜在空間の広い領域が無駄になることです。

VAEの導入の理由は、ポイント間の真空を埋めることです。 この目標を達成するために必要なのは、以下で1つずつ分析される単純な2ステップの操作だけです

離散点の範囲を拡大する

ポイント間の真空を埋める方法は、非常に簡単です、離散ポイントの範囲を拡張するだけです。 では、ポイントの範囲を拡大するにはどうすればよいでしょうか。 答えは、以下に示すように、乱数を導入することです。

神が私たちに与えた乱数は、拳銃をショットガンに変えました。

VAEの説明については、「Intuitively Understanding Variational Autoencoders」というブログを強くお勧めします。 グーグルで検索すれば見つけます。トリックを見つける前に、私は多くのブログと論文を見て、長い間調べました。この記事は間違いなく最高の記事です(Intuitiveというキーワードのつけた記事は常に良いです)。

しかし、ほとんどすべての人が元の論文の手順に従って乱数を導入しました。meanとvarianを使用して確率分布をシミュレートしました。 これは、VAEを理解するための最大のしきい値です。

実際、乱数を使用するのは非常に簡単です。乱数を直接使用するだけで、確率分布を人為的にシミュレートする必要は全然ありません。乱数自体は正規分布にしたがっています。 サンプルコードは次のとおりで、効果は上図と完全に一致しています。

def add_disturbance(self, z):
  epsilon = torch.randn(z.shape)
  return z + epsilon

ただし、確率変数を追加した後でも、潜在的なスペースは無限であり、制限をせずにポイントが一箇所に集める理由がないため、ポイントは分離しています。次の最後のステップを導入する必要があります。

離散点を集める

上の図に示すように、左の結果を取得したいですが、制限がない場合、ステップ1を従って、右の図の結果しか取得できません。

では、どうすれば離散点を集められるですか? 非常に簡単です。アンカーポイントを導入するだけです。 このアンカーポイントは、torch.randnを使用して直接取得できる潜在空間内の任意のポイントにすることができます。または、座標原点を直接使用することで、計算も簡単にできます。 次に、各ポイントからアンカーポイントまでの距離を計算し、この距離をLossとして定義して、トレーニングでLossを減らすことができます。 実際、KL Lossもまさにこの事を行いています。

KL = z.pow(2).sum().sqrt() # Force z decay to zero

KL Lossだけの場合、すべてのポイントが原点に落ちるだけで終わりです。これは私たちが望んでいることではありませんが、再構築Lossがあります(忘れないでください、これがAEの主な目標です) 、2つの損失を足し合わせるだけで、最適化はバランスに達します。

これはVAEです。

再構築Lossについて話しましょう。元の論文の再構築Lossも非常に理解し難いので、最後の仕上げとして再構築Lossを単純化します。

再構築Lossを単純化

# Pre define
self.MSE = nn.MSELoss(reduction='sum') # Sum is important
# ...
# x from input
z = self.get_z(x)
y = self.get_y(z)
reconstruction = self.MSE(y, x)
# sum up loss
loss = reconstruction + KL

これはVAEです。

以下は結果の概略図です。 最初の行は元の画像、2番目の行は再構成された画像、3番目の行は最初の番号から2番目の番号への遷移画像です。

VAEを作るために、確率論はまったく必要ないことです。確率論が少ないほど、私たちはより良く生きることができます。 私はこの実装を貧乏人のVAEと名付けました。この実装を論文で使用している人がいたら、少なくとも私に知らせてください、ありがとうございます。

コード

コードはgithubでホストされています:https://github.com/zhuobinggang/poor-man-VAE

主要部分は次のとおりです。

class My_VAE_V2(nn.Module):
    def __init__(self, z_dim):
      super().__init__()
      self.fw1 = nn.Linear(28*28, 28*14)
      self.z_layer = nn.Linear(28*14, z_dim)
      self.fw2 = nn.Linear(z_dim, 28*14)
      self.recover_layer = nn.Linear(28*14, 28*28)
      self.MSE = nn.MSELoss(reduction='sum')

    def _encoder(self, x):
      z = self.add_disturbance(self.z_layer(F.sigmoid(self.fw1(x))))
      return z

    def add_disturbance(self, z):
      epsilon = torch.randn(z.shape)
      return z + epsilon

    def _decoder(self, z):
      o = F.sigmoid(self.recover_layer(F.sigmoid(self.fw2(z))))
      return o

    def forward(self, x):
      z = self._encoder(x)
      o = self._decoder(z)
      return o, z

    def loss(self, x):
      z = self._encoder(x)
      KL = z.pow(2).sum().sqrt() # Force z decay to zero
      y = self._decoder(z)
      reconstruction = self.MSE(y, x)
      return KL + reconstruction