Back to DiT Hub

Deconstructing DiT

Part 3: Training at Toy Scale, and What It Teaches

The Honest Result

Most blog posts about new architectures show impressive results. This one shows a mediocre one — and the mediocrity is the point. A $323{,}696$-parameter class-conditional DiT trained from scratch on $3{,}200$ synthetic $16 \times 16$ colored-shape images in $42$ seconds produces samples that have the right colors but blurry shapes. This is exactly what the inductive-bias tradeoff predicts: DiT's architecture lacks the locality bias that would let a U-Net of the same size produce sharper output on small data.

The honest finding is more informative than the textbook claim that "DiT is better than U-Net." The actual finding is "DiT is better than U-Net when the data is large enough" — and at $3{,}200$ examples it isn't. This part demonstrates exactly what that crossover looks like at small scale.

Setup

Dataset. $16$ classes (4 colors $\times$ 4 shapes) at $16 \times 16$ RGB resolution. $200$ examples per class, $3{,}200$ total. Patch size $4$, so $16$ patches per image — the same total number as the number of classes, which is coincidental but small enough to fit comfortably in attention.

Model. DiT with $\text{hidden} = 64$, $4$ layers, $4$ heads. Class-conditional via adaLN-Zero on the combined (timestep + class) vector. Total trainable parameters: $323{,}696$. This is two orders of magnitude smaller than the smallest published DiT (DiT-S at ~$33$M parameters trained on ImageNet at $256 \times 256$).

Diffusion. Linear $\beta$ schedule, $T = 200$ timesteps. Standard DDPM $\varepsilon$-prediction objective.

Training. AdamW with $\eta = 5 \times 10^{-4}$, weight decay $0.0$ ($z$-init handles the regularisation), gradient clipping at $1.0$, cosine schedule over $30$ epochs, batch size $128$.

Hardware. CPU. The MPS backend on Apple Silicon has a stride-handling bug in the adaLN-Zero chunk(6, dim=-1) backward pass — gradients silently miscompute. CUDA reproduces the same numbers; the toy experiment falls back to CPU because of this Apple-specific issue. Total wall-clock: $42.7$ seconds.

Loss Trajectory

EpochMSE (noise prediction)
$1$$0.9147$
$4$$0.2467$
$8$$0.1685$
$16$$0.1409$
$24$$0.1262$
$30$$0.1228$

Loss drops by roughly an order of magnitude in the first four epochs (a regime where the model is learning the gross statistics of the dataset — average color distribution, average shape size, etc.), then plateaus slowly. The final value of $0.1228$ is reasonable for DDPM noise prediction on a low-resolution dataset — a perfectly trained model would have a noise-prediction MSE bounded below by the irreducible noise in the data itself, which on our $16 \times 16$ RGB images with the small added jitter is around $0.05$ to $0.10$.

The loss is still slowly decreasing at epoch $30$. Training longer would likely improve the samples further. We stop at $30$ epochs to keep the toy experiment under a minute; production DiT training runs for hundreds of thousands of iterations.

Generated Samples

After training, we generate one image per class by sampling pure Gaussian noise and running $200$ ancestral denoising steps conditioned on the target class label. The results are mixed in a way that is genuinely informative.

What is correct. Colors are recovered consistently. Every "red" class produces red samples, every "blue" class produces blue, etc. The class conditioning works — the model has learned to associate each class index with a particular color distribution. This is what success looks like for the conditioning mechanism in isolation.

What is approximate. Shape geometry is recognisable but blurry. Circles are roughly round but with soft, fuzzy edges. Squares are roughly square but the corners are rounded and the sides aren't quite straight. Triangles are roughly triangular but the points are blunted. Crosses — the highest-frequency shape in the dataset, requiring two thin orthogonal strokes — are the hardest, often collapsing toward a smudge with the right color but no clear cross structure.

What is missing. Sharp edges. Clean geometric primitives. The fine-grained spatial structure that distinguishes shapes from coloured blobs. These are the things our model cannot produce at this scale.

This is the honest result. At $3{,}200$ examples, $323K$ parameters, and $16 \times 16$ resolution, DiT is undertrained for sharp geometric primitives. The samples have the right marginal distributions (color) but lack the high-frequency spatial detail that would make them recognisably squares versus circles versus triangles.

Why the Samples Are Blurry

Three confounded reasons, all in the same direction.

(1) Resolution is 16×16. A sharp triangle has only a few pixels per edge at this resolution — the boundary between "triangle" and "circle" is a few pixels of difference. The model has very little information to encode shape geometry from. Even a perfectly trained model would struggle to make $16 \times 16$ triangles look much sharper than $16 \times 16$ blobs.

(2) Dataset is tiny. $3{,}200$ examples is enough to learn approximate marginals (the colors), but not enough to learn sharp geometric primitives from scratch — especially with a low-inductive-bias architecture like DiT that has no prior on locality. A model with a U-Net's convolutional prior could plausibly do better with this little data because it would not have to learn that nearby pixels matter from scratch.

(3) Model is small. $323{,}696$ parameters is two orders of magnitude smaller than the smallest published DiT (DiT-S at $\sim 33$M parameters on ImageNet). The smallest model in the original DiT paper trains for orders of magnitude longer than we do, on orders of magnitude more data, at much higher resolution. We are at the "this barely works" end of the regime.

All three of these point in the same direction: at this scale, the architecture's lack of inductive bias is a liability.

Why DiT Is Undertrained Here But Dominant at Scale

A U-Net of the same parameter count would produce sharper samples on the same dataset. U-Nets bake in three image priors that DiT has to learn from data:

Locality. Convolutions only connect nearby pixels. A pixel's value depends on its neighbours, not on pixels far away. This is correct for natural images and immediately tells the model "shapes are local structures".

Translation equivariance. The same convolution kernel applies everywhere in the image. A circle in the top-left and a circle in the bottom-right are processed by the same filters. This means the model doesn't have to learn separately what a circle looks like at each position.

Multi-scale processing. U-Nets downsample to compute coarse features, then upsample with skip connections. This lets them naturally represent both global structure (where is the shape?) and local detail (what does its edge look like?) without having to discover this hierarchy from scratch.

These three priors are correct for natural images. They let a U-Net learn good image structure from very little data — exactly the regime we are in. DiT has none of these priors. The patch-embedding step throws away spatial locality (any patch can attend to any other patch). The position embedding has to be learned from scratch rather than baked in. Multi-scale processing happens implicitly through depth, if at all.

With $3{,}200$ examples, DiT cannot fully learn locality or translation equivariance — it doesn't have enough data to discover from scratch that nearby pixels matter more than distant ones. The basic image priors are not surfaced from such a small corpus, and the model is left producing reasonable colors but bad geometry.

At LAION/JFT Scale, the Picture Inverts

At $400$ million images (the LAION-scale that production diffusion models train on), the trade-off flips entirely. Now the U-Net's hardcoded biases become constraints, not assists. The data wants the model to learn biases that are different from "locality + translation equivariance":

U-Nets cannot easily learn these. The convolutional prior actively prevents them. DiT can — it has no prior either way, so it learns whatever the data shows it.

This is why every production text-to-image and text-to-video model in 2024 moved to DiT: Stable Diffusion 3, Sora, Stable Video Diffusion (later versions), PixArt-$\alpha$, Lumina-T2I. The inductive-bias tradeoff that hurts DiT at $3{,}200$ examples is the same tradeoff that wins at $400$ million.

The Inductive-Bias Tradeoff, Quantified

The general pattern: inductive biases trade short-term sample efficiency for long-term ceiling. A U-Net with strong locality bias is more sample-efficient at small data; it has a higher floor. DiT with weak inductive bias is less sample-efficient at small data but has a higher ceiling at large data.

The crossover point is empirical and depends on resolution, model size, and task difficulty. For class-conditional $256 \times 256$ ImageNet generation, the crossover is somewhere in the millions of images. For text-conditional $1024 \times 1024$ image generation (the SD3/Sora regime), the crossover is in the hundreds of millions.

Our $3{,}200$-image experiment is firmly on the U-Net side of the crossover. That the DiT architecture learned recognisable colors and rough shapes at all is itself remarkable — evidence that the architecture is functioning correctly. What it cannot do at this scale is precisely the thing U-Nets do as a free gift from their inductive bias: produce sharp output from a small corpus.

The Denoising Trajectory

Watching a single noise vector get denoised step by step (the "denoising trajectory" plot) shows the canonical DDPM behaviour: at $t = T - 1$ the image is pure noise; over the first $50\%$ of the timesteps the noise gradually coalesces into a colored blob; over the next $30\%$ the blob takes shape; over the final $20\%$ the shape sharpens (modestly, in our case) into recognisable structure.

This is the standard DDPM picture. It shows up the same way in every diffusion paper. Our trajectory looks like the textbook one, just at toy scale — the model is functioning correctly, the limitation is data and resolution.

What This Series Demonstrates

The architecture is sound. Class conditioning via adaLN-Zero works. The denoising trajectory has canonical DDPM appearance. The training is stable (no NaN losses, no diverging gradients). The samples are recognisable. What the architecture cannot do at this scale is precisely the thing U-Nets do as a free gift.

The lesson is not "DiT doesn't work" — DiT works fine, and at scale it dominates U-Nets. The lesson is that the architectural choice depends on which side of the inductive-bias crossover you are training. Small data, small model, small resolution: pick the U-Net. Large data, large model, large resolution: pick the DiT.

This Was the Experiment I Built Before Writing the Post

Writing this post in advance of running the experiment would have produced the textbook claim "DiT > U-Net" with no nuance. Running the experiment first and writing afterward produced the actual finding: "DiT > U-Net when data is large enough; U-Net > DiT when data is small."

This is the right way to do build-in-public series. Run the experiment, see what happens, write up what you saw. If the result is unexpected, that's the most interesting outcome — the textbook story is rarely as clean as it is presented, and the discrepancies are where the real learning happens.

Scaling This Up

To get from "toy DiT that almost works" to "production DiT", you would scale several knobs:

Resolution. $16 \times 16$ to $256 \times 256$ or $1024 \times 1024$. Latent diffusion in a pretrained VAE is the standard approach — the DiT operates on the VAE's latent grid, not directly on pixels.

Model size. $323K$ to $1B$+ parameters. DiT-XL is roughly $700M$ parameters; SD3 and Sora's DiTs are believed to be in the $4B$+ range.

Dataset. $3{,}200$ to $400M$+ examples. Most production diffusion models train on web-scraped image-caption pairs (LAION, COYO, internal datasets).

Conditioning. Class indices to natural-language text. Add a small text-cross-attention back to the DiT block; keep adaLN-Zero for the time-conditioning side.

Sampling improvements. Replace ancestral DDPM with DDIM, DPM-Solver++, or EDM-style samplers. Reduces the $200$-step sampling to $20$–$50$ steps without quality loss.

The architecture and the diffusion machinery don't change. The differences are entirely in the magnitudes — which is exactly the scaling story DiT was designed for.

Summary

Full code on GitHub: github.com/soveshmohapatra/DiT