Part 1 covered the math: images as patch sequences, the CLS token, and learnable positional embeddings. Now we turn every equation into working PyTorch code.
Four modules, built from scratch: PatchEmbedding, MultiHeadSelfAttention, TransformerBlock, and VisionTransformer. We implement attention manually -- no nn.MultiheadAttention -- so every matrix multiplication and softmax is visible.
PatchEmbedding: The Conv2d Trick
class PatchEmbedding(nn.Module):
def __init__(self, img_size=32, patch_size=4,
in_channels=3, embed_dim=128):
super().__init__()
self.num_patches = (img_size // patch_size) ** 2
# Conv2d trick: kernel_size=stride=patch_size
self.projection = nn.Conv2d(
in_channels, embed_dim,
kernel_size=patch_size, stride=patch_size
)
self.cls_token = nn.Parameter(
torch.randn(1, 1, embed_dim)
)
self.position_embeddings = nn.Parameter(
torch.randn(1, 1 + self.num_patches, embed_dim)
)
Tensor shapes through the pipeline for CIFAR-10 ($32 \times 32$, $P=4$, $D=128$):
- Input: $(B, 3, 32, 32)$
- After Conv2d: $(B, 128, 8, 8)$ -- 64 patches, each projected to 128 dims
- After flatten + transpose: $(B, 64, 128)$
- After CLS prepend: $(B, 65, 128)$
- After position addition: $(B, 65, 128)$
MultiHeadSelfAttention: From Scratch
The heart of the Transformer. Scaled dot-product attention with multiple heads, built entirely from linear projections and matrix multiplications:
class MultiHeadSelfAttention(nn.Module):
def __init__(self, embed_dim=128, num_heads=4):
super().__init__()
self.num_heads = num_heads
self.head_dim = embed_dim // num_heads # 32
self.scale = self.head_dim ** -0.5
self.W_q = nn.Linear(embed_dim, embed_dim)
self.W_k = nn.Linear(embed_dim, embed_dim)
self.W_v = nn.Linear(embed_dim, embed_dim)
self.out_proj = nn.Linear(embed_dim, embed_dim)
The forward pass implements the attention equation directly:
def forward(self, x):
B, N, C = x.shape
# Project and reshape for multi-head
Q = self.W_q(x).reshape(B, N, self.num_heads, self.head_dim).permute(0,2,1,3)
K = self.W_k(x).reshape(B, N, self.num_heads, self.head_dim).permute(0,2,1,3)
V = self.W_v(x).reshape(B, N, self.num_heads, self.head_dim).permute(0,2,1,3)
# Scaled dot-product attention
attn = (Q @ K.transpose(-2,-1)) * self.scale # (B,h,N,N)
attn = F.softmax(attn, dim=-1)
out = attn @ V # (B, h, N, d_k)
# Concatenate heads and project
out = out.permute(0,2,1,3).reshape(B, N, C)
return self.out_proj(out), attn
We return the attention weights $\text{attn} \in \mathbb{R}^{B \times h \times N \times N}$ alongside the output -- useful for visualizing what the model looks at.
Why the $\sqrt{d_k}$ Scaling?
Without scaling, the dot products $QK^\top$ grow in magnitude with $d_k$. Large dot products push the softmax into near-one-hot distributions, which kills gradients. Dividing by $\sqrt{d_k}$ keeps the variance of the dot products near 1.
TransformerBlock: Pre-Norm Architecture
class TransformerBlock(nn.Module):
def __init__(self, embed_dim=128, num_heads=4, mlp_ratio=4.0):
super().__init__()
self.norm1 = nn.LayerNorm(embed_dim)
self.attn = MultiHeadSelfAttention(embed_dim, num_heads)
self.norm2 = nn.LayerNorm(embed_dim)
mlp_dim = int(embed_dim * mlp_ratio)
self.mlp = nn.Sequential(
nn.Linear(embed_dim, mlp_dim), # 128 -> 512
nn.GELU(),
nn.Linear(mlp_dim, embed_dim), # 512 -> 128
)
def forward(self, x):
# Pre-norm MHSA + residual
attn_out, attn_w = self.attn(self.norm1(x))
x = x + attn_out
# Pre-norm MLP + residual
x = x + self.mlp(self.norm2(x))
return x, attn_w
Pre-Norm vs Post-Norm
ViT applies LayerNorm before each sub-layer (pre-norm), not after. This stabilizes training for deeper models because the residual path stays unnormalized -- gradients flow through the identity connection without being rescaled.
GELU Activation
The MLP uses GELU instead of ReLU:
GELU smoothly approximates ReLU, avoiding the hard zero-gradient cutoff. It has become the default activation in Transformer architectures.
The Residual Connection
Same principle as in ResNets. The residual connection $x + F(x)$ guarantees gradient flow:
The identity term $\mathbf{I}$ means gradients always have a direct path, no matter the depth.
VisionTransformer: Full Assembly
class VisionTransformer(nn.Module):
def __init__(self, img_size=32, patch_size=4,
in_channels=3, num_classes=10,
embed_dim=128, num_heads=4,
depth=6, mlp_ratio=4.0):
super().__init__()
self.patch_embed = PatchEmbedding(...)
self.blocks = nn.ModuleList([
TransformerBlock(...) for _ in range(depth)
])
self.norm = nn.LayerNorm(embed_dim)
self.head = nn.Linear(embed_dim, num_classes)
def forward(self, x):
x = self.patch_embed(x) # (B, 65, 128)
for block in self.blocks:
x, _ = block(x) # 6x Transformer
x = self.norm(x)
cls = x[:, 0] # CLS token
return self.head(cls) # (B, 10)
Architecture Summary
Our ViT-Tiny for CIFAR-10:
- Image: $32 \times 32 \times 3$
- Patch size: $4 \times 4$ $\rightarrow$ 64 patches
- Embedding dimension: $D = 128$
- Heads: $h = 4$ (each with $d_k = 32$)
- Depth: $L = 6$ Transformer blocks
- MLP ratio: $4\times$ (hidden dim $= 512$)
- Total parameters: 1,205,898
Parameter Count Breakdown
Where the 1.2M parameters live:
- Patch Embedding Conv2d: $3 \times 128 \times 4 \times 4 + 128 = 6{,}272$
- CLS token: $128$
- Positional embeddings: $65 \times 128 = 8{,}320$
- Per Transformer block: Q, K, V projections ($49{,}536$) + Output projection ($16{,}512$) + MLP ($131{,}712$) + LayerNorms ($512$) = $\approx 198{,}272$
- 6 blocks: $6 \times 198{,}272 = 1{,}189{,}632$
- Classification head: $128 \times 10 + 10 = 1{,}290$
Over 98% of the parameters sit in the Transformer blocks, split roughly evenly between attention projections and the MLP.
Next: Training and Analysis
That gives us a complete Vision Transformer in pure PyTorch. In Part 3, we train it on CIFAR-10 alongside a CNN baseline and dig into training dynamics, attention maps, positional embeddings, and the data efficiency gap.