In Part 1, we established the mathematical foundations of Vision Transformers: images as patch sequences, the CLS token as a global aggregator, and learnable positional embeddings for spatial awareness. Now we translate every equation into working PyTorch code.
We build four modules from scratch: PatchEmbedding, MultiHeadSelfAttention, TransformerBlock, and VisionTransformer. Crucially, we implement attention manually -- no nn.MultiheadAttention -- to expose every matrix multiplication and softmax operation.
PatchEmbedding: The Conv2d Trick
class PatchEmbedding(nn.Module):
def __init__(self, img_size=32, patch_size=4,
in_channels=3, embed_dim=128):
super().__init__()
self.num_patches = (img_size // patch_size) ** 2
# Conv2d trick: kernel_size=stride=patch_size
self.projection = nn.Conv2d(
in_channels, embed_dim,
kernel_size=patch_size, stride=patch_size
)
self.cls_token = nn.Parameter(
torch.randn(1, 1, embed_dim)
)
self.position_embeddings = nn.Parameter(
torch.randn(1, 1 + self.num_patches, embed_dim)
)
For CIFAR-10 ($32 \times 32$, $P=4$, $D=128$):
- Input: $(B, 3, 32, 32)$
- After Conv2d: $(B, 128, 8, 8)$ -- 64 patches, each projected to 128 dims
- After flatten + transpose: $(B, 64, 128)$
- After CLS prepend: $(B, 65, 128)$
- After position addition: $(B, 65, 128)$
MultiHeadSelfAttention: From Scratch
This is the core of the Transformer. We implement scaled dot-product attention with multiple heads, entirely from linear projections and matrix multiplications:
class MultiHeadSelfAttention(nn.Module):
def __init__(self, embed_dim=128, num_heads=4):
super().__init__()
self.num_heads = num_heads
self.head_dim = embed_dim // num_heads # 32
self.scale = self.head_dim ** -0.5
self.W_q = nn.Linear(embed_dim, embed_dim)
self.W_k = nn.Linear(embed_dim, embed_dim)
self.W_v = nn.Linear(embed_dim, embed_dim)
self.out_proj = nn.Linear(embed_dim, embed_dim)
The forward pass follows the attention formula precisely:
def forward(self, x):
B, N, C = x.shape
# Project and reshape for multi-head
Q = self.W_q(x).reshape(B, N, self.num_heads, self.head_dim).permute(0,2,1,3)
K = self.W_k(x).reshape(B, N, self.num_heads, self.head_dim).permute(0,2,1,3)
V = self.W_v(x).reshape(B, N, self.num_heads, self.head_dim).permute(0,2,1,3)
# Scaled dot-product attention
attn = (Q @ K.transpose(-2,-1)) * self.scale # (B,h,N,N)
attn = F.softmax(attn, dim=-1)
out = attn @ V # (B, h, N, d_k)
# Concatenate heads and project
out = out.permute(0,2,1,3).reshape(B, N, C)
return self.out_proj(out), attn
The attention weights $\text{attn} \in \mathbb{R}^{B \times h \times N \times N}$ tell us exactly how much each token attends to every other token. We return these for visualization.
Why the $\sqrt{d_k}$ Scaling?
Without scaling, the dot products $QK^\top$ grow in magnitude with dimension $d_k$. For large $d_k$, the softmax saturates into near-one-hot distributions, producing vanishingly small gradients. Dividing by $\sqrt{d_k}$ keeps the variance of the dot products at approximately 1.
TransformerBlock: Pre-Norm Architecture
class TransformerBlock(nn.Module):
def __init__(self, embed_dim=128, num_heads=4, mlp_ratio=4.0):
super().__init__()
self.norm1 = nn.LayerNorm(embed_dim)
self.attn = MultiHeadSelfAttention(embed_dim, num_heads)
self.norm2 = nn.LayerNorm(embed_dim)
mlp_dim = int(embed_dim * mlp_ratio)
self.mlp = nn.Sequential(
nn.Linear(embed_dim, mlp_dim), # 128 -> 512
nn.GELU(),
nn.Linear(mlp_dim, embed_dim), # 512 -> 128
)
def forward(self, x):
# Pre-norm MHSA + residual
attn_out, attn_w = self.attn(self.norm1(x))
x = x + attn_out
# Pre-norm MLP + residual
x = x + self.mlp(self.norm2(x))
return x, attn_w
Pre-Norm vs Post-Norm
ViT uses pre-norm: LayerNorm is applied before each sub-layer. Pre-norm stabilizes training, especially for deeper models, because the residual path remains unnormalized -- gradients flow through the identity connection without being scaled by normalization.
GELU Activation
The MLP uses GELU (Gaussian Error Linear Unit) rather than ReLU:
GELU provides a smooth approximation to ReLU, avoiding the sharp zero-gradient region, and has become standard in Transformer architectures.
The Residual Connection
The residual connection $x + F(x)$ is the same principle from our ResNet series. The mathematical guarantee is identical:
The identity term $\mathbf{I}$ ensures gradients always flow, regardless of how many layers the network has.
VisionTransformer: Full Assembly
class VisionTransformer(nn.Module):
def __init__(self, img_size=32, patch_size=4,
in_channels=3, num_classes=10,
embed_dim=128, num_heads=4,
depth=6, mlp_ratio=4.0):
super().__init__()
self.patch_embed = PatchEmbedding(...)
self.blocks = nn.ModuleList([
TransformerBlock(...) for _ in range(depth)
])
self.norm = nn.LayerNorm(embed_dim)
self.head = nn.Linear(embed_dim, num_classes)
def forward(self, x):
x = self.patch_embed(x) # (B, 65, 128)
for block in self.blocks:
x, _ = block(x) # 6x Transformer
x = self.norm(x)
cls = x[:, 0] # CLS token
return self.head(cls) # (B, 10)
Architecture Summary
Our ViT-Tiny for CIFAR-10:
- Image: $32 \times 32 \times 3$
- Patch size: $4 \times 4$ $\rightarrow$ 64 patches
- Embedding dimension: $D = 128$
- Heads: $h = 4$ (each with $d_k = 32$)
- Depth: $L = 6$ Transformer blocks
- MLP ratio: $4\times$ (hidden dim $= 512$)
- Total parameters: 1,205,898
Parameter Count Breakdown
Where do the 1.2M parameters come from?
- Patch Embedding Conv2d: $3 \times 128 \times 4 \times 4 + 128 = 6{,}272$
- CLS token: $128$
- Positional embeddings: $65 \times 128 = 8{,}320$
- Per Transformer block: Q, K, V projections ($49{,}536$) + Output projection ($16{,}512$) + MLP ($131{,}712$) + LayerNorms ($512$) = $\approx 198{,}272$
- 6 blocks: $6 \times 198{,}272 = 1{,}189{,}632$
- Classification head: $128 \times 10 + 10 = 1{,}290$
The vast majority of parameters ($>$98%) are in the Transformer blocks, split roughly equally between attention projections and the MLP.
Next Steps: Training and Analysis
We have a complete, working Vision Transformer built entirely from scratch in PyTorch. In Part 3, we train it on CIFAR-10 alongside a CNN baseline and analyze training dynamics, attention maps, positional embeddings, and the data efficiency question.