In Part 1, we explored the beautiful mathematics driving State Space Models (SSMs) and Mamba, breaking down the underlying continuous differential equations, Zero-Order Hold discretization, and the core duality between recurrent inference and convolutional training. In Part 2, we leave the whiteboards behind and step into the code. While the official implementations of architectures like S4 and Mamba rely on highly optimized, complex custom CUDA kernels and hardware-aware associative scans for production performance, the mathematical heart of an SSM is surprisingly straightforward. In this post, we will walk through the core components of a functional, educational 1-Dimensional Linear State Space layer written in pure PyTorch. The complete source code is available in my GitHub repository .
1. Introduction
Deep learning architectures often seem intimidating when you glance at their official repositories. Production-grade models must run efficiently on hardware, so the foundational logic is frequently hidden behind layers of Triton blocks, parallel prefix scans, and low-level memory optimizations.
However, the mathematical heart of an SSM—mapping a 1D sequence of inputs $x$ to an output sequence $y$
via a hidden state $h$—merely involves standard tensor operations. In this post, we will dissect a
minimal PyTorch nn.Module that acts as a 1D Linear Time-Invariant (LTI) sequence model.
Our unoptimized, educational implementation will demonstrate three essential concepts:
- Defining the continuous parameters $(\mathbf{A}, \mathbf{B}, \mathbf{C}, \Delta)$ and computing their discrete counterparts $(\bar{\mathbf{A}}, \bar{\mathbf{B}})$.
- Unrolling the system into a massive 1D convolutional kernel to process entire sequences in parallel during training.
- Processing inputs dynamically, step-by-step using recurrent state updates during autoregressive inference.
2. Breaking Down the Components
We make a standard assumption to simplify the mathematics and avoid computing dense matrix exponentials (which are slow and numerically unstable): the state matrix $\mathbf{A}$ is diagonal. Instead of maintaining an $N \times N$ matrix, $\mathbf{A}$ becomes an $N$-dimensional vector that strictly operates element-wise.
Let's dissect the most critical mechanisms of our layer.
2.1 Continuous Initialization
Real language models process discrete tokens, yet an SSM actually learns continuous parameters representing a physical dynamical system. This separation lets the model learn real underlying dynamics mapped into discrete token space via a step size $\Delta$.
class Simple1DSSM(nn.Module):
def __init__(self, d_state: int):
super().__init__()
# Continuous-time parameters
# A: (d_state) initialized with negative values for stable decay
self.A = nn.Parameter(-torch.rand(d_state) - 1.0)
# B: (d_state, 1) mapping scalar input to state
self.B = nn.Parameter(torch.randn(d_state, 1))
# C: (1, d_state) mapping state to scalar output
self.C = nn.Parameter(torch.randn(1, d_state))
# Delta: The step size, enforced strictly positive
self.log_delta = nn.Parameter(torch.randn(1))
By constraining $\mathbf{A}$ to be diagonal, we drastically reduce memory footprint while retaining rich expressive capacity.
2.2 ZOH Discretization
To process discrete text tokens, we use Zero-Order Hold (ZOH) to bridge continuous physics with discrete sequence modeling. This maps our continuous $\mathbf{A}$ onto discrete steps $\bar{\mathbf{A}}$.
def discretize(self):
delta = torch.exp(self.log_delta)
# The matrix exponential simplifies to an element-wise exp
A_bar = torch.exp(delta * self.A)
# Algebraic simplification for a diagonal A
B_bar = ((A_bar - 1.0) / self.A).unsqueeze(-1) * self.B
return A_bar, B_bar
The mathematical magic here is how elegantly dense matrix exponentials collapse into fast, standard
torch.exp() calls since $\mathbf{A}$ is modeled as a diagonal vector.
2.3 Massive Sequence Convolutions (Fast Training)
During training, we have entire sequences available simultaneously. Because an LTI model applies exact static matrix transitions to every token, we compute an $L$-dimensional unrolled response kernel $\mathbf{K}$ ahead of time. This kernel defines how much token $x_{T-k}$ influences token $x_{T}$.
def forward(self, x):
""" Convolutional execution for rapid parallel training. """
batch, L = x.shape
A_bar, B_bar = self.discretize()
# Compute global convolutional kernel: [CB, CAB, CA^2B, ...]
K = torch.zeros(L, device=x.device)
A_pow = torch.ones_like(self.A)
for k in range(L):
K[k] = (self.C @ (A_pow.unsqueeze(-1) * B_bar)).squeeze()
A_pow = A_pow * A_bar
K = K.view(1, 1, L)
x = x.view(batch, 1, L)
# Note: PyTorch's F.conv1d computes cross-correlation!
# We must reverse/flip the kernel to make it a true causal convolution.
K_flipped = torch.flip(K, dims=(-1,))
y = F.conv1d(x, K_flipped, padding=L-1)
y = y[..., :L] # Slice to original length
return y.squeeze(1)
Once $\mathbf{K}$ is materialized, we use PyTorch's optimized F.conv1d. Padding the input
sequence by $L - 1$ zeros is a crucial trick to maintain causality. Furthermore, this code explicitly
addresses a notorious PyTorch gotcha: F.conv1d structurally performs
cross-correlation, not mathematical convolution. We must physically flip our sequence kernel
before applying it; otherwise, future tokens bleed backwards in time.
2.4 O(1) Autoregressive Inference
When generating tokens one-by-one under deployment, we swap to a recurrent mode. At every step, we ingest a scalar $x_k$, multiply the pre-existing constant-sized hidden state $h$ by $\bar{\mathbf{A}}$, inject the new token via $\bar{\mathbf{B}}$, and emit $y_k$.
def step(self, x_k, h_prev):
""" Recurrent execution for O(1) memory autoregressive inference. """
A_bar, B_bar = self.discretize()
# State Update: h_k = A_bar * h_prev + B_bar * x_k
h_k = A_bar * h_prev + (B_bar.squeeze(-1) * x_k)
# Output Projection: y_k = C * h_k
y_k = (self.C * h_k).sum(dim=-1, keepdim=True)
return y_k, h_k
Whether calculating the response for token 10 or token 10,000, the computation overhead stays perfectly flat, effectively bypassing the memory-scaling wall that cripples standard Attention models.
3. Debugging the Duality
When building this implementation, my initial test to ensure both functions mathematically matched failed dramatically—the difference between the unrolled Convolution and the Step-by-Step recurrence was massive ($1.4 \times 10^{11}$).
This highlighted two vital, non-obvious engineering realities when translating Control Theory to code:
- Exploding Exponentials: Initially, $\mathbf{A}$ was initialized via a standard
torch.randn(), allowing positive continuous values. A positive state parameter undergoing exponential mapping inside ZOH causes the rolling state memory to blow up mathematically. I corrected this by using strictly negative initializations (-torch.rand() - 1.0), forcing the dynamical system to stabilize into a healthy decay. - Cross-Correlation vs Convolution: The aforementioned mathematical divergence was
rooted in PyTorch's
F.conv1dnot actually behaving as formal convolution. The output only matched identically ($\sim 1 \times 10^{-6}$ float error) once the unrolled response kernel was flipped explicitly viatorch.flip.
4. Conclusion and Next Steps
By examining isolated chunks, we've dissected the theoretical bedrock of State Space Models in PyTorch. The duality between recurrent step-by-step $O(1)$ memory execution and parallel sequence convolutional training is an elegant property inherent to classical Control Theory mapped into neural networks.
You can find the full, unbroken Simple1DSSM module implementation in my GitHub repository, ready for
experimentation.
However, the architecture we built maps a 1D scalar to another scalar. Real language models are composed of deep, multi-channel embeddings operating over $D_{model}$ dimensions.
In Part 3 of our series, we will integrate this basic 1D component into a multi-headed block, enabling it to handle modern token embeddings natively, and pit it competitively against a Transformer on a long-sequence task.