Back to RWKV Hub

Deconstructing RWKV

Part 1: Linear Attention and the WKV Operator

Introduction

The Transformer architecture has dominated deep learning for nearly a decade. Its core mechanism—self-attention—enables rich contextual representations by comparing every token to every other token. But this comes at a steep cost: attention scales quadratically with sequence length, and autoregressive generation requires storing all past key-value pairs, leading to $O(L)$ memory that grows linearly with context.

Recurrent Neural Networks (RNNs) sit at the opposite extreme. They process sequences token-by-token, maintaining a fixed-size hidden state, giving them $O(1)$ memory during inference. But RNNs suffer from vanishing gradients, cannot be parallelized during training, and typically underperform Transformers on complex tasks.

What if there existed an architecture that combined the best of both worlds? Enter RWKV (Receptance Weighted Key Value)—an architecture that trains in parallel like a Transformer but infers sequentially like an RNN. In this 3-part series, we completely deconstruct RWKV from theory to pure PyTorch implementation to benchmark.

The Key Insight: Attention as Linear Recurrence

The breakthrough behind RWKV comes from a deceptively simple observation: attention can be rewritten as a linear recurrence.

Recall the standard self-attention formula:

$$ \text{Attention}(Q, K, V) = \text{softmax}\left(\frac{QK^T}{\sqrt{d}}\right)V $$

For a single query position $t$, the output is:

$$ y_t = \sum_{i=1}^{t} \frac{\exp(q_t \cdot k_i / \sqrt{d})}{\sum_{j=1}^{t} \exp(q_t \cdot k_j / \sqrt{d})} v_i $$

The denominator is a normalization term that depends on all previous keys. This is what makes attention quadratic—we must recompute it for every query.

Now consider a linear attention variant where we use an exponential kernel $\phi(q, k) = \exp(q) \cdot \exp(k)$. The numerator becomes:

$$ \sum_{i=1}^{t} \exp(q_t) \cdot \exp(k_i) \cdot v_i = \exp(q_t) \cdot \sum_{i=1}^{t} \exp(k_i) \cdot v_i $$

The term $\sum_{i=1}^{t} \exp(k_i) \cdot v_i$ is a cumulative sum that can be computed recursively:

$$ \text{kv}_t = \text{kv}_{t-1} + \exp(k_t) \cdot v_t $$

This is the essence of RWKV: attention reformulated as a linear recurrence that can be computed in $O(1)$ memory during inference.

The WKV Operator

RWKV introduces the WKV (Weighted Key-Value) operator, which formalizes this recurrence with additional refinements:

$$ \text{wkv}_t = \frac{\sum_{i=1}^{t-1} \exp(-(t-1-i)w + k_i) v_i + \exp(u + k_t) v_t}{\sum_{i=1}^{t-1} \exp(-(t-1-i)w + k_i) + \exp(u + k_t)} $$

Here:

The numerator can be computed recursively:

$$ \text{num}_t = \exp(u + k_t) v_t + \exp(-w) \cdot \text{num}_{t-1} $$
$$ \text{den}_t = \exp(u + k_t) + \exp(-w) \cdot \text{den}_{t-1} $$
$$ \text{wkv}_t = \frac{\text{num}_t}{\text{den}_t} $$

Time Mixing vs. Channel Mixing

RWKV decomposes the Transformer block into two distinct modules:

Time Mixing (TM)

The Time Mixing module is RWKV's equivalent of self-attention. It computes the WKV recurrence to mix information across time. The key components are:

The receptance gate is crucial—it controls how much of the WKV output flows through, similar to how a query determines which values to attend to.

Channel Mixing (CM)

The Channel Mixing module is RWKV's equivalent of the feed-forward network (FFN). It uses a gated linear unit (GLU) variant with squared ReLU activation:

$$ k_t = x_t W_k \qquad r_t = \sigma(x_t W_r) $$ $$ v_t = \text{ReLU}(k_t)^2 W_v \qquad o_t = r_t \odot v_t $$

The squared ReLU activation ($\text{ReLU}(x)^2$) provides smoother gradients than standard ReLU while maintaining sparsity—a subtle but important design choice for training stability.

Why RWKV Matters

RWKV offers several compelling advantages over both Transformers and standard RNNs:

Next Steps: From Math to Code

The WKV operator may seem abstract, but implementing it in PyTorch is surprisingly straightforward. In Part 2, we will build the Time Mixing and Channel Mixing modules from scratch, implement both parallel training and recurrent inference modes, and assemble a complete RWKV model capable of autoregressive text generation.