Back to RWKV Hub

Deconstructing RWKV

Part 1: Linear Attention and the WKV Operator

Introduction

Self-attention compares every token to every other token. This gives Transformers rich contextual representations, but at a cost: attention is $O(L^2)$ in sequence length, and autoregressive generation requires a KV cache that grows as $O(L)$.

RNNs have the opposite profile. They maintain a fixed-size hidden state, giving $O(1)$ inference memory, but they can't parallelize during training and tend to underperform Transformers.

RWKV (Receptance Weighted Key Value) bridges the gap: it trains in parallel like a Transformer and infers like an RNN with constant memory. This series covers the theory (Part 1), a pure PyTorch implementation (Part 2), and a head-to-head benchmark (Part 3).

The Key Insight: Attention as Linear Recurrence

The core observation behind RWKV: attention can be rewritten as a linear recurrence.

Standard self-attention:

$$ \text{Attention}(Q, K, V) = \text{softmax}\left(\frac{QK^T}{\sqrt{d}}\right)V $$

For a single query position $t$:

$$ y_t = \sum_{i=1}^{t} \frac{\exp(q_t \cdot k_i / \sqrt{d})}{\sum_{j=1}^{t} \exp(q_t \cdot k_j / \sqrt{d})} v_i $$

The denominator depends on all previous keys--this normalization is what forces the $O(L^2)$ computation.

Now replace softmax with an exponential kernel $\phi(q, k) = \exp(q) \cdot \exp(k)$. The numerator factors:

$$ \sum_{i=1}^{t} \exp(q_t) \cdot \exp(k_i) \cdot v_i = \exp(q_t) \cdot \sum_{i=1}^{t} \exp(k_i) \cdot v_i $$

That inner sum $\sum_{i=1}^{t} \exp(k_i) \cdot v_i$ is a cumulative sum, computable as a recurrence:

$$ \text{kv}_t = \text{kv}_{t-1} + \exp(k_t) \cdot v_t $$

This is the essence of RWKV: attention reformulated as a linear recurrence with $O(1)$ state.

The WKV Operator

RWKV formalizes this recurrence with the WKV (Weighted Key-Value) operator:

$$ \text{wkv}_t = \frac{\sum_{i=1}^{t-1} \exp(-(t-1-i)w + k_i) v_i + \exp(u + k_t) v_t}{\sum_{i=1}^{t-1} \exp(-(t-1-i)w + k_i) + \exp(u + k_t)} $$

Three components:

Both numerator and denominator admit a simple recurrence:

$$ \text{num}_t = \exp(u + k_t) v_t + \exp(-w) \cdot \text{num}_{t-1} $$
$$ \text{den}_t = \exp(u + k_t) + \exp(-w) \cdot \text{den}_{t-1} $$
$$ \text{wkv}_t = \frac{\text{num}_t}{\text{den}_t} $$

Time Mixing vs. Channel Mixing

Each RWKV block has two sub-modules:

Time Mixing (TM)

RWKV's replacement for self-attention. It runs the WKV recurrence to mix information across time steps:

The receptance gate controls how much of the WKV output passes through, functioning like a learned information filter.

Channel Mixing (CM)

RWKV's replacement for the FFN, using a GLU variant with squared ReLU:

$$ k_t = x_t W_k \qquad r_t = \sigma(x_t W_r) $$ $$ v_t = \text{ReLU}(k_t)^2 W_v \qquad o_t = r_t \odot v_t $$

Squared ReLU ($\text{ReLU}(x)^2$) gives smoother gradients than standard ReLU while preserving sparsity.

Why RWKV Matters