Part 1 covered the math behind SSMs and Mamba -- continuous ODEs, ZOH discretization, and the recurrence/convolution duality. Here we translate that math into code. Production implementations of S4 and Mamba use custom CUDA kernels and hardware-aware scans, but the core algorithm is simple enough to fit in a single PyTorch nn.Module. This post walks through a minimal, educational 1D LTI State Space layer in pure PyTorch. Full source is on GitHub .
1. Introduction
Official SSM repositories look intimidating -- Triton kernels, parallel prefix scans, low-level memory tricks. But the underlying operation -- mapping a 1D input sequence $x$ to an output $y$ through a hidden state $h$ -- is just standard tensor math. We will build a minimal nn.Module that implements a 1D LTI sequence model.
Three things to demonstrate:
- Defining continuous parameters $(\mathbf{A}, \mathbf{B}, \mathbf{C}, \Delta)$ and computing their discrete counterparts $(\bar{\mathbf{A}}, \bar{\mathbf{B}})$.
- Unrolling the system into a 1D convolutional kernel for parallel training.
- Step-by-step recurrent state updates for autoregressive inference.
2. Breaking Down the Components
To avoid dense matrix exponentials (slow and numerically fragile), we assume $\mathbf{A}$ is diagonal. This reduces it from an $N \times N$ matrix to an $N$-dimensional vector with element-wise operations.
2.1 Continuous Initialization
An SSM learns continuous-time parameters representing a dynamical system. The step size $\Delta$ maps these continuous dynamics into discrete token space.
class Simple1DSSM(nn.Module):
def __init__(self, d_state: int):
super().__init__()
# Continuous-time parameters
# A: (d_state) initialized with negative values for stable decay
self.A = nn.Parameter(-torch.rand(d_state) - 1.0)
# B: (d_state, 1) mapping scalar input to state
self.B = nn.Parameter(torch.randn(d_state, 1))
# C: (1, d_state) mapping state to scalar output
self.C = nn.Parameter(torch.randn(1, d_state))
# Delta: The step size, enforced strictly positive
self.log_delta = nn.Parameter(torch.randn(1))
The diagonal constraint on $\mathbf{A}$ cuts memory from $O(N^2)$ to $O(N)$ without losing much expressiveness.
2.2 ZOH Discretization
Zero-Order Hold converts our continuous $\mathbf{A}$ into the discrete $\bar{\mathbf{A}}$ needed for token-level processing.
def discretize(self):
delta = torch.exp(self.log_delta)
# The matrix exponential simplifies to an element-wise exp
A_bar = torch.exp(delta * self.A)
# Algebraic simplification for a diagonal A
B_bar = ((A_bar - 1.0) / self.A).unsqueeze(-1) * self.B
return A_bar, B_bar
Because $\mathbf{A}$ is diagonal, the matrix exponential $\exp(\Delta \mathbf{A})$ reduces to element-wise torch.exp(). No eigendecomposition required.
2.3 Sequence Convolution (Fast Training)
With the full sequence available at training time, we precompute an $L$-length impulse-response kernel $\mathbf{K}$. Entry $K[k]$ encodes how much token $x_{T-k}$ influences token $x_{T}$.
def forward(self, x):
""" Convolutional execution for rapid parallel training. """
batch, L = x.shape
A_bar, B_bar = self.discretize()
# Compute global convolutional kernel: [CB, CAB, CA^2B, ...]
K = torch.zeros(L, device=x.device)
A_pow = torch.ones_like(self.A)
for k in range(L):
K[k] = (self.C @ (A_pow.unsqueeze(-1) * B_bar)).squeeze()
A_pow = A_pow * A_bar
K = K.view(1, 1, L)
x = x.view(batch, 1, L)
# Note: PyTorch's F.conv1d computes cross-correlation!
# We must reverse/flip the kernel to make it a true causal convolution.
K_flipped = torch.flip(K, dims=(-1,))
y = F.conv1d(x, K_flipped, padding=L-1)
y = y[..., :L] # Slice to original length
return y.squeeze(1)
Two details worth noting: we pad by $L-1$ to maintain causality, and we explicitly flip the kernel before calling F.conv1d. PyTorch's F.conv1d computes cross-correlation, not convolution -- skip the flip and future tokens leak backward in time.
2.4 O(1) Autoregressive Inference
At generation time, we switch to recurrent mode: ingest one scalar $x_k$, update the fixed-size state via $\bar{\mathbf{A}}$ and $\bar{\mathbf{B}}$, emit $y_k$.
def step(self, x_k, h_prev):
""" Recurrent execution for O(1) memory autoregressive inference. """
A_bar, B_bar = self.discretize()
# State Update: h_k = A_bar * h_prev + B_bar * x_k
h_k = A_bar * h_prev + (B_bar.squeeze(-1) * x_k)
# Output Projection: y_k = C * h_k
y_k = (self.C * h_k).sum(dim=-1, keepdim=True)
return y_k, h_k
Token 10 and token 10,000 cost the same. The state size is fixed regardless of how long the sequence gets.
3. Debugging the Duality
My first sanity check -- verifying that forward() and a loop of step() calls produce identical outputs -- failed with a discrepancy of $1.4 \times 10^{11}$. Two bugs:
- Exploding exponentials: I initialized $\mathbf{A}$ with
torch.randn(), which allows positive values. Positive entries in $\mathbf{A}$ cause $\exp(\Delta \mathbf{A})$ to blow up. Fixing the initialization to-torch.rand() - 1.0(strictly negative) forced stable decay. - Cross-correlation vs. convolution: The remaining mismatch disappeared (down to $\sim 10^{-6}$ float precision) once I added
torch.flipbefore theF.conv1dcall.
4. Next Steps
The full Simple1DSSM module is in the GitHub repo.
What we have here maps a single scalar sequence to another scalar sequence. Real language models operate over $D_{model}$-dimensional embeddings. In Part 3, we wrap this 1D core into a multi-channel Mamba block and benchmark it against a Transformer on a long-sequence copying task.