Back to ViTs Hub

Deconstructing ViTs from Scratch

Part 2: PyTorch Implementation

Introduction

In Part 1, we established the mathematical foundations of Vision Transformers: images as patch sequences, the CLS token as a global aggregator, and learnable positional embeddings for spatial awareness. Now we translate every equation into working PyTorch code.

We build four modules from scratch: PatchEmbedding, MultiHeadSelfAttention, TransformerBlock, and VisionTransformer. Crucially, we implement attention manually -- no nn.MultiheadAttention -- to expose every matrix multiplication and softmax operation.

PatchEmbedding: The Conv2d Trick

class PatchEmbedding(nn.Module):
    def __init__(self, img_size=32, patch_size=4,
                 in_channels=3, embed_dim=128):
        super().__init__()
        self.num_patches = (img_size // patch_size) ** 2
        # Conv2d trick: kernel_size=stride=patch_size
        self.projection = nn.Conv2d(
            in_channels, embed_dim,
            kernel_size=patch_size, stride=patch_size
        )
        self.cls_token = nn.Parameter(
            torch.randn(1, 1, embed_dim)
        )
        self.position_embeddings = nn.Parameter(
            torch.randn(1, 1 + self.num_patches, embed_dim)
        )

For CIFAR-10 ($32 \times 32$, $P=4$, $D=128$):

MultiHeadSelfAttention: From Scratch

This is the core of the Transformer. We implement scaled dot-product attention with multiple heads, entirely from linear projections and matrix multiplications:

class MultiHeadSelfAttention(nn.Module):
    def __init__(self, embed_dim=128, num_heads=4):
        super().__init__()
        self.num_heads = num_heads
        self.head_dim = embed_dim // num_heads  # 32
        self.scale = self.head_dim ** -0.5

        self.W_q = nn.Linear(embed_dim, embed_dim)
        self.W_k = nn.Linear(embed_dim, embed_dim)
        self.W_v = nn.Linear(embed_dim, embed_dim)
        self.out_proj = nn.Linear(embed_dim, embed_dim)

The forward pass follows the attention formula precisely:

$$ \text{Attention}(Q, K, V) = \text{softmax}\left(\frac{QK^\top}{\sqrt{d_k}}\right) V $$
def forward(self, x):
    B, N, C = x.shape
    # Project and reshape for multi-head
    Q = self.W_q(x).reshape(B, N, self.num_heads, self.head_dim).permute(0,2,1,3)
    K = self.W_k(x).reshape(B, N, self.num_heads, self.head_dim).permute(0,2,1,3)
    V = self.W_v(x).reshape(B, N, self.num_heads, self.head_dim).permute(0,2,1,3)

    # Scaled dot-product attention
    attn = (Q @ K.transpose(-2,-1)) * self.scale  # (B,h,N,N)
    attn = F.softmax(attn, dim=-1)
    out = attn @ V  # (B, h, N, d_k)

    # Concatenate heads and project
    out = out.permute(0,2,1,3).reshape(B, N, C)
    return self.out_proj(out), attn

The attention weights $\text{attn} \in \mathbb{R}^{B \times h \times N \times N}$ tell us exactly how much each token attends to every other token. We return these for visualization.

Why the $\sqrt{d_k}$ Scaling?

Without scaling, the dot products $QK^\top$ grow in magnitude with dimension $d_k$. For large $d_k$, the softmax saturates into near-one-hot distributions, producing vanishingly small gradients. Dividing by $\sqrt{d_k}$ keeps the variance of the dot products at approximately 1.

TransformerBlock: Pre-Norm Architecture

class TransformerBlock(nn.Module):
    def __init__(self, embed_dim=128, num_heads=4, mlp_ratio=4.0):
        super().__init__()
        self.norm1 = nn.LayerNorm(embed_dim)
        self.attn = MultiHeadSelfAttention(embed_dim, num_heads)
        self.norm2 = nn.LayerNorm(embed_dim)
        mlp_dim = int(embed_dim * mlp_ratio)
        self.mlp = nn.Sequential(
            nn.Linear(embed_dim, mlp_dim),   # 128 -> 512
            nn.GELU(),
            nn.Linear(mlp_dim, embed_dim),   # 512 -> 128
        )

    def forward(self, x):
        # Pre-norm MHSA + residual
        attn_out, attn_w = self.attn(self.norm1(x))
        x = x + attn_out
        # Pre-norm MLP + residual
        x = x + self.mlp(self.norm2(x))
        return x, attn_w

Pre-Norm vs Post-Norm

ViT uses pre-norm: LayerNorm is applied before each sub-layer. Pre-norm stabilizes training, especially for deeper models, because the residual path remains unnormalized -- gradients flow through the identity connection without being scaled by normalization.

GELU Activation

The MLP uses GELU (Gaussian Error Linear Unit) rather than ReLU:

$$ \text{GELU}(x) = x \cdot \Phi(x) \approx 0.5 \cdot x \cdot \left(1 + \tanh\left[\sqrt{2/\pi}(x + 0.044715x^3)\right]\right) $$

GELU provides a smooth approximation to ReLU, avoiding the sharp zero-gradient region, and has become standard in Transformer architectures.

The Residual Connection

The residual connection $x + F(x)$ is the same principle from our ResNet series. The mathematical guarantee is identical:

$$ \frac{\partial \mathcal{L}}{\partial x} = \frac{\partial \mathcal{L}}{\partial y}\left(\frac{\partial F}{\partial x} + \mathbf{I}\right) $$

The identity term $\mathbf{I}$ ensures gradients always flow, regardless of how many layers the network has.

VisionTransformer: Full Assembly

class VisionTransformer(nn.Module):
    def __init__(self, img_size=32, patch_size=4,
                 in_channels=3, num_classes=10,
                 embed_dim=128, num_heads=4,
                 depth=6, mlp_ratio=4.0):
        super().__init__()
        self.patch_embed = PatchEmbedding(...)
        self.blocks = nn.ModuleList([
            TransformerBlock(...) for _ in range(depth)
        ])
        self.norm = nn.LayerNorm(embed_dim)
        self.head = nn.Linear(embed_dim, num_classes)

    def forward(self, x):
        x = self.patch_embed(x)        # (B, 65, 128)
        for block in self.blocks:
            x, _ = block(x)            # 6x Transformer
        x = self.norm(x)
        cls = x[:, 0]                  # CLS token
        return self.head(cls)           # (B, 10)

Architecture Summary

Our ViT-Tiny for CIFAR-10:

Parameter Count Breakdown

Where do the 1.2M parameters come from?

The vast majority of parameters ($>$98%) are in the Transformer blocks, split roughly equally between attention projections and the MLP.

Next Steps: Training and Analysis

We have a complete, working Vision Transformer built entirely from scratch in PyTorch. In Part 3, we train it on CIFAR-10 alongside a CNN baseline and analyze training dynamics, attention maps, positional embeddings, and the data efficiency question.