Part 1 covered why residual learning works -- learning deviations from identity gives gradients a clear path through deep networks. Now we translate that math into PyTorch, building from a single residual block up to the full ResNet-18/34/50/101/152 family and a smaller variant for CIFAR-10.
The Residual Block
The ResidualBlock implements $y = F(x) + x$: two convolutions with a skip connection that bypasses both. One detail matters here -- ReLU comes after the addition, not before. If you put ReLU before the addition, the skip connection no longer passes a clean signal.
class ResidualBlock(nn.Module):
expansion = 1 # Output channels = in_channels * expansion
def __init__(self, in_channels, out_channels, stride=1, downsample=None):
super().__init__()
self.conv1 = nn.Conv2d(in_channels, out_channels, kernel_size=3,
stride=stride, padding=1, bias=False)
self.bn1 = nn.BatchNorm2d(out_channels)
self.conv2 = nn.Conv2d(out_channels, out_channels, kernel_size=3,
stride=1, padding=1, bias=False)
self.bn2 = nn.BatchNorm2d(out_channels)
self.downsample = downsample
def forward(self, x):
identity = x
out = F.relu(self.bn1(self.conv1(x))) # 3x3 conv -> BN -> ReLU
out = self.bn2(self.conv2(out)) # 3x3 conv -> BN (no ReLU yet)
if self.downsample is not None:
identity = self.downsample(x)
out += identity # Skip connection!
out = F.relu(out) # Final ReLU after addition
return out
Handling Dimension Changes
When spatial dimensions or channel counts change between layers, the identity shortcut needs a $1 \times 1$ convolution to match:
# Create downsampling for skip connection if needed
if stride != 1 or self.in_channels != out_channels * block.expansion:
downsample = nn.Sequential(
nn.Conv2d(self.in_channels, out_channels * block.expansion,
kernel_size=1, stride=stride, bias=False),
nn.BatchNorm2d(out_channels * block.expansion),
)
Bottleneck Blocks
For ResNet-50 and deeper, running $3 \times 3$ convolutions at high channel counts gets expensive. Bottleneck blocks use a $1 \times 1 \rightarrow 3 \times 3 \rightarrow 1 \times 1$ pattern to keep computation manageable:
- $1 \times 1$ conv: Reduce channels (e.g., 256 → 64)
- $3 \times 3$ conv: Do the actual spatial processing at reduced width
- $1 \times 1$ conv: Expand back (64 → 256)
The expansion factor is 4, so output channels are always 4x the bottleneck width.
class BottleneckBlock(nn.Module):
expansion = 4 # Output = in_channels * 4
def __init__(self, in_channels, out_channels, stride=1, downsample=None):
super().__init__()
self.conv1 = nn.Conv2d(in_channels, out_channels, kernel_size=1,
stride=1, bias=False) # Squeeze
self.conv2 = nn.Conv2d(out_channels, out_channels, kernel_size=3,
stride=stride, padding=1, bias=False) # Process
self.conv3 = nn.Conv2d(out_channels, out_channels * self.expansion,
kernel_size=1, stride=1, bias=False) # Expand
def forward(self, x):
identity = x
out = F.relu(self.bn1(self.conv1(x))) # 1x1 squeeze
out = F.relu(self.bn2(self.conv2(out))) # 3x3 process
out = self.bn3(self.conv3(out)) # 1x1 expand (no ReLU)
if self.downsample is not None:
identity = self.downsample(x)
out += identity
return F.relu(out)
The ResNet Family
Same macro-architecture, different block types and depths:
ResNet-18/34 (Basic Blocks)
ResNet-18: [2, 2, 2, 2] basic blocks # 11.7M params
ResNet-34: [3, 4, 6, 3] basic blocks # 21.8M params
ResNet-50/101/152 (Bottleneck Blocks)
ResNet-50: [3, 4, 6, 3] bottleneck # 25.6M params
ResNet-101: [3, 4, 23, 3] bottleneck # 44.5M params
ResNet-152: [3, 8, 36, 3] bottleneck # 60.2M params
SmallResNet for CIFAR-10
The standard ResNet stem -- a $7 \times 7$ conv followed by max pooling -- is designed for $224 \times 224$ ImageNet images. On $32 \times 32$ CIFAR-10 inputs, that would destroy spatial information immediately. Our SmallResNet makes four changes:
- $3 \times 3$ conv stem with stride 1 (instead of $7 \times 7$ stride 2)
- No initial max pooling
- 3 residual stages instead of 4
- 16 initial channels instead of 64
This gives a 175,258-parameter model -- small enough to train on CPU in a few minutes.
class SmallResNet(nn.Module):
def __init__(self, block, blocks_per_layer, num_classes=10,
in_channels=3, initial_channels=16):
super().__init__()
self.in_channels = initial_channels
# 3x3 stem (no maxpool for small images)
self.conv1 = nn.Conv2d(in_channels, initial_channels, kernel_size=3,
stride=1, padding=1, bias=False)
self.bn1 = nn.BatchNorm2d(initial_channels)
# 3 residual stages (not 4)
self.layer1 = self._make_layer(block, initial_channels, blocks_per_layer[0], stride=1)
self.layer2 = self._make_layer(block, initial_channels * 2, blocks_per_layer[1], stride=2)
self.layer3 = self._make_layer(block, initial_channels * 4, blocks_per_layer[2], stride=2)
# Global average pooling + classifier
self.avgpool = nn.AdaptiveAvgPool2d((1, 1))
self.fc = nn.Linear(initial_channels * 4 * block.expansion, num_classes)
Weight Initialization
We use He (Kaiming) initialization for all convolutions. Xavier initialization underestimates the variance needed after ReLU, which zeros out half the distribution:
Batch norm weights go to 1, biases to 0.
def _initialize_weights(self):
for m in self.modules():
if isinstance(m, nn.Conv2d):
nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu')
elif isinstance(m, nn.BatchNorm2d):
nn.init.constant_(m.weight, 1)
nn.init.constant_(m.bias, 0)
Training Strategy
For our SmallResNet-18 on a 5,000-image CIFAR-10 subset, 30 epochs:
- Data augmentation: Random horizontal flip + random crop with 4px padding
- Optimizer: SGD, momentum 0.9, weight decay $5 \times 10^{-4}$
- Learning rate: 0.1, decayed by 10x at epochs 15 and 25
- Batch size: 128
model = SmallResNet18(num_classes=10, in_channels=3).to(device)
criterion = nn.CrossEntropyLoss()
optimizer = optim.SGD(model.parameters(), lr=0.1, momentum=0.9, weight_decay=5e-4)
scheduler = optim.lr_scheduler.MultiStepLR(optimizer, milestones=[15, 25], gamma=0.1)
Next: Training Results
The implementation is done. Part 3 trains this model, plots the loss and accuracy curves, and visualizes activation magnitudes layer-by-layer to confirm that skip connections actually preserve signal flow through the network.