Back to DiT Hub

Deconstructing DiT

Part 2: 250 Lines of PyTorch

Overview

Part 1 explained why DiT replaced the U-Net: not because attention is inherently better than convolution, but because attention's lack of architectural inductive bias becomes an asset once you have enough training data. Now we implement DiT from scratch — patch embedding, the DiTBlock with adaLN-Zero conditioning, the full model, the DDPM training procedure, and ancestral sampling.

The full implementation is roughly $250$ lines across three files. The interesting parts are the conditioning mechanism (adaLN-Zero is genuinely different from anything in earlier Transformer literature), the zero-initialisation trick that makes early training stable, and the surprising amount of standard ViT machinery that the model inherits unchanged.

File Map

Three files: diffusion.py contains the DDPM noise scheduler and the sinusoidal time-embedding function (~80 lines). dit.py contains PatchEmbed, Attention, DiTBlock, and the full DiT model (~170 lines). train.py handles dataset generation, the training loop, and sampling.

The DDPM machinery is separated from the model because it is reusable across architectures — the same noise scheduler works with a U-Net, a DiT, or any other noise predictor. The model file is self-contained: PatchEmbed and Attention are private to it, no shared modules with other series.

The Noise Scheduler

class DiffusionScheduler:
    def __init__(self, T=200, device="cpu"):
        betas = torch.linspace(1e-4, 0.02, T)
        alphas = 1.0 - betas
        alpha_bars = torch.cumprod(alphas, dim=0)
        alpha_bars_prev = torch.cat([torch.ones(1), alpha_bars[:-1]])

        self.T = T
        self.betas = betas.to(device)
        self.alphas = alphas.to(device)
        self.alpha_bars = alpha_bars.to(device)
        self.alpha_bars_prev = alpha_bars_prev.to(device)
        self.sqrt_alpha_bars = torch.sqrt(alpha_bars).to(device)
        self.sqrt_one_minus_alpha_bars = torch.sqrt(1 - alpha_bars).to(device)
        self.posterior_var = (betas * (1 - alpha_bars_prev) /
                              (1 - alpha_bars)).to(device)

The constructor precomputes every quantity that the training and sampling loops will need. Storing these as buffers (instead of recomputing them each step) is the standard performance optimisation; the redundancy is roughly $6T$ float values for $T = 200$, which is nothing in memory but saves several FLOPs per training step.

The linear $\beta$ schedule. $\beta_t$ controls how much noise is added at step $t$. Ho et al.'s DDPM paper used $\beta$ linearly interpolating from $10^{-4}$ to $0.02$ over $T = 1000$ steps. Our $T = 200$ uses the same endpoints with a steeper slope, which gives faster diffusion at the cost of slightly worse final-noise quality. For toy experiments this is fine; for production image generation you would use $T = 1000$ and a cosine schedule (which adds noise more uniformly over $t$).

The cumulative product $\bar{\alpha}_t = \prod_{s \leq t} \alpha_s$. This is the key derived quantity for DDPM. It tells you: starting from $x_0$, after $t$ diffusion steps, the resulting $x_t$ has the closed form $x_t = \sqrt{\bar{\alpha}_t} \, x_0 + \sqrt{1 - \bar{\alpha}_t} \, \varepsilon$, where $\varepsilon \sim \mathcal{N}(0, I)$. The cumulative product means we never have to iteratively apply the noise step — we can jump straight to any $t$.

The single-shot forward process.

def q_sample(self, x0, t, eps):
    sqrt_ab   = self.sqrt_alpha_bars[t].reshape(-1, 1, 1, 1)
    sqrt_omab = self.sqrt_one_minus_alpha_bars[t].reshape(-1, 1, 1, 1)
    return sqrt_ab * x0 + sqrt_omab * eps

Two coefficients are looked up by indexing into the precomputed tensors with the timestep $t$. The .reshape(-1, 1, 1, 1) turns the scalar coefficients into broadcast-compatible $(B, 1, 1, 1)$ tensors that multiply with images of shape $(B, C, H, W)$. This is the only PyTorch idiom that takes a moment to recognise; once you do, the four-line implementation maps directly to the DDPM equations.

Sinusoidal Time Embedding

def timestep_embedding(t, dim, max_period=10000):
    half = dim // 2
    freqs = torch.exp(-math.log(max_period) * torch.arange(0, half, device=t.device) / half)
    args = t.float().unsqueeze(-1) * freqs.unsqueeze(0)
    embed = torch.cat([torch.cos(args), torch.sin(args)], dim=-1)
    if dim % 2 == 1:
        embed = torch.cat([embed, torch.zeros_like(embed[:, :1])], dim=-1)
    return embed

This is the same sinusoidal embedding as in the original Transformer paper — but applied to diffusion timesteps instead of token positions. The intuition: we want each timestep $t \in [0, T-1]$ to get a unique, smooth representation. The sinusoidal embedding satisfies both properties: different $t$ values produce different embeddings, but adjacent $t$ values produce similar embeddings (which is what gradient descent needs to learn smooth time-dependent functions).

The $\max\_\text{period}$ controls the lowest frequency. With $\max\_\text{period} = 10{,}000$ and $\dim/2 = 32$ frequencies, the wavelengths span roughly $1$ to $10{,}000$ — covering all $T = 200$ timesteps with plenty of resolution. The implementation matches the original Transformer's exactly.

The PatchEmbed

class PatchEmbed(nn.Module):
    def __init__(self, img_size, patch_size, in_channels, hidden):
        super().__init__()
        assert img_size % patch_size == 0
        self.num_patches = (img_size // patch_size) ** 2
        self.patch_size = patch_size
        self.proj = nn.Conv2d(in_channels, hidden, kernel_size=patch_size,
                              stride=patch_size)

    def forward(self, x):
        x = self.proj(x)                       # (B, hidden, H/p, W/p)
        return x.flatten(2).transpose(1, 2)    # (B, N_patches, hidden)

A stride-$p$ convolution does the patch embedding in one operation. Each $p \times p$ patch of input pixels gets convolved into one $\text{hidden}$-dimensional token. For a $16 \times 16$ image with patch size $4$, the output is $4 \times 4 = 16$ tokens of dimension $\text{hidden}$.

This conv-as-patch-embed trick is the standard ViT pattern. Alternative implementations use nn.Unfold + nn.Linear — mathematically equivalent, slightly less efficient because the conv kernel fuses the two operations. The conv version is also more familiar to anyone who has read existing ViT code.

After the conv, the output is a $4$D tensor $(B, \text{hidden}, H/p, W/p)$. The flatten(2).transpose(1, 2) converts it to the Transformer-friendly $(B, N, C)$ shape that the rest of the model expects.

The Attention Module

class Attention(nn.Module):
    def __init__(self, hidden, n_heads):
        super().__init__()
        assert hidden % n_heads == 0
        self.n_heads = n_heads
        self.d_head = hidden // n_heads
        self.qkv = nn.Linear(hidden, 3 * hidden)
        self.proj = nn.Linear(hidden, hidden)

    def forward(self, x):
        B, N, C = x.shape
        q, k, v = self.qkv(x).split(C, dim=-1)
        q = q.reshape(B, N, self.n_heads, self.d_head).transpose(1, 2)
        k = k.reshape(B, N, self.n_heads, self.d_head).transpose(1, 2)
        v = v.reshape(B, N, self.n_heads, self.d_head).transpose(1, 2)
        att = (q @ k.transpose(-2, -1)) / math.sqrt(self.d_head)
        att = F.softmax(att, dim=-1)
        y = (att @ v).transpose(1, 2).contiguous().reshape(B, N, C)
        return self.proj(y)

Standard multi-head self-attention. No causal mask — this is the key difference from the TinyGPT attention in an earlier series. DiT is processing an image as a set of patches; we want each patch to attend to all others, including "future" ones in the patch ordering. The lack of a mask is what makes this bidirectional attention (the same as a ViT, not the same as a GPT).

The QKV projection into a single linear, then split, is the standard efficiency trick. Reshape to expose the head dimension, transpose to put heads as the second axis (so the matmul $q @ k^T$ operates over the right axes), compute attention, transpose back, project. Twelve lines of forward; this is unchanged from any other ViT.

The DiTBlock with adaLN-Zero

This is where DiT diverges from standard ViT. The block applies attention and MLP with residual connections, but the LayerNorms are data-dependent — scale and shift come from the conditioning vector $c$ rather than from learned per-layer parameters.

class DiTBlock(nn.Module):
    def __init__(self, hidden, n_heads, mlp_ratio=4.0):
        super().__init__()
        self.norm1 = nn.LayerNorm(hidden, elementwise_affine=False)
        self.attn = Attention(hidden, n_heads)
        self.norm2 = nn.LayerNorm(hidden, elementwise_affine=False)
        h = int(hidden * mlp_ratio)
        self.mlp = nn.Sequential(
            nn.Linear(hidden, h), nn.GELU(),
            nn.Linear(h, hidden),
        )
        self.adaLN_modulation = nn.Linear(hidden, 6 * hidden)
        nn.init.zeros_(self.adaLN_modulation.weight)
        nn.init.zeros_(self.adaLN_modulation.bias)

    def forward(self, x, c):
        params = self.adaLN_modulation(c).unsqueeze(1)        # (B, 1, 6H)
        gamma1, beta1, alpha1, gamma2, beta2, alpha2 = params.chunk(6, dim=-1)
        x = x + alpha1 * self.attn(self.norm1(x) * (1 + gamma1) + beta1)
        x = x + alpha2 * self.mlp (self.norm2(x) * (1 + gamma2) + beta2)
        return x

Four critical design choices are visible here.

(1) elementwise_affine=False on the LayerNorms. The standard nn.LayerNorm has learnable $\gamma$ and $\beta$ parameters per element. We turn those off — the LayerNorm just normalizes, the affine comes from elsewhere. This is essential because adaLN's whole point is to replace the per-layer learned affine with a data-dependent one.

(2) Six modulation parameters per block. The line self.adaLN_modulation = nn.Linear(hidden, 6 * hidden) projects the $\text{hidden}$-dimensional conditioning vector to $6 \times \text{hidden}$ parameters: scale ($\gamma$) and shift ($\beta$) for each LayerNorm, plus a residual-gate ($\alpha$) for each sub-layer. The .chunk(6, dim=-1) splits this into the six pieces.

(3) Zero initialisation of adaLN_modulation. The two nn.init.zeros_ lines are non-negotiable. At initialisation, this gives $\gamma = 0$, $\beta = 0$, $\alpha = 0$, so each sub-layer's residual contribution is exactly zero. The block is identity at step zero: $x \to x$. The model gradually learns to un-zero the modulation MLP as training proceeds. Empirically this is the difference between "trains stably" and "training collapses in the first 100 steps".

(4) The modulation pattern. Reading the forward pass: $\text{LayerNorm}(x) \cdot (1 + \gamma) + \beta$ is the data-dependent affine — multiply by $1 + \gamma$ rather than $\gamma$ because we want the default $\gamma = 0$ behaviour to be "scale by 1" (identity). Then $\alpha \cdot \text{sub-layer}(\cdot)$ is the residual gate — when $\alpha = 0$, the sub-layer contributes nothing.

Why adaLN-Zero Works So Well

adaLN-Zero is one of those design choices that looks arbitrary at first glance and turns out to be load-bearing. The benefits compared to cross-attention conditioning (the alternative used in earlier diffusion models):

Simplicity: adaLN-Zero only adds a small linear layer per block. Cross-attention adds full attention modules.

Same expressivity for scalar conditioning: When the conditioning is a single vector (time + class), there is no sequence to attend over. Modulation is mathematically sufficient.

Zero initialisation stability: Cross-attention has random initial weights and contributes random output from step zero. adaLN-Zero contributes nothing from step zero. The latter is much easier to train.

For text-conditioning at scale: Models like SD3 and Sora add a small text cross-attention back to handle variable-length text inputs, but they keep adaLN-Zero for the time-conditioning side. The pattern is: adaLN-Zero for everything you can express as a single vector, cross-attention for sequences.

The Full DiT Model

class DiT(nn.Module):
    def __init__(self, img_size=16, patch_size=4, in_channels=3,
                 hidden=64, n_layers=4, n_heads=4, num_classes=16,
                 use_class_cond=True):
        super().__init__()
        self.cfg = ...  # store config
        self.patch_embed = PatchEmbed(img_size, patch_size, in_channels, hidden)
        n_patches = self.patch_embed.num_patches
        self.pos_embed = nn.Parameter(torch.zeros(1, n_patches, hidden))
        nn.init.normal_(self.pos_embed, std=0.02)

        self.t_embed = nn.Sequential(nn.Linear(hidden, hidden), nn.GELU(),
                                     nn.Linear(hidden, hidden))
        if use_class_cond:
            self.y_embed = nn.Embedding(num_classes, hidden)
        self.blocks = nn.ModuleList([DiTBlock(hidden, n_heads) for _ in range(n_layers)])

        self.final_norm   = nn.LayerNorm(hidden, elementwise_affine=False)
        self.final_mod    = nn.Linear(hidden, 2 * hidden)
        self.final_linear = nn.Linear(hidden, patch_size * patch_size * in_channels)
        nn.init.zeros_(self.final_mod.weight)
        nn.init.zeros_(self.final_mod.bias)
        nn.init.zeros_(self.final_linear.weight)
        nn.init.zeros_(self.final_linear.bias)

The model is a stack of DiT blocks bookended by patch-embed-in and un-patchify-and-project-out. The final projection is also zero-initialised, so the very first forward pass predicts zero noise — equivalent to "predict that $x_0 \approx x_t$", a safe starting point that gives gradients useful information from the start.

The conditioning pipeline. The model conditions on (timestep, class):

def _condition(self, t, y):
    t_e = timestep_embedding(t, self.cfg.hidden)        # (B, H) sinusoidal
    t_e = self.t_embed(t_e)                              # MLP refinement
    if self.use_class_cond and y is not None:
        y_e = self.y_embed(y)
        return t_e + y_e
    return t_e

Sinusoidal time embedding, MLP refinement, class embedding added if available. The result is one conditioning vector per (image, timestep, class) that every DiT block sees. This conditioning is what flows into the adaLN-Zero modulation MLPs in each block.

Un-Patchifying the Output

def unpatchify(self, x):
    B = x.size(0)
    p = self.patch_size
    H = W = self.img_size // p
    x = x.reshape(B, H, W, p, p, self.in_channels)
    x = x.permute(0, 5, 1, 3, 2, 4).contiguous()
    return x.reshape(B, self.in_channels, H * p, W * p)

The model outputs $(B, N, P^2 \cdot C)$ — one feature vector per patch, where the feature is the flattened pixel patch. We need to fold this back to $(B, C, H, W)$ for the loss computation.

The reshape into $(B, H, W, p, p, C)$ recovers the patch grid plus per-patch pixel layout. The permute(0, 5, 1, 3, 2, 4) interleaves the patch index and within-patch index correctly. The final reshape merges them into pixel coordinates.

This is the most error-prone part of any ViT-style model. The standard mistake is getting the permute axis order wrong, which produces an image with the right pixel values but the wrong spatial layout — patches arranged correctly relative to each other but flipped or transposed internally. The fix is testing: ensure that patchify(unpatchify(x)) == x for some test pattern. If you write a ViT from scratch, write that test first.

The DDPM Training Loop

for epoch in range(EPOCHS):
    for x0, y in batches:
        t   = torch.randint(0, T, (x0.size(0),))
        eps = torch.randn_like(x0)
        x_t = scheduler.q_sample(x0, t, eps)
        eps_pred = model(x_t, t, y)
        loss = F.mse_loss(eps_pred, eps)
        opt.zero_grad(set_to_none=True)
        loss.backward()
        torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)
        opt.step()

Standard DDPM. Sample a random timestep, sample fresh Gaussian noise, apply the closed-form forward process to get the noisy image, ask the model to predict the noise, compute MSE between predicted and actual noise. That is the entire training objective.

Why predicting $\varepsilon$ instead of $x_0$. The DDPM loss can be written equivalently as predicting the noise $\varepsilon$ or predicting the clean image $x_0$. The two are linearly related given $x_t$ and $t$. Empirically, $\varepsilon$-prediction gives more stable training because the target has a fixed distribution ($\mathcal{N}(0, I)$) regardless of $t$, while $x_0$-prediction has different per-timestep target distributions. We follow Ho et al.'s original convention.

Gradient clipping at $1.0$. Diffusion training is occasionally numerically unstable, especially at large $t$ where $\sqrt{\bar\alpha_t}$ becomes very small. Clipping prevents single bad batches from destabilising the entire training run. Production diffusion training always has gradient clipping; the value $1.0$ is a reasonable default.

Sampling (Ancestral)

x = torch.randn(image_shape)
for t in reversed(range(T)):
    eps_pred = model(x, t, y)
    mean = (1/sqrt(alpha_t)) * (x - (beta_t/sqrt(1-alpha_bar_t)) * eps_pred)
    if t > 0:
        x = mean + sqrt(posterior_var_t) * torch.randn_like(x)
    else:
        x = mean

Start from pure Gaussian noise. Walk back from $t = T-1$ to $t = 0$. At each step, predict the noise that was added at that step, compute the posterior mean of $x_{t-1}$ given $x_t$ and the predicted $\varepsilon$, add a small amount of fresh Gaussian noise (except at the final step, $t = 0$, where we just take the mean).

The formulas come from the DDPM posterior derivation. After $T$ such steps, the model has denoised pure noise into an image from the training distribution. For our $T = 200$, that is $200$ forward passes per sample — much slower than the single forward pass of a GAN, but with much better stability and mode coverage.

Production diffusion models accelerate this with DDIM (deterministic sampling, fewer steps), advanced solvers (DPM-Solver++, EDM), or by training the model with classifier-free guidance to enable sampling at fewer effective steps. Our ancestral DDPM is the textbook version; the others are layered on top of the same noise prediction model.

Total Code Inventory

Total: ~$250$ lines for the diffusion machinery, ~$200$ lines of glue. This implements the architecture behind Stable Diffusion 3 and Sora — with the caveat that those use vastly larger models and additional engineering (latent diffusion in a pre-trained autoencoder, classifier-free guidance, text-conditional cross-attention, mixed precision, etc.).

What Part 3 Tests

Part 3 trains this model on $3{,}200$ synthetic 16$\times$16 colored-shape images. The result is not "perfect generation" — at toy scale, with $323K$ parameters and no pretraining, the samples are blurry. What it is: a real demonstration of the inductive-bias tradeoff. U-Nets bake in image priors (locality, translation equivariance) that help at small data; DiT doesn't. At LAION scale that becomes an asset; at toy scale it's a liability. The honest result is informative even when it isn't impressive.

Full code on GitHub: github.com/soveshmohapatra/DiT