The Transformer architecture has dominated deep learning for nearly a decade. Its core mechanism—self-attention—enables rich contextual representations by comparing every token to every other token. But this comes at a steep cost: attention scales quadratically with sequence length, and autoregressive generation requires storing all past key-value pairs, leading to $O(L)$ memory that grows linearly with context.
Recurrent Neural Networks (RNNs) sit at the opposite extreme. They process sequences token-by-token, maintaining a fixed-size hidden state, giving them $O(1)$ memory during inference. But RNNs suffer from vanishing gradients, cannot be parallelized during training, and typically underperform Transformers on complex tasks.
What if there existed an architecture that combined the best of both worlds? Enter RWKV (Receptance Weighted Key Value)—an architecture that trains in parallel like a Transformer but infers sequentially like an RNN. In this 3-part series, we completely deconstruct RWKV from theory to pure PyTorch implementation to benchmark.
The Key Insight: Attention as Linear Recurrence
The breakthrough behind RWKV comes from a deceptively simple observation: attention can be rewritten as a linear recurrence.
Recall the standard self-attention formula:
For a single query position $t$, the output is:
The denominator is a normalization term that depends on all previous keys. This is what makes attention quadratic—we must recompute it for every query.
Now consider a linear attention variant where we use an exponential kernel $\phi(q, k) = \exp(q) \cdot \exp(k)$. The numerator becomes:
The term $\sum_{i=1}^{t} \exp(k_i) \cdot v_i$ is a cumulative sum that can be computed recursively:
This is the essence of RWKV: attention reformulated as a linear recurrence that can be computed in $O(1)$ memory during inference.
The WKV Operator
RWKV introduces the WKV (Weighted Key-Value) operator, which formalizes this recurrence with additional refinements:
Here:
- $w$ is a learned decay parameter controlling how fast past information fades.
- $u$ is a learned time-first parameter giving extra weight to the current token.
- The exponential decay $\exp(-(t-1-i)w)$ implements a fading memory similar to an RNN's hidden state.
The numerator can be computed recursively:
Time Mixing vs. Channel Mixing
RWKV decomposes the Transformer block into two distinct modules:
Time Mixing (TM)
The Time Mixing module is RWKV's equivalent of self-attention. It computes the WKV recurrence to mix information across time. The key components are:
- Key projection: $k_t = x_t W_k$
- Value projection: $v_t = x_t W_v$
- Receptance: $r_t = \sigma(x_t W_r)$ (a sigmoid gate, analogous to the query)
- WKV output: $\text{wkv}_t$ computed via the recurrence above
- Final output: $o_t = r_t \odot \text{wkv}_t$ (element-wise multiplication)
The receptance gate is crucial—it controls how much of the WKV output flows through, similar to how a query determines which values to attend to.
Channel Mixing (CM)
The Channel Mixing module is RWKV's equivalent of the feed-forward network (FFN). It uses a gated linear unit (GLU) variant with squared ReLU activation:
The squared ReLU activation ($\text{ReLU}(x)^2$) provides smoother gradients than standard ReLU while maintaining sparsity—a subtle but important design choice for training stability.
Why RWKV Matters
RWKV offers several compelling advantages over both Transformers and standard RNNs:
- Parallel Training: The WKV recurrence can be unrolled and computed in parallel during training using a cumulative sum trick, achieving throughput comparable to Transformers.
- O(1) Inference: During autoregressive generation, RWKV maintains a constant-size hidden state (the accumulated numerator and denominator). No growing KV cache.
- Linear Scaling: Both training and inference scale linearly with sequence length, unlike Transformer's quadratic attention.
- Transformer Performance: RWKV achieves perplexity and accuracy comparable to Transformers on language modeling tasks, unlike standard RNNs.
Next Steps: From Math to Code
The WKV operator may seem abstract, but implementing it in PyTorch is surprisingly straightforward. In Part 2, we will build the Time Mixing and Channel Mixing modules from scratch, implement both parallel training and recurrent inference modes, and assemble a complete RWKV model capable of autoregressive text generation.