Self-attention compares every token to every other token. This gives Transformers rich contextual representations, but at a cost: attention is $O(L^2)$ in sequence length, and autoregressive generation requires a KV cache that grows as $O(L)$.
RNNs have the opposite profile. They maintain a fixed-size hidden state, giving $O(1)$ inference memory, but they can't parallelize during training and tend to underperform Transformers.
RWKV (Receptance Weighted Key Value) bridges the gap: it trains in parallel like a Transformer and infers like an RNN with constant memory. This series covers the theory (Part 1), a pure PyTorch implementation (Part 2), and a head-to-head benchmark (Part 3).
The Key Insight: Attention as Linear Recurrence
The core observation behind RWKV: attention can be rewritten as a linear recurrence.
Standard self-attention:
For a single query position $t$:
The denominator depends on all previous keys--this normalization is what forces the $O(L^2)$ computation.
Now replace softmax with an exponential kernel $\phi(q, k) = \exp(q) \cdot \exp(k)$. The numerator factors:
That inner sum $\sum_{i=1}^{t} \exp(k_i) \cdot v_i$ is a cumulative sum, computable as a recurrence:
This is the essence of RWKV: attention reformulated as a linear recurrence with $O(1)$ state.
The WKV Operator
RWKV formalizes this recurrence with the WKV (Weighted Key-Value) operator:
Three components:
- $w$ -- a learned decay controlling how fast past information fades.
- $u$ -- a learned time-first parameter that upweights the current token.
- The exponential decay $\exp(-(t-1-i)w)$ implements fading memory, analogous to an RNN hidden state.
Both numerator and denominator admit a simple recurrence:
Time Mixing vs. Channel Mixing
Each RWKV block has two sub-modules:
Time Mixing (TM)
RWKV's replacement for self-attention. It runs the WKV recurrence to mix information across time steps:
- Key: $k_t = x_t W_k$
- Value: $v_t = x_t W_v$
- Receptance: $r_t = \sigma(x_t W_r)$ -- a sigmoid gate playing the role of the query
- Output: $o_t = r_t \odot \text{wkv}_t$
The receptance gate controls how much of the WKV output passes through, functioning like a learned information filter.
Channel Mixing (CM)
RWKV's replacement for the FFN, using a GLU variant with squared ReLU:
Squared ReLU ($\text{ReLU}(x)^2$) gives smoother gradients than standard ReLU while preserving sparsity.
Why RWKV Matters
- Parallel Training: The WKV recurrence unrolls into a cumulative sum, so full sequences can be processed in parallel during training.
- O(1) Inference: At generation time, the model carries only the accumulated numerator and denominator--no KV cache.
- Linear Scaling: Training and inference both scale linearly with sequence length.
- Competitive Accuracy: RWKV matches Transformer perplexity on language modeling, unlike vanilla RNNs.