Back to RWKV Hub

Deconstructing RWKV

Part 2: Pure PyTorch Implementation

Introduction

In Part 1, we explored the mathematical foundations of RWKV: how attention can be rewritten as a linear recurrence via the WKV operator, enabling parallel training and $O(1)$ inference.

Today, in Part 2, we turn that math into code. I've written a complete RWKV implementation in pure PyTorch—no external libraries, no CUDA kernels, just clean tensor operations. The code supports both parallel training mode (for efficiency) and recurrent inference mode (for memory-flat generation).

The Implementation Strategy

RWKV's elegance lies in its simplicity. The architecture decomposes into two modules:

  1. Time Mixing (TM): The attention equivalent, computing the WKV recurrence.
  2. Channel Mixing (CM): The FFN equivalent, using gated squared ReLU.

Each module follows a pre-norm pattern similar to Transformers: $x \leftarrow x + \text{Module}(\text{LayerNorm}(x))$

Time Mixing: Parallel Training Mode

The key challenge is computing the WKV operator efficiently. During training, we have access to the entire sequence at once, so we want to parallelize the recurrence.

Recall from Part 1 that the numerator at time $t$ is:

$$ \text{num}_t = \exp(u + k_t) v_t + \sum_{i=0}^{t-1} \exp(-(t-1-i)w + k_i) v_i $$

This can be computed using a cumulative sum with decay. The key steps are: compute $kv_t = k_t \cdot v_t$ for all $t$, apply decay weights, accumulate, and rescale. In PyTorch:

def _compute_wkv_parallel(self, kv, time_decay, time_first, T):
    B, T, C = kv.shape
    time_decay = time_decay.expand(B, T, C)
    time_first = time_first.expand(B, T, C)
    time_indices = torch.arange(T, device=kv.device).view(1, T, 1)
    weighted_kv = kv * torch.pow(time_decay, time_indices)
    cumsum = torch.cumsum(weighted_kv, dim=1)
    wkv = torch.zeros_like(kv)
    wkv[:, 0] = time_first[:, 0] * kv[:, 0]
    if T > 1:
        wkv[:, 1:] = time_first[:, 1:] * kv[:, 1:] + time_decay[:, 1:] * cumsum[:, :-1]
    return wkv

Time Mixing: Recurrent Inference Mode

During inference, we generate tokens one at a time. Here, the recurrence shines—we simply maintain the accumulated state:

def forward_recurrent(self, x, state=None):
    # x is (B, C) for a single token
    x = self.ln(x)
    time_decay, time_first = self._get_wkv_weights(B, C, x.device)
    wkv_state = torch.zeros(B, C, device=x.device) if state is None else state[0]

    k_t = self.key(x)
    v_t = self.value(x)
    r_t = self.receptance(x)

    # WKV recurrence: single step
    wkv_t = time_first[:, 0] * k_t * v_t + time_decay[:, 0] * wkv_state

    # Apply receptance gate
    output = self.output(torch.sigmoid(r_t) * wkv_t)
    return output, (wkv_t,)

The state is just a single tensor of shape $(B, C)$—constant size regardless of how many tokens we've generated. Compare this to Transformers, which must store $(B, L, C)$ for the KV cache that grows with each token!

Channel Mixing: Squared ReLU Gating

The Channel Mixing module uses a GLU-style gating mechanism with squared ReLU:

class RWKVChannelMixer(nn.Module):
    def __init__(self, embed_dim, expand_factor=4):
        super().__init__()
        self.hidden_dim = embed_dim * expand_factor
        self.key = nn.Linear(embed_dim, self.hidden_dim, bias=False)
        self.receptance = nn.Linear(embed_dim, embed_dim, bias=False)
        self.value = nn.Linear(self.hidden_dim, embed_dim, bias=False)
        self.ln = nn.LayerNorm(embed_dim)

    def forward(self, x):
        x = self.ln(x)
        return torch.sigmoid(self.receptance(x)) * self.value(F.relu(self.key(x)) ** 2)

The squared ReLU ($\text{ReLU}(x)^2$) provides smoother gradients than standard ReLU while maintaining sparsity—a subtle but important design choice for training stability.

The Complete RWKV Model

Stacking Time Mixing and Channel Mixing blocks gives us the full RWKV model:

class RWKV(nn.Module):
    def __init__(self, vocab_size, embed_dim, num_layers, num_heads=1, expand_factor=4):
        super().__init__()
        self.embedding = nn.Embedding(vocab_size, embed_dim)
        self.blocks = nn.ModuleList([
            RWKVBlock(embed_dim, num_heads, expand_factor) for _ in range(num_layers)
        ])
        self.ln_out = nn.LayerNorm(embed_dim)
        self.head = nn.Linear(embed_dim, vocab_size, bias=False)
        # Weight tying (like Transformers)
        self.head.weight = self.embedding.weight

    def forward(self, x, states=None, use_recurrent=False):
        x = self.embedding(x)
        for i, block in enumerate(self.blocks):
            if use_recurrent:
                x, new_state = block(x, states[i], use_recurrent=True)
            else:
                x, _ = block(x)
        return self.head(self.ln_out(x)), states

Weight tying between the embedding and output projection is a standard technique that improves parameter efficiency and training stability—the same weights learn to both encode inputs and predict outputs.

Autoregressive Generation

With the recurrent mode implemented, autoregressive generation is straightforward:

@torch.no_grad()
def generate(self, prompt_tokens, max_new_tokens, temperature=1.0):
    self.eval()
    context = prompt_tokens.clone()
    states = None

    # Process prompt in parallel for efficiency
    logits, _ = self(context, use_recurrent=False)
    logits = logits[:, -1, :]  # get last token logits

    for _ in range(max_new_tokens):
        next_token = torch.multinomial(
            F.softmax(logits / temperature, dim=-1), num_samples=1
        )
        context = torch.cat([context, next_token], dim=1)
        # Single recurrent step: O(1) memory
        logits, states = self(next_token.squeeze(1), states, use_recurrent=True)

    return context

After processing the prompt in parallel, each new token requires only a single recurrent step with $O(1)$ memory—no KV cache to grow.

Implementation Gotchas

A few pitfalls to watch for when implementing RWKV:

  1. Decay Parameterization: The decay $w$ must be positive to ensure stable fading memory. Parameterize via sigmoid to keep it in $(0, 1)$.
  2. Time First Initialization: The $u$ parameter should be initialized to emphasize the current token ($u = 1.0$, giving $\sigma(u) \approx 0.73$).
  3. Gradient Clipping: RWKV can have gradient explosions in early training. Gradient clipping at norm 1.0 is essential.
  4. State Initialization: During recurrent inference, the state starts at zero, representing no prior context.

Next Steps: Benchmarking

The code is complete. In Part 3, we'll put RWKV to the test against Transformer and LSTM baselines on synthetic sequence tasks, measuring training convergence, inference latency, and memory usage—validating the $O(1)$ inference promise in practice.