Back to ViTs Hub

Deconstructing ViTs from Scratch

Part 2: PyTorch Implementation

Introduction

Part 1 covered the math: images as patch sequences, the CLS token, and learnable positional embeddings. Now we turn every equation into working PyTorch code.

Four modules, built from scratch: PatchEmbedding, MultiHeadSelfAttention, TransformerBlock, and VisionTransformer. We implement attention manually -- no nn.MultiheadAttention -- so every matrix multiplication and softmax is visible.

PatchEmbedding: The Conv2d Trick

class PatchEmbedding(nn.Module):
    def __init__(self, img_size=32, patch_size=4,
                 in_channels=3, embed_dim=128):
        super().__init__()
        self.num_patches = (img_size // patch_size) ** 2
        # Conv2d trick: kernel_size=stride=patch_size
        self.projection = nn.Conv2d(
            in_channels, embed_dim,
            kernel_size=patch_size, stride=patch_size
        )
        self.cls_token = nn.Parameter(
            torch.randn(1, 1, embed_dim)
        )
        self.position_embeddings = nn.Parameter(
            torch.randn(1, 1 + self.num_patches, embed_dim)
        )

Tensor shapes through the pipeline for CIFAR-10 ($32 \times 32$, $P=4$, $D=128$):

MultiHeadSelfAttention: From Scratch

The heart of the Transformer. Scaled dot-product attention with multiple heads, built entirely from linear projections and matrix multiplications:

class MultiHeadSelfAttention(nn.Module):
    def __init__(self, embed_dim=128, num_heads=4):
        super().__init__()
        self.num_heads = num_heads
        self.head_dim = embed_dim // num_heads  # 32
        self.scale = self.head_dim ** -0.5

        self.W_q = nn.Linear(embed_dim, embed_dim)
        self.W_k = nn.Linear(embed_dim, embed_dim)
        self.W_v = nn.Linear(embed_dim, embed_dim)
        self.out_proj = nn.Linear(embed_dim, embed_dim)

The forward pass implements the attention equation directly:

$$ \text{Attention}(Q, K, V) = \text{softmax}\left(\frac{QK^\top}{\sqrt{d_k}}\right) V $$
def forward(self, x):
    B, N, C = x.shape
    # Project and reshape for multi-head
    Q = self.W_q(x).reshape(B, N, self.num_heads, self.head_dim).permute(0,2,1,3)
    K = self.W_k(x).reshape(B, N, self.num_heads, self.head_dim).permute(0,2,1,3)
    V = self.W_v(x).reshape(B, N, self.num_heads, self.head_dim).permute(0,2,1,3)

    # Scaled dot-product attention
    attn = (Q @ K.transpose(-2,-1)) * self.scale  # (B,h,N,N)
    attn = F.softmax(attn, dim=-1)
    out = attn @ V  # (B, h, N, d_k)

    # Concatenate heads and project
    out = out.permute(0,2,1,3).reshape(B, N, C)
    return self.out_proj(out), attn

We return the attention weights $\text{attn} \in \mathbb{R}^{B \times h \times N \times N}$ alongside the output -- useful for visualizing what the model looks at.

Why the $\sqrt{d_k}$ Scaling?

Without scaling, the dot products $QK^\top$ grow in magnitude with $d_k$. Large dot products push the softmax into near-one-hot distributions, which kills gradients. Dividing by $\sqrt{d_k}$ keeps the variance of the dot products near 1.

TransformerBlock: Pre-Norm Architecture

class TransformerBlock(nn.Module):
    def __init__(self, embed_dim=128, num_heads=4, mlp_ratio=4.0):
        super().__init__()
        self.norm1 = nn.LayerNorm(embed_dim)
        self.attn = MultiHeadSelfAttention(embed_dim, num_heads)
        self.norm2 = nn.LayerNorm(embed_dim)
        mlp_dim = int(embed_dim * mlp_ratio)
        self.mlp = nn.Sequential(
            nn.Linear(embed_dim, mlp_dim),   # 128 -> 512
            nn.GELU(),
            nn.Linear(mlp_dim, embed_dim),   # 512 -> 128
        )

    def forward(self, x):
        # Pre-norm MHSA + residual
        attn_out, attn_w = self.attn(self.norm1(x))
        x = x + attn_out
        # Pre-norm MLP + residual
        x = x + self.mlp(self.norm2(x))
        return x, attn_w

Pre-Norm vs Post-Norm

ViT applies LayerNorm before each sub-layer (pre-norm), not after. This stabilizes training for deeper models because the residual path stays unnormalized -- gradients flow through the identity connection without being rescaled.

GELU Activation

The MLP uses GELU instead of ReLU:

$$ \text{GELU}(x) = x \cdot \Phi(x) \approx 0.5 \cdot x \cdot \left(1 + \tanh\left[\sqrt{2/\pi}(x + 0.044715x^3)\right]\right) $$

GELU smoothly approximates ReLU, avoiding the hard zero-gradient cutoff. It has become the default activation in Transformer architectures.

The Residual Connection

Same principle as in ResNets. The residual connection $x + F(x)$ guarantees gradient flow:

$$ \frac{\partial \mathcal{L}}{\partial x} = \frac{\partial \mathcal{L}}{\partial y}\left(\frac{\partial F}{\partial x} + \mathbf{I}\right) $$

The identity term $\mathbf{I}$ means gradients always have a direct path, no matter the depth.

VisionTransformer: Full Assembly

class VisionTransformer(nn.Module):
    def __init__(self, img_size=32, patch_size=4,
                 in_channels=3, num_classes=10,
                 embed_dim=128, num_heads=4,
                 depth=6, mlp_ratio=4.0):
        super().__init__()
        self.patch_embed = PatchEmbedding(...)
        self.blocks = nn.ModuleList([
            TransformerBlock(...) for _ in range(depth)
        ])
        self.norm = nn.LayerNorm(embed_dim)
        self.head = nn.Linear(embed_dim, num_classes)

    def forward(self, x):
        x = self.patch_embed(x)        # (B, 65, 128)
        for block in self.blocks:
            x, _ = block(x)            # 6x Transformer
        x = self.norm(x)
        cls = x[:, 0]                  # CLS token
        return self.head(cls)           # (B, 10)

Architecture Summary

Our ViT-Tiny for CIFAR-10:

Parameter Count Breakdown

Where the 1.2M parameters live:

Over 98% of the parameters sit in the Transformer blocks, split roughly evenly between attention projections and the MLP.

Next: Training and Analysis

That gives us a complete Vision Transformer in pure PyTorch. In Part 3, we train it on CIFAR-10 alongside a CNN baseline and dig into training dynamics, attention maps, positional embeddings, and the data efficiency gap.