Back to VAEs Hub

Deconstructing VAEs from Scratch

Part 2: PyTorch Implementation

Introduction

In Part 1, we derived the ELBO, KL divergence, and the reparameterization trick. Now we translate every equation into working PyTorch code, building a complete VAE from scratch---no generative modeling libraries, just torch.nn.

The Encoder

The encoder maps a flattened 28$\times$28 MNIST image (784 dimensions) to the parameters of a Gaussian distribution over the latent space:

self.enc_fc1 = nn.Linear(784, 512)
self.enc_fc2 = nn.Linear(512, 256)
self.fc_mu = nn.Linear(256, latent_dim)
self.fc_log_var = nn.Linear(256, latent_dim)

Two hidden layers with ReLU activations form the shared backbone. The final layer splits into two parallel heads: one for $\boldsymbol{\mu}$ and one for $\log \boldsymbol{\sigma}^2$. This dual-head design is the key architectural difference from a standard autoencoder.

Why $\log \boldsymbol{\sigma}^2$ instead of $\boldsymbol{\sigma}$? The log-variance can take any real value, while $\boldsymbol{\sigma}$ must be positive. Parameterizing in log-space avoids the need for explicit positivity constraints and improves numerical stability.

The Reparameterization Trick

The most elegant function in the entire model:

def reparameterize(self, mu, log_var):
    std = torch.exp(0.5 * log_var)
    eps = torch.randn_like(std)
    z = mu + eps * std
    return z

Three lines. torch.exp(0.5 * log_var) converts log-variance to standard deviation. torch.randn_like(std) samples $\boldsymbol{\epsilon} \sim \mathcal{N}(\mathbf{0}, \mathbf{I})$ with the same shape and device. The final line computes $\mathbf{z} = \boldsymbol{\mu} + \boldsymbol{\epsilon} \odot \boldsymbol{\sigma}$.

During training, this introduces stochasticity---each forward pass samples a different $\mathbf{z}$ for the same input. At inference time, you can either sample or simply use $\boldsymbol{\mu}$ as the latent representation.

The Decoder

The decoder mirrors the encoder, mapping from latent space back to pixel space:

self.dec_fc1 = nn.Linear(latent_dim, 256)
self.dec_fc2 = nn.Linear(256, 512)
self.dec_fc3 = nn.Linear(512, 784)

ReLU activations on hidden layers, sigmoid on the output to bound reconstructions to $[0, 1]$ (matching the normalized pixel range of MNIST).

The ELBO Loss Function

The loss decomposes into two terms:

def vae_loss(recon_x, x, mu, log_var):
    recon_loss = F.binary_cross_entropy(
        recon_x.view(-1, 784), x.view(-1, 784), reduction='sum')
    kl_loss = -0.5 * torch.sum(
        1 + log_var - mu.pow(2) - log_var.exp())
    return recon_loss + kl_loss, recon_loss, kl_loss

The reconstruction loss (BCE) measures how well the decoder reproduces the input. Using reduction='sum' sums over all 784 pixels and all batch elements, giving a total loss that scales with both image size and batch size.

The KL divergence implements the closed-form expression:

$$ D_{\text{KL}} = -\frac{1}{2} \sum \left( 1 + \log\_var - \mu^2 - e^{\log\_var} \right) $$

We return both terms separately to monitor them during training---the balance between reconstruction and KL is critical for understanding VAE behavior.

The Convolutional Variant

For better spatial feature capture, we also implement a ConvVAE:

Encoder: Two Conv2d layers with stride 2 for downsampling (1$\times$28$\times$28 → 32$\times$14$\times$14 → 64$\times$7$\times$7), followed by a fully connected layer to the latent parameters.

Decoder: Fully connected layer to 64$\times$7$\times$7, then two ConvTranspose2d layers with stride 2 for upsampling back to 1$\times$28$\times$28.

The convolutional variant preserves spatial structure that the fully connected model discards when flattening.

Parameter Counts

With latent_dim=2: 1,068,820 trainable parameters. With latent_dim=20: 1,082,680 trainable parameters.

The difference is minimal---only 13,860 additional parameters---because the latent layer is tiny compared to the hidden layers. Yet the impact on reconstruction quality is dramatic, as we will see in Part 3.

Training Setup

Conclusion

The complete VAE implementation requires surprisingly little code---under 150 lines for the model and loss function. The architectural simplicity belies the mathematical depth. In Part 3, we train these models and explore their latent spaces through visualization, interpolation, and random generation.