This second installment translates the mathematics of Capsule Networks into working PyTorch code. We implement every component from scratch: the squashing non-linearity, PrimaryCapsule layer, DigitCapsule layer with dynamic routing, margin loss, and reconstruction decoder.
Along the way, we address the implementation details that the original paper leaves implicit -- tensor shapes in the routing algorithm, gradient detachment for stability, and the careful balance between classification and reconstruction losses.
Architecture Overview
The full CapsNet architecture for MNIST consists of four stages:
- Conv1:
Conv2d(1, 256, kernel_size=9, stride=1)+ ReLU. Input: $(B, 1, 28, 28)$ to Output: $(B, 256, 20, 20)$ - PrimaryCapsule:
Conv2d(256, 256, kernel_size=9, stride=2)+ Reshape + Squash. Output: $(B, 1152, 8)$ -- 1,152 capsules of dimension 8 - DigitCapsule: Dynamic routing with learned weight matrices. Output: $(B, 10, 16)$ -- 10 class capsules of dimension 16
- Decoder:
Linear(16 -> 512 -> 1024 -> 784)with masking. Reconstructs the input image from the winning capsule
Total parameters: 8,141,840.
The Squashing Function
def squash(s, dim=-1):
squared_norm = (s ** 2).sum(dim=dim, keepdim=True)
safe_norm = torch.sqrt(squared_norm + 1e-8)
scale = squared_norm / (1.0 + squared_norm)
return scale * s / safe_norm
Key implementation detail: the 1e-8 inside the square root prevents division by zero when a capsule has near-zero activation. Without this, gradients become NaN during backpropagation.
The function computes:
The keepdim=True argument is essential -- it ensures the scaling factor broadcasts correctly against the input tensor.
PrimaryCapsule Layer
The PrimaryCapsule layer converts standard convolutional features into capsule vectors.
class PrimaryCapsule(nn.Module):
def __init__(self, in_channels=256, num_types=32,
caps_dim=8, kernel_size=9, stride=2):
super().__init__()
self.caps_dim = caps_dim
self.conv = nn.Conv2d(in_channels,
num_types * caps_dim,
kernel_size, stride)
def forward(self, x):
out = self.conv(x) # (B, 256, 6, 6)
B = out.size(0)
out = out.view(B, 32, 8, -1) # (B, 32, 8, 36)
out = out.permute(0,1,3,2) # (B, 32, 36, 8)
out = out.reshape(B, -1, 8) # (B, 1152, 8)
return squash(out)
Shape Walkthrough
- Input: $(B, 256, 20, 20)$ from Conv1.
- After
Conv2d(256, 256, 9, stride=2): spatial size $(20-9)/2+1 = 6$, giving $(B, 256, 6, 6)$. - The 256 channels encode 32 capsule types $\times$ 8 dimensions each.
- Reshape and permute to get $(B, 1152, 8)$: 1,152 capsules ($32 \times 6 \times 6$) of dimension 8.
- Squash: ensure each capsule has length $\in [0, 1)$.
DigitCapsule Layer and Dynamic Routing
This is the core innovation. Each of the 1,152 primary capsules must route to one of 10 digit capsules.
Weight Matrix
# W: (1, 1152, 10, 16, 8)
self.W = nn.Parameter(
torch.randn(1, 1152, 10, 16, 8) * 0.01
)
This 5D tensor contains $1{,}152 \times 10$ transformation matrices, each of size $16 \times 8$. For each primary capsule $i$ and digit capsule $j$, $\mathbf{W}_{ij}$ transforms the 8D primary vector to a 16D prediction vector:
Computing Prediction Vectors
# u: (B, 1152, 8) -> (B, 1152, 1, 8, 1)
u = u.unsqueeze(2).unsqueeze(4)
# W: (1, 1152, 10, 16, 8) @ u: (B, 1152, 1, 8, 1)
# -> u_hat: (B, 1152, 10, 16, 1) -> squeeze -> (B, 1152, 10, 16)
u_hat = torch.matmul(self.W, u).squeeze(-1)
Broadcasting handles the batch dimension and the 10-class dimension simultaneously. The result is 1,152 prediction vectors of dimension 16 for each of 10 classes.
Dynamic Routing Implementation
def dynamic_routing(self, u_hat, u_hat_detached, B):
b = torch.zeros(B, 1152, 10, device=u_hat.device)
for r in range(self.num_routing):
c = F.softmax(b, dim=2) # coupling coefficients
if r < self.num_routing - 1:
s = (c.unsqueeze(-1) * u_hat_detached).sum(dim=1)
v = squash(s)
agreement = (u_hat_detached * v.unsqueeze(1)).sum(-1)
b = b + agreement
else:
# Last iteration: allow gradients through
s = (c.unsqueeze(-1) * u_hat).sum(dim=1)
v = squash(s)
return v # (B, 10, 16)
Critical Detail: Gradient Detachment
For routing iterations 1 and 2, we use u_hat_detached -- the prediction vectors with gradients detached. Only on the final iteration do we use u_hat with gradients intact.
Why? Backpropagating through all routing iterations creates deep computation graphs that:
- Cause vanishing/exploding gradients through the iterative process.
- Dramatically increase memory usage.
- Yield unstable training without improving convergence.
The final iteration alone provides sufficient gradient signal to learn the weight matrices.
Margin Loss
Classification uses a per-class margin loss:
where $T_c = 1$ if class $c$ is present, $m^+ = 0.9$, $m^- = 0.1$, and $\lambda_{\text{neg}} = 0.5$.
v_length = torch.sqrt((v**2).sum(dim=-1) + 1e-8)
T = torch.zeros(B, 10, device=v.device)
T.scatter_(1, labels.unsqueeze(1), 1.0)
present = T * F.relu(0.9 - v_length)**2
absent = 0.5 * (1-T) * F.relu(v_length - 0.1)**2
margin_loss = (present + absent).sum(dim=1).mean()
Interpretation:
- For the correct class: penalize if $\|\mathbf{v}_c\| < 0.9$ (push length up).
- For incorrect classes: penalize if $\|\mathbf{v}_c\| > 0.1$ (push length down).
- The $0.5$ factor down-weights the absent-class loss to prevent the network from killing all capsules early in training.
Reconstruction Decoder
The decoder serves as a regularizer, forcing the 16D capsule vectors to encode enough information to reconstruct the input image.
self.decoder = nn.Sequential(
nn.Linear(16, 512),
nn.ReLU(inplace=True),
nn.Linear(512, 1024),
nn.ReLU(inplace=True),
nn.Linear(1024, 784),
nn.Sigmoid()
)
During training, only the correct class capsule is fed to the decoder (all others are masked to zero). During inference, the capsule with the largest length is used.
The reconstruction loss is MSE, weighted by $\lambda_{\text{recon}} = 0.0005$:
The small weight ensures reconstruction serves as regularization without overwhelming the margin loss.
Putting It All Together
def forward(self, x, y=None):
out = F.relu(self.conv1(x)) # (B, 256, 20, 20)
primary = self.primary_caps(out) # (B, 1152, 8)
v = self.digit_caps(primary) # (B, 10, 16)
probs = torch.sqrt((v**2).sum(dim=-1) + 1e-8)
# Mask: keep only correct/predicted capsule
if y is not None:
mask = torch.zeros_like(probs)
mask.scatter_(1, y.unsqueeze(1), 1.0)
else:
_, idx = probs.max(dim=1)
mask = torch.zeros_like(probs)
mask.scatter_(1, idx.unsqueeze(1), 1.0)
masked_v = (v * mask.unsqueeze(-1)).sum(dim=1)
recon = self.decoder(masked_v)
return v, recon
Training Configuration
- Dataset: MNIST, 5,000-image training subset.
- Optimizer: Adam with learning rate 0.001.
- Batch size: 64.
- Epochs: 20.
- Routing iterations: 3.
Looking Ahead
With the implementation complete, Part 3 presents the experimental results: MNIST classification accuracy on standard and rotated digits, reconstruction quality from 16D capsule representations, and a direct comparison with a simple CNN baseline quantifying the rotation robustness advantage.