Part 1 covered the ELBO, KL divergence, and the reparameterization trick. Here we turn those equations into a working VAE in pure PyTorch---just torch.nn, no external generative modeling libraries.
The Encoder
The encoder maps a flattened 28$\times$28 MNIST image (784 dimensions) to the parameters of a Gaussian distribution over the latent space:
self.enc_fc1 = nn.Linear(784, 512)
self.enc_fc2 = nn.Linear(512, 256)
self.fc_mu = nn.Linear(256, latent_dim)
self.fc_log_var = nn.Linear(256, latent_dim)
Two hidden layers with ReLU activations form the shared backbone. The final layer splits into two parallel heads: one for $\boldsymbol{\mu}$ and one for $\log \boldsymbol{\sigma}^2$. This dual-head design is the key architectural difference from a standard autoencoder.
Why $\log \boldsymbol{\sigma}^2$ instead of $\boldsymbol{\sigma}$? The log-variance can take any real value, while $\boldsymbol{\sigma}$ must be positive. Parameterizing in log-space avoids the need for explicit positivity constraints and improves numerical stability.
The Reparameterization Trick
The core of the whole model, in three lines:
def reparameterize(self, mu, log_var):
std = torch.exp(0.5 * log_var)
eps = torch.randn_like(std)
z = mu + eps * std
return z
torch.exp(0.5 * log_var) converts log-variance to standard deviation. torch.randn_like(std) samples $\boldsymbol{\epsilon} \sim \mathcal{N}(\mathbf{0}, \mathbf{I})$ with matching shape and device. The final line computes $\mathbf{z} = \boldsymbol{\mu} + \boldsymbol{\epsilon} \odot \boldsymbol{\sigma}$.
During training, this introduces stochasticity---each forward pass samples a different $\mathbf{z}$ for the same input. At inference time, you can either sample or simply use $\boldsymbol{\mu}$ as the latent representation.
The Decoder
The decoder mirrors the encoder, mapping from latent space back to pixel space:
self.dec_fc1 = nn.Linear(latent_dim, 256)
self.dec_fc2 = nn.Linear(256, 512)
self.dec_fc3 = nn.Linear(512, 784)
ReLU activations on hidden layers, sigmoid on the output to bound reconstructions to $[0, 1]$ (matching the normalized pixel range of MNIST).
The ELBO Loss Function
The loss decomposes into two terms:
def vae_loss(recon_x, x, mu, log_var):
recon_loss = F.binary_cross_entropy(
recon_x.view(-1, 784), x.view(-1, 784), reduction='sum')
kl_loss = -0.5 * torch.sum(
1 + log_var - mu.pow(2) - log_var.exp())
return recon_loss + kl_loss, recon_loss, kl_loss
The reconstruction loss (BCE) measures how well the decoder reproduces the input. Using reduction='sum' sums over all 784 pixels and all batch elements, giving a total loss that scales with both image size and batch size.
The KL divergence implements the closed-form expression:
We return both terms separately to monitor them during training---the balance between reconstruction and KL is critical for understanding VAE behavior.
The Convolutional Variant
For better spatial feature capture, we also implement a ConvVAE:
Encoder: Two Conv2d layers with stride 2 for downsampling (1$\times$28$\times$28 → 32$\times$14$\times$14 → 64$\times$7$\times$7), followed by a fully connected layer to the latent parameters.
Decoder: Fully connected layer to 64$\times$7$\times$7, then two ConvTranspose2d layers with stride 2 for upsampling back to 1$\times$28$\times$28.
The convolutional variant preserves spatial structure that the fully connected model discards when flattening.
Parameter Counts
With latent_dim=2: 1,068,820 trainable parameters.
With latent_dim=20: 1,082,680 trainable parameters.
Only 13,860 additional parameters, since the latent layer is tiny relative to the hidden layers. The effect on reconstruction quality, however, is substantial (see Part 3).
Training Setup
- Dataset: MNIST, 5,000-image training subset
- Optimizer: Adam with learning rate $10^{-3}$
- Batch size: 64
- Epochs: 30
- Latent dimensions: 2 (visualization) and 20 (quality)
Conclusion
The full VAE fits in under 150 lines for the model and loss function. Part 3 trains both the FC and convolutional variants and explores their latent spaces through visualization, interpolation, and generation.