This post translates the CapsNet math from Part 1 into PyTorch. We implement every component from scratch: squashing, PrimaryCapsule, DigitCapsule with dynamic routing, margin loss, and the reconstruction decoder.
The paper glosses over several implementation details -- tensor shapes inside routing, gradient detachment, loss weighting -- so we walk through those here.
Architecture Overview
The CapsNet for MNIST has four stages:
- Conv1:
Conv2d(1, 256, kernel_size=9, stride=1)+ ReLU. $(B, 1, 28, 28) \to (B, 256, 20, 20)$ - PrimaryCapsule:
Conv2d(256, 256, kernel_size=9, stride=2)+ Reshape + Squash. Output: $(B, 1152, 8)$ -- 1,152 capsules of dimension 8 - DigitCapsule: Dynamic routing with learned weight matrices. Output: $(B, 10, 16)$ -- 10 class capsules of dimension 16
- Decoder:
Linear(16 -> 512 -> 1024 -> 784)with masking. Reconstructs the input from the winning capsule
Total parameters: 8,141,840.
The Squashing Function
def squash(s, dim=-1):
squared_norm = (s ** 2).sum(dim=dim, keepdim=True)
safe_norm = torch.sqrt(squared_norm + 1e-8)
scale = squared_norm / (1.0 + squared_norm)
return scale * s / safe_norm
The 1e-8 inside the square root prevents division by zero when a capsule has near-zero activation. Without it, gradients go to NaN.
Mathematically:
keepdim=True is essential so the scaling factor broadcasts correctly against the input tensor.
PrimaryCapsule Layer
Converts convolutional features into capsule vectors.
class PrimaryCapsule(nn.Module):
def __init__(self, in_channels=256, num_types=32,
caps_dim=8, kernel_size=9, stride=2):
super().__init__()
self.caps_dim = caps_dim
self.conv = nn.Conv2d(in_channels,
num_types * caps_dim,
kernel_size, stride)
def forward(self, x):
out = self.conv(x) # (B, 256, 6, 6)
B = out.size(0)
out = out.view(B, 32, 8, -1) # (B, 32, 8, 36)
out = out.permute(0,1,3,2) # (B, 32, 36, 8)
out = out.reshape(B, -1, 8) # (B, 1152, 8)
return squash(out)
Shape Walkthrough
- Input: $(B, 256, 20, 20)$ from Conv1.
- After
Conv2d(256, 256, 9, stride=2): spatial size $(20-9)/2+1 = 6$, so $(B, 256, 6, 6)$. - 256 channels = 32 capsule types $\times$ 8 dimensions.
- Reshape and permute to $(B, 1152, 8)$: that's $32 \times 6 \times 6 = 1{,}152$ capsules, each 8D.
- Squash to enforce length $\in [0, 1)$.
DigitCapsule Layer and Dynamic Routing
Each of the 1,152 primary capsules routes to one of 10 digit capsules. This is where the architecture gets interesting.
Weight Matrix
# W: (1, 1152, 10, 16, 8)
self.W = nn.Parameter(
torch.randn(1, 1152, 10, 16, 8) * 0.01
)
A 5D tensor holding $1{,}152 \times 10$ transformation matrices, each $16 \times 8$. For primary capsule $i$ and digit capsule $j$, $\mathbf{W}_{ij}$ transforms the 8D input to a 16D prediction:
Computing Prediction Vectors
# u: (B, 1152, 8) -> (B, 1152, 1, 8, 1)
u = u.unsqueeze(2).unsqueeze(4)
# W: (1, 1152, 10, 16, 8) @ u: (B, 1152, 1, 8, 1)
# -> u_hat: (B, 1152, 10, 16, 1) -> squeeze -> (B, 1152, 10, 16)
u_hat = torch.matmul(self.W, u).squeeze(-1)
Broadcasting handles the batch and 10-class dimensions in one shot. Result: 1,152 prediction vectors of dimension 16 for each of 10 classes.
Dynamic Routing Implementation
def dynamic_routing(self, u_hat, u_hat_detached, B):
b = torch.zeros(B, 1152, 10, device=u_hat.device)
for r in range(self.num_routing):
c = F.softmax(b, dim=2) # coupling coefficients
if r < self.num_routing - 1:
s = (c.unsqueeze(-1) * u_hat_detached).sum(dim=1)
v = squash(s)
agreement = (u_hat_detached * v.unsqueeze(1)).sum(-1)
b = b + agreement
else:
# Last iteration: allow gradients through
s = (c.unsqueeze(-1) * u_hat).sum(dim=1)
v = squash(s)
return v # (B, 10, 16)
Gradient Detachment
Iterations 1 and 2 use u_hat_detached -- prediction vectors with gradients cut off. Only the final iteration uses u_hat with live gradients.
Backpropagating through all routing iterations builds a deep computation graph that causes vanishing/exploding gradients, high memory usage, and unstable training with no convergence benefit. The last iteration alone gives enough gradient signal to learn the weights.
Margin Loss
Per-class margin loss for classification:
where $T_c = 1$ if class $c$ is present, $m^+ = 0.9$, $m^- = 0.1$, and $\lambda_{\text{neg}} = 0.5$.
v_length = torch.sqrt((v**2).sum(dim=-1) + 1e-8)
T = torch.zeros(B, 10, device=v.device)
T.scatter_(1, labels.unsqueeze(1), 1.0)
present = T * F.relu(0.9 - v_length)**2
absent = 0.5 * (1-T) * F.relu(v_length - 0.1)**2
margin_loss = (present + absent).sum(dim=1).mean()
- Correct class: penalize if $\|\mathbf{v}_c\| < 0.9$ (push length up).
- Wrong classes: penalize if $\|\mathbf{v}_c\| > 0.1$ (push length down).
- The $0.5$ down-weight on absent classes prevents the network from killing all capsules early in training.
Reconstruction Decoder
The decoder regularizes training by forcing the 16D capsule vectors to retain enough information to reconstruct the input.
self.decoder = nn.Sequential(
nn.Linear(16, 512),
nn.ReLU(inplace=True),
nn.Linear(512, 1024),
nn.ReLU(inplace=True),
nn.Linear(1024, 784),
nn.Sigmoid()
)
During training, only the correct class capsule feeds into the decoder (others masked to zero). At inference, the longest capsule is used.
Reconstruction loss is MSE, weighted by $\lambda_{\text{recon}} = 0.0005$:
The small weight keeps reconstruction as a regularizer without drowning out the margin loss.
Putting It All Together
def forward(self, x, y=None):
out = F.relu(self.conv1(x)) # (B, 256, 20, 20)
primary = self.primary_caps(out) # (B, 1152, 8)
v = self.digit_caps(primary) # (B, 10, 16)
probs = torch.sqrt((v**2).sum(dim=-1) + 1e-8)
# Mask: keep only correct/predicted capsule
if y is not None:
mask = torch.zeros_like(probs)
mask.scatter_(1, y.unsqueeze(1), 1.0)
else:
_, idx = probs.max(dim=1)
mask = torch.zeros_like(probs)
mask.scatter_(1, idx.unsqueeze(1), 1.0)
masked_v = (v * mask.unsqueeze(-1)).sum(dim=1)
recon = self.decoder(masked_v)
return v, recon
Training Configuration
- Dataset: MNIST, 5,000-image training subset.
- Optimizer: Adam, lr = 0.001.
- Batch size: 64.
- Epochs: 20.
- Routing iterations: 3.
Next
Part 3 tests the trained CapsNet: standard MNIST accuracy, reconstruction quality from 16D vectors, and rotation robustness compared to a simple CNN baseline.