Back to CapsNets Hub

Deconstructing CapsNets from Scratch

Part 2: PyTorch Implementation

Introduction

This post translates the CapsNet math from Part 1 into PyTorch. We implement every component from scratch: squashing, PrimaryCapsule, DigitCapsule with dynamic routing, margin loss, and the reconstruction decoder.

The paper glosses over several implementation details -- tensor shapes inside routing, gradient detachment, loss weighting -- so we walk through those here.

Architecture Overview

The CapsNet for MNIST has four stages:

  1. Conv1: Conv2d(1, 256, kernel_size=9, stride=1) + ReLU. $(B, 1, 28, 28) \to (B, 256, 20, 20)$
  2. PrimaryCapsule: Conv2d(256, 256, kernel_size=9, stride=2) + Reshape + Squash. Output: $(B, 1152, 8)$ -- 1,152 capsules of dimension 8
  3. DigitCapsule: Dynamic routing with learned weight matrices. Output: $(B, 10, 16)$ -- 10 class capsules of dimension 16
  4. Decoder: Linear(16 -> 512 -> 1024 -> 784) with masking. Reconstructs the input from the winning capsule

Total parameters: 8,141,840.

The Squashing Function

def squash(s, dim=-1):
    squared_norm = (s ** 2).sum(dim=dim, keepdim=True)
    safe_norm = torch.sqrt(squared_norm + 1e-8)
    scale = squared_norm / (1.0 + squared_norm)
    return scale * s / safe_norm

The 1e-8 inside the square root prevents division by zero when a capsule has near-zero activation. Without it, gradients go to NaN.

Mathematically:

$$ \mathbf{v} = \underbrace{\frac{\|\mathbf{s}\|^2}{1 + \|\mathbf{s}\|^2}}_{\text{scale factor}} \cdot \underbrace{\frac{\mathbf{s}}{\|\mathbf{s}\|}}_{\text{unit vector}} $$

keepdim=True is essential so the scaling factor broadcasts correctly against the input tensor.

PrimaryCapsule Layer

Converts convolutional features into capsule vectors.

class PrimaryCapsule(nn.Module):
    def __init__(self, in_channels=256, num_types=32,
                 caps_dim=8, kernel_size=9, stride=2):
        super().__init__()
        self.caps_dim = caps_dim
        self.conv = nn.Conv2d(in_channels,
                              num_types * caps_dim,
                              kernel_size, stride)

    def forward(self, x):
        out = self.conv(x)       # (B, 256, 6, 6)
        B = out.size(0)
        out = out.view(B, 32, 8, -1)  # (B, 32, 8, 36)
        out = out.permute(0,1,3,2)     # (B, 32, 36, 8)
        out = out.reshape(B, -1, 8)    # (B, 1152, 8)
        return squash(out)

Shape Walkthrough

DigitCapsule Layer and Dynamic Routing

Each of the 1,152 primary capsules routes to one of 10 digit capsules. This is where the architecture gets interesting.

Weight Matrix

# W: (1, 1152, 10, 16, 8)
self.W = nn.Parameter(
    torch.randn(1, 1152, 10, 16, 8) * 0.01
)

A 5D tensor holding $1{,}152 \times 10$ transformation matrices, each $16 \times 8$. For primary capsule $i$ and digit capsule $j$, $\mathbf{W}_{ij}$ transforms the 8D input to a 16D prediction:

$$ \hat{\mathbf{u}}_{j|i} = \mathbf{W}_{ij} \, \mathbf{u}_i $$

Computing Prediction Vectors

# u: (B, 1152, 8) -> (B, 1152, 1, 8, 1)
u = u.unsqueeze(2).unsqueeze(4)

# W: (1, 1152, 10, 16, 8) @ u: (B, 1152, 1, 8, 1)
# -> u_hat: (B, 1152, 10, 16, 1) -> squeeze -> (B, 1152, 10, 16)
u_hat = torch.matmul(self.W, u).squeeze(-1)

Broadcasting handles the batch and 10-class dimensions in one shot. Result: 1,152 prediction vectors of dimension 16 for each of 10 classes.

Dynamic Routing Implementation

def dynamic_routing(self, u_hat, u_hat_detached, B):
    b = torch.zeros(B, 1152, 10, device=u_hat.device)

    for r in range(self.num_routing):
        c = F.softmax(b, dim=2)      # coupling coefficients

        if r < self.num_routing - 1:
            s = (c.unsqueeze(-1) * u_hat_detached).sum(dim=1)
            v = squash(s)
            agreement = (u_hat_detached * v.unsqueeze(1)).sum(-1)
            b = b + agreement
        else:
            # Last iteration: allow gradients through
            s = (c.unsqueeze(-1) * u_hat).sum(dim=1)
            v = squash(s)

    return v  # (B, 10, 16)

Gradient Detachment

Iterations 1 and 2 use u_hat_detached -- prediction vectors with gradients cut off. Only the final iteration uses u_hat with live gradients.

Backpropagating through all routing iterations builds a deep computation graph that causes vanishing/exploding gradients, high memory usage, and unstable training with no convergence benefit. The last iteration alone gives enough gradient signal to learn the weights.

Margin Loss

Per-class margin loss for classification:

$$ L_c = T_c \max(0,\; m^+ - \|\mathbf{v}_c\|)^2 + \lambda_{\text{neg}} (1 - T_c) \max(0,\; \|\mathbf{v}_c\| - m^-)^2 $$

where $T_c = 1$ if class $c$ is present, $m^+ = 0.9$, $m^- = 0.1$, and $\lambda_{\text{neg}} = 0.5$.

v_length = torch.sqrt((v**2).sum(dim=-1) + 1e-8)
T = torch.zeros(B, 10, device=v.device)
T.scatter_(1, labels.unsqueeze(1), 1.0)

present = T * F.relu(0.9 - v_length)**2
absent  = 0.5 * (1-T) * F.relu(v_length - 0.1)**2
margin_loss = (present + absent).sum(dim=1).mean()

Reconstruction Decoder

The decoder regularizes training by forcing the 16D capsule vectors to retain enough information to reconstruct the input.

self.decoder = nn.Sequential(
    nn.Linear(16, 512),
    nn.ReLU(inplace=True),
    nn.Linear(512, 1024),
    nn.ReLU(inplace=True),
    nn.Linear(1024, 784),
    nn.Sigmoid()
)

During training, only the correct class capsule feeds into the decoder (others masked to zero). At inference, the longest capsule is used.

Reconstruction loss is MSE, weighted by $\lambda_{\text{recon}} = 0.0005$:

$$ L = L_{\text{margin}} + 0.0005 \times L_{\text{recon}} $$

The small weight keeps reconstruction as a regularizer without drowning out the margin loss.

Putting It All Together

def forward(self, x, y=None):
    out = F.relu(self.conv1(x))        # (B, 256, 20, 20)
    primary = self.primary_caps(out)    # (B, 1152, 8)
    v = self.digit_caps(primary)        # (B, 10, 16)

    probs = torch.sqrt((v**2).sum(dim=-1) + 1e-8)

    # Mask: keep only correct/predicted capsule
    if y is not None:
        mask = torch.zeros_like(probs)
        mask.scatter_(1, y.unsqueeze(1), 1.0)
    else:
        _, idx = probs.max(dim=1)
        mask = torch.zeros_like(probs)
        mask.scatter_(1, idx.unsqueeze(1), 1.0)

    masked_v = (v * mask.unsqueeze(-1)).sum(dim=1)
    recon = self.decoder(masked_v)
    return v, recon

Training Configuration

Next

Part 3 tests the trained CapsNet: standard MNIST accuracy, reconstruction quality from 16D vectors, and rotation robustness compared to a simple CNN baseline.