In Part 1, we explored the mathematical foundations of RWKV: how attention can be rewritten as a linear recurrence via the WKV operator, enabling parallel training and $O(1)$ inference.
Today, in Part 2, we turn that math into code. I've written a complete RWKV implementation in pure PyTorch—no external libraries, no CUDA kernels, just clean tensor operations. The code supports both parallel training mode (for efficiency) and recurrent inference mode (for memory-flat generation).
The Implementation Strategy
RWKV's elegance lies in its simplicity. The architecture decomposes into two modules:
- Time Mixing (TM): The attention equivalent, computing the WKV recurrence.
- Channel Mixing (CM): The FFN equivalent, using gated squared ReLU.
Each module follows a pre-norm pattern similar to Transformers: $x \leftarrow x + \text{Module}(\text{LayerNorm}(x))$
Time Mixing: Parallel Training Mode
The key challenge is computing the WKV operator efficiently. During training, we have access to the entire sequence at once, so we want to parallelize the recurrence.
Recall from Part 1 that the numerator at time $t$ is:
This can be computed using a cumulative sum with decay. The key steps are: compute $kv_t = k_t \cdot v_t$ for all $t$, apply decay weights, accumulate, and rescale. In PyTorch:
def _compute_wkv_parallel(self, kv, time_decay, time_first, T):
B, T, C = kv.shape
time_decay = time_decay.expand(B, T, C)
time_first = time_first.expand(B, T, C)
time_indices = torch.arange(T, device=kv.device).view(1, T, 1)
weighted_kv = kv * torch.pow(time_decay, time_indices)
cumsum = torch.cumsum(weighted_kv, dim=1)
wkv = torch.zeros_like(kv)
wkv[:, 0] = time_first[:, 0] * kv[:, 0]
if T > 1:
wkv[:, 1:] = time_first[:, 1:] * kv[:, 1:] + time_decay[:, 1:] * cumsum[:, :-1]
return wkv
Time Mixing: Recurrent Inference Mode
During inference, we generate tokens one at a time. Here, the recurrence shines—we simply maintain the accumulated state:
def forward_recurrent(self, x, state=None):
# x is (B, C) for a single token
x = self.ln(x)
time_decay, time_first = self._get_wkv_weights(B, C, x.device)
wkv_state = torch.zeros(B, C, device=x.device) if state is None else state[0]
k_t = self.key(x)
v_t = self.value(x)
r_t = self.receptance(x)
# WKV recurrence: single step
wkv_t = time_first[:, 0] * k_t * v_t + time_decay[:, 0] * wkv_state
# Apply receptance gate
output = self.output(torch.sigmoid(r_t) * wkv_t)
return output, (wkv_t,)
The state is just a single tensor of shape $(B, C)$—constant size regardless of how many tokens we've generated. Compare this to Transformers, which must store $(B, L, C)$ for the KV cache that grows with each token!
Channel Mixing: Squared ReLU Gating
The Channel Mixing module uses a GLU-style gating mechanism with squared ReLU:
class RWKVChannelMixer(nn.Module):
def __init__(self, embed_dim, expand_factor=4):
super().__init__()
self.hidden_dim = embed_dim * expand_factor
self.key = nn.Linear(embed_dim, self.hidden_dim, bias=False)
self.receptance = nn.Linear(embed_dim, embed_dim, bias=False)
self.value = nn.Linear(self.hidden_dim, embed_dim, bias=False)
self.ln = nn.LayerNorm(embed_dim)
def forward(self, x):
x = self.ln(x)
return torch.sigmoid(self.receptance(x)) * self.value(F.relu(self.key(x)) ** 2)
The squared ReLU ($\text{ReLU}(x)^2$) provides smoother gradients than standard ReLU while maintaining sparsity—a subtle but important design choice for training stability.
The Complete RWKV Model
Stacking Time Mixing and Channel Mixing blocks gives us the full RWKV model:
class RWKV(nn.Module):
def __init__(self, vocab_size, embed_dim, num_layers, num_heads=1, expand_factor=4):
super().__init__()
self.embedding = nn.Embedding(vocab_size, embed_dim)
self.blocks = nn.ModuleList([
RWKVBlock(embed_dim, num_heads, expand_factor) for _ in range(num_layers)
])
self.ln_out = nn.LayerNorm(embed_dim)
self.head = nn.Linear(embed_dim, vocab_size, bias=False)
# Weight tying (like Transformers)
self.head.weight = self.embedding.weight
def forward(self, x, states=None, use_recurrent=False):
x = self.embedding(x)
for i, block in enumerate(self.blocks):
if use_recurrent:
x, new_state = block(x, states[i], use_recurrent=True)
else:
x, _ = block(x)
return self.head(self.ln_out(x)), states
Weight tying between the embedding and output projection is a standard technique that improves parameter efficiency and training stability—the same weights learn to both encode inputs and predict outputs.
Autoregressive Generation
With the recurrent mode implemented, autoregressive generation is straightforward:
@torch.no_grad()
def generate(self, prompt_tokens, max_new_tokens, temperature=1.0):
self.eval()
context = prompt_tokens.clone()
states = None
# Process prompt in parallel for efficiency
logits, _ = self(context, use_recurrent=False)
logits = logits[:, -1, :] # get last token logits
for _ in range(max_new_tokens):
next_token = torch.multinomial(
F.softmax(logits / temperature, dim=-1), num_samples=1
)
context = torch.cat([context, next_token], dim=1)
# Single recurrent step: O(1) memory
logits, states = self(next_token.squeeze(1), states, use_recurrent=True)
return context
After processing the prompt in parallel, each new token requires only a single recurrent step with $O(1)$ memory—no KV cache to grow.
Implementation Gotchas
A few pitfalls to watch for when implementing RWKV:
- Decay Parameterization: The decay $w$ must be positive to ensure stable fading memory. Parameterize via sigmoid to keep it in $(0, 1)$.
- Time First Initialization: The $u$ parameter should be initialized to emphasize the current token ($u = 1.0$, giving $\sigma(u) \approx 0.73$).
- Gradient Clipping: RWKV can have gradient explosions in early training. Gradient clipping at norm 1.0 is essential.
- State Initialization: During recurrent inference, the state starts at zero, representing no prior context.
Next Steps: Benchmarking
The code is complete. In Part 3, we'll put RWKV to the test against Transformer and LSTM baselines on synthetic sequence tasks, measuring training convergence, inference latency, and memory usage—validating the $O(1)$ inference promise in practice.