Back to CapsNets Hub

Deconstructing CapsNets from Scratch

Part 2: PyTorch Implementation

Introduction

This second installment translates the mathematics of Capsule Networks into working PyTorch code. We implement every component from scratch: the squashing non-linearity, PrimaryCapsule layer, DigitCapsule layer with dynamic routing, margin loss, and reconstruction decoder.

Along the way, we address the implementation details that the original paper leaves implicit -- tensor shapes in the routing algorithm, gradient detachment for stability, and the careful balance between classification and reconstruction losses.

Architecture Overview

The full CapsNet architecture for MNIST consists of four stages:

  1. Conv1: Conv2d(1, 256, kernel_size=9, stride=1) + ReLU. Input: $(B, 1, 28, 28)$ to Output: $(B, 256, 20, 20)$
  2. PrimaryCapsule: Conv2d(256, 256, kernel_size=9, stride=2) + Reshape + Squash. Output: $(B, 1152, 8)$ -- 1,152 capsules of dimension 8
  3. DigitCapsule: Dynamic routing with learned weight matrices. Output: $(B, 10, 16)$ -- 10 class capsules of dimension 16
  4. Decoder: Linear(16 -> 512 -> 1024 -> 784) with masking. Reconstructs the input image from the winning capsule

Total parameters: 8,141,840.

The Squashing Function

def squash(s, dim=-1):
    squared_norm = (s ** 2).sum(dim=dim, keepdim=True)
    safe_norm = torch.sqrt(squared_norm + 1e-8)
    scale = squared_norm / (1.0 + squared_norm)
    return scale * s / safe_norm

Key implementation detail: the 1e-8 inside the square root prevents division by zero when a capsule has near-zero activation. Without this, gradients become NaN during backpropagation.

The function computes:

$$ \mathbf{v} = \underbrace{\frac{\|\mathbf{s}\|^2}{1 + \|\mathbf{s}\|^2}}_{\text{scale factor}} \cdot \underbrace{\frac{\mathbf{s}}{\|\mathbf{s}\|}}_{\text{unit vector}} $$

The keepdim=True argument is essential -- it ensures the scaling factor broadcasts correctly against the input tensor.

PrimaryCapsule Layer

The PrimaryCapsule layer converts standard convolutional features into capsule vectors.

class PrimaryCapsule(nn.Module):
    def __init__(self, in_channels=256, num_types=32,
                 caps_dim=8, kernel_size=9, stride=2):
        super().__init__()
        self.caps_dim = caps_dim
        self.conv = nn.Conv2d(in_channels,
                              num_types * caps_dim,
                              kernel_size, stride)

    def forward(self, x):
        out = self.conv(x)       # (B, 256, 6, 6)
        B = out.size(0)
        out = out.view(B, 32, 8, -1)  # (B, 32, 8, 36)
        out = out.permute(0,1,3,2)     # (B, 32, 36, 8)
        out = out.reshape(B, -1, 8)    # (B, 1152, 8)
        return squash(out)

Shape Walkthrough

DigitCapsule Layer and Dynamic Routing

This is the core innovation. Each of the 1,152 primary capsules must route to one of 10 digit capsules.

Weight Matrix

# W: (1, 1152, 10, 16, 8)
self.W = nn.Parameter(
    torch.randn(1, 1152, 10, 16, 8) * 0.01
)

This 5D tensor contains $1{,}152 \times 10$ transformation matrices, each of size $16 \times 8$. For each primary capsule $i$ and digit capsule $j$, $\mathbf{W}_{ij}$ transforms the 8D primary vector to a 16D prediction vector:

$$ \hat{\mathbf{u}}_{j|i} = \mathbf{W}_{ij} \, \mathbf{u}_i $$

Computing Prediction Vectors

# u: (B, 1152, 8) -> (B, 1152, 1, 8, 1)
u = u.unsqueeze(2).unsqueeze(4)

# W: (1, 1152, 10, 16, 8) @ u: (B, 1152, 1, 8, 1)
# -> u_hat: (B, 1152, 10, 16, 1) -> squeeze -> (B, 1152, 10, 16)
u_hat = torch.matmul(self.W, u).squeeze(-1)

Broadcasting handles the batch dimension and the 10-class dimension simultaneously. The result is 1,152 prediction vectors of dimension 16 for each of 10 classes.

Dynamic Routing Implementation

def dynamic_routing(self, u_hat, u_hat_detached, B):
    b = torch.zeros(B, 1152, 10, device=u_hat.device)

    for r in range(self.num_routing):
        c = F.softmax(b, dim=2)      # coupling coefficients

        if r < self.num_routing - 1:
            s = (c.unsqueeze(-1) * u_hat_detached).sum(dim=1)
            v = squash(s)
            agreement = (u_hat_detached * v.unsqueeze(1)).sum(-1)
            b = b + agreement
        else:
            # Last iteration: allow gradients through
            s = (c.unsqueeze(-1) * u_hat).sum(dim=1)
            v = squash(s)

    return v  # (B, 10, 16)

Critical Detail: Gradient Detachment

For routing iterations 1 and 2, we use u_hat_detached -- the prediction vectors with gradients detached. Only on the final iteration do we use u_hat with gradients intact.

Why? Backpropagating through all routing iterations creates deep computation graphs that:

The final iteration alone provides sufficient gradient signal to learn the weight matrices.

Margin Loss

Classification uses a per-class margin loss:

$$ L_c = T_c \max(0,\; m^+ - \|\mathbf{v}_c\|)^2 + \lambda_{\text{neg}} (1 - T_c) \max(0,\; \|\mathbf{v}_c\| - m^-)^2 $$

where $T_c = 1$ if class $c$ is present, $m^+ = 0.9$, $m^- = 0.1$, and $\lambda_{\text{neg}} = 0.5$.

v_length = torch.sqrt((v**2).sum(dim=-1) + 1e-8)
T = torch.zeros(B, 10, device=v.device)
T.scatter_(1, labels.unsqueeze(1), 1.0)

present = T * F.relu(0.9 - v_length)**2
absent  = 0.5 * (1-T) * F.relu(v_length - 0.1)**2
margin_loss = (present + absent).sum(dim=1).mean()

Interpretation:

Reconstruction Decoder

The decoder serves as a regularizer, forcing the 16D capsule vectors to encode enough information to reconstruct the input image.

self.decoder = nn.Sequential(
    nn.Linear(16, 512),
    nn.ReLU(inplace=True),
    nn.Linear(512, 1024),
    nn.ReLU(inplace=True),
    nn.Linear(1024, 784),
    nn.Sigmoid()
)

During training, only the correct class capsule is fed to the decoder (all others are masked to zero). During inference, the capsule with the largest length is used.

The reconstruction loss is MSE, weighted by $\lambda_{\text{recon}} = 0.0005$:

$$ L = L_{\text{margin}} + 0.0005 \times L_{\text{recon}} $$

The small weight ensures reconstruction serves as regularization without overwhelming the margin loss.

Putting It All Together

def forward(self, x, y=None):
    out = F.relu(self.conv1(x))        # (B, 256, 20, 20)
    primary = self.primary_caps(out)    # (B, 1152, 8)
    v = self.digit_caps(primary)        # (B, 10, 16)

    probs = torch.sqrt((v**2).sum(dim=-1) + 1e-8)

    # Mask: keep only correct/predicted capsule
    if y is not None:
        mask = torch.zeros_like(probs)
        mask.scatter_(1, y.unsqueeze(1), 1.0)
    else:
        _, idx = probs.max(dim=1)
        mask = torch.zeros_like(probs)
        mask.scatter_(1, idx.unsqueeze(1), 1.0)

    masked_v = (v * mask.unsqueeze(-1)).sum(dim=1)
    recon = self.decoder(masked_v)
    return v, recon

Training Configuration

Looking Ahead

With the implementation complete, Part 3 presents the experimental results: MNIST classification accuracy on standard and rotated digits, reconstruction quality from 16D capsule representations, and a direct comparison with a simple CNN baseline quantifying the rotation robustness advantage.