Part 1 covered the math: how attention becomes a linear recurrence through the WKV operator.
Here we turn that into working code. The implementation is pure PyTorch--no external libraries, no custom CUDA kernels. It supports both parallel training mode and recurrent inference mode with $O(1)$ memory.
Architecture Overview
Two modules per block, each with pre-norm residual connections ($x \leftarrow x + \text{Module}(\text{LayerNorm}(x))$):
- Time Mixing (TM): Attention equivalent -- computes the WKV recurrence.
- Channel Mixing (CM): FFN equivalent -- gated squared ReLU.
Time Mixing: Parallel Training Mode
During training we have the full sequence, so we parallelize the WKV recurrence. The numerator at time $t$:
This decomposes into element-wise products, decay-weighted terms, and a cumulative sum:
def _compute_wkv_parallel(self, kv, time_decay, time_first, T):
B, T, C = kv.shape
time_decay = time_decay.expand(B, T, C)
time_first = time_first.expand(B, T, C)
time_indices = torch.arange(T, device=kv.device).view(1, T, 1)
weighted_kv = kv * torch.pow(time_decay, time_indices)
cumsum = torch.cumsum(weighted_kv, dim=1)
wkv = torch.zeros_like(kv)
wkv[:, 0] = time_first[:, 0] * kv[:, 0]
if T > 1:
wkv[:, 1:] = time_first[:, 1:] * kv[:, 1:] + time_decay[:, 1:] * cumsum[:, :-1]
return wkv
Time Mixing: Recurrent Inference Mode
At inference time we generate one token at a time, so we just step the recurrence forward:
def forward_recurrent(self, x, state=None):
# x is (B, C) for a single token
x = self.ln(x)
time_decay, time_first = self._get_wkv_weights(B, C, x.device)
wkv_state = torch.zeros(B, C, device=x.device) if state is None else state[0]
k_t = self.key(x)
v_t = self.value(x)
r_t = self.receptance(x)
# WKV recurrence: single step
wkv_t = time_first[:, 0] * k_t * v_t + time_decay[:, 0] * wkv_state
# Apply receptance gate
output = self.output(torch.sigmoid(r_t) * wkv_t)
return output, (wkv_t,)
The entire state is a single $(B, C)$ tensor -- fixed size no matter how many tokens have been generated. A Transformer would need $(B, L, C)$ for the KV cache, growing with every step.
Channel Mixing: Squared ReLU Gating
class RWKVChannelMixer(nn.Module):
def __init__(self, embed_dim, expand_factor=4):
super().__init__()
self.hidden_dim = embed_dim * expand_factor
self.key = nn.Linear(embed_dim, self.hidden_dim, bias=False)
self.receptance = nn.Linear(embed_dim, embed_dim, bias=False)
self.value = nn.Linear(self.hidden_dim, embed_dim, bias=False)
self.ln = nn.LayerNorm(embed_dim)
def forward(self, x):
x = self.ln(x)
return torch.sigmoid(self.receptance(x)) * self.value(F.relu(self.key(x)) ** 2)
Squared ReLU ($\text{ReLU}(x)^2$) gives smoother gradients than standard ReLU while preserving sparsity.
The Complete RWKV Model
Stack the blocks, add weight tying between embedding and output head:
class RWKV(nn.Module):
def __init__(self, vocab_size, embed_dim, num_layers, num_heads=1, expand_factor=4):
super().__init__()
self.embedding = nn.Embedding(vocab_size, embed_dim)
self.blocks = nn.ModuleList([
RWKVBlock(embed_dim, num_heads, expand_factor) for _ in range(num_layers)
])
self.ln_out = nn.LayerNorm(embed_dim)
self.head = nn.Linear(embed_dim, vocab_size, bias=False)
# Weight tying (like Transformers)
self.head.weight = self.embedding.weight
def forward(self, x, states=None, use_recurrent=False):
x = self.embedding(x)
for i, block in enumerate(self.blocks):
if use_recurrent:
x, new_state = block(x, states[i], use_recurrent=True)
else:
x, _ = block(x)
return self.head(self.ln_out(x)), states
Autoregressive Generation
Process the prompt in parallel, then switch to single-step recurrence for generation:
@torch.no_grad()
def generate(self, prompt_tokens, max_new_tokens, temperature=1.0):
self.eval()
context = prompt_tokens.clone()
states = None
# Process prompt in parallel for efficiency
logits, _ = self(context, use_recurrent=False)
logits = logits[:, -1, :] # get last token logits
for _ in range(max_new_tokens):
next_token = torch.multinomial(
F.softmax(logits / temperature, dim=-1), num_samples=1
)
context = torch.cat([context, next_token], dim=1)
# Single recurrent step: O(1) memory
logits, states = self(next_token.squeeze(1), states, use_recurrent=True)
return context
Each new token after the prompt costs $O(1)$ memory -- no growing cache.
Implementation Pitfalls
- Decay must be positive. Parameterize $w$ through sigmoid to keep it in $(0, 1)$.
- Time-first initialization matters. Set $u = 1.0$ (so $\sigma(u) \approx 0.73$) to bias toward the current token.
- Gradient clipping is not optional. Without clipping at norm 1.0, early training often diverges.
- Zero-initialize state. The recurrent state starts at zero, representing no prior context.