Back to RWKV Hub

Deconstructing RWKV

Part 2: Pure PyTorch Implementation

Introduction

Part 1 covered the math: how attention becomes a linear recurrence through the WKV operator.

Here we turn that into working code. The implementation is pure PyTorch--no external libraries, no custom CUDA kernels. It supports both parallel training mode and recurrent inference mode with $O(1)$ memory.

Architecture Overview

Two modules per block, each with pre-norm residual connections ($x \leftarrow x + \text{Module}(\text{LayerNorm}(x))$):

  1. Time Mixing (TM): Attention equivalent -- computes the WKV recurrence.
  2. Channel Mixing (CM): FFN equivalent -- gated squared ReLU.

Time Mixing: Parallel Training Mode

During training we have the full sequence, so we parallelize the WKV recurrence. The numerator at time $t$:

$$ \text{num}_t = \exp(u + k_t) v_t + \sum_{i=0}^{t-1} \exp(-(t-1-i)w + k_i) v_i $$

This decomposes into element-wise products, decay-weighted terms, and a cumulative sum:

def _compute_wkv_parallel(self, kv, time_decay, time_first, T):
    B, T, C = kv.shape
    time_decay = time_decay.expand(B, T, C)
    time_first = time_first.expand(B, T, C)
    time_indices = torch.arange(T, device=kv.device).view(1, T, 1)
    weighted_kv = kv * torch.pow(time_decay, time_indices)
    cumsum = torch.cumsum(weighted_kv, dim=1)
    wkv = torch.zeros_like(kv)
    wkv[:, 0] = time_first[:, 0] * kv[:, 0]
    if T > 1:
        wkv[:, 1:] = time_first[:, 1:] * kv[:, 1:] + time_decay[:, 1:] * cumsum[:, :-1]
    return wkv

Time Mixing: Recurrent Inference Mode

At inference time we generate one token at a time, so we just step the recurrence forward:

def forward_recurrent(self, x, state=None):
    # x is (B, C) for a single token
    x = self.ln(x)
    time_decay, time_first = self._get_wkv_weights(B, C, x.device)
    wkv_state = torch.zeros(B, C, device=x.device) if state is None else state[0]

    k_t = self.key(x)
    v_t = self.value(x)
    r_t = self.receptance(x)

    # WKV recurrence: single step
    wkv_t = time_first[:, 0] * k_t * v_t + time_decay[:, 0] * wkv_state

    # Apply receptance gate
    output = self.output(torch.sigmoid(r_t) * wkv_t)
    return output, (wkv_t,)

The entire state is a single $(B, C)$ tensor -- fixed size no matter how many tokens have been generated. A Transformer would need $(B, L, C)$ for the KV cache, growing with every step.

Channel Mixing: Squared ReLU Gating

class RWKVChannelMixer(nn.Module):
    def __init__(self, embed_dim, expand_factor=4):
        super().__init__()
        self.hidden_dim = embed_dim * expand_factor
        self.key = nn.Linear(embed_dim, self.hidden_dim, bias=False)
        self.receptance = nn.Linear(embed_dim, embed_dim, bias=False)
        self.value = nn.Linear(self.hidden_dim, embed_dim, bias=False)
        self.ln = nn.LayerNorm(embed_dim)

    def forward(self, x):
        x = self.ln(x)
        return torch.sigmoid(self.receptance(x)) * self.value(F.relu(self.key(x)) ** 2)

Squared ReLU ($\text{ReLU}(x)^2$) gives smoother gradients than standard ReLU while preserving sparsity.

The Complete RWKV Model

Stack the blocks, add weight tying between embedding and output head:

class RWKV(nn.Module):
    def __init__(self, vocab_size, embed_dim, num_layers, num_heads=1, expand_factor=4):
        super().__init__()
        self.embedding = nn.Embedding(vocab_size, embed_dim)
        self.blocks = nn.ModuleList([
            RWKVBlock(embed_dim, num_heads, expand_factor) for _ in range(num_layers)
        ])
        self.ln_out = nn.LayerNorm(embed_dim)
        self.head = nn.Linear(embed_dim, vocab_size, bias=False)
        # Weight tying (like Transformers)
        self.head.weight = self.embedding.weight

    def forward(self, x, states=None, use_recurrent=False):
        x = self.embedding(x)
        for i, block in enumerate(self.blocks):
            if use_recurrent:
                x, new_state = block(x, states[i], use_recurrent=True)
            else:
                x, _ = block(x)
        return self.head(self.ln_out(x)), states

Autoregressive Generation

Process the prompt in parallel, then switch to single-step recurrence for generation:

@torch.no_grad()
def generate(self, prompt_tokens, max_new_tokens, temperature=1.0):
    self.eval()
    context = prompt_tokens.clone()
    states = None

    # Process prompt in parallel for efficiency
    logits, _ = self(context, use_recurrent=False)
    logits = logits[:, -1, :]  # get last token logits

    for _ in range(max_new_tokens):
        next_token = torch.multinomial(
            F.softmax(logits / temperature, dim=-1), num_samples=1
        )
        context = torch.cat([context, next_token], dim=1)
        # Single recurrent step: O(1) memory
        logits, states = self(next_token.squeeze(1), states, use_recurrent=True)

    return context

Each new token after the prompt costs $O(1)$ memory -- no growing cache.

Implementation Pitfalls

  1. Decay must be positive. Parameterize $w$ through sigmoid to keep it in $(0, 1)$.
  2. Time-first initialization matters. Set $u = 1.0$ (so $\sigma(u) \approx 0.73$) to bias toward the current token.
  3. Gradient clipping is not optional. Without clipping at norm 1.0, early training often diverges.
  4. Zero-initialize state. The recurrent state starts at zero, representing no prior context.