Back to Hub

Deconstructing State Space Models: Part 2

Writing a 1D State Space Layer in PyTorch

Abstract

In Part 1, we explored the beautiful mathematics driving State Space Models (SSMs) and Mamba, breaking down the underlying continuous differential equations, Zero-Order Hold discretization, and the core duality between recurrent inference and convolutional training. In Part 2, we leave the whiteboards behind and step into the code. While the official implementations of architectures like S4 and Mamba rely on highly optimized, complex custom CUDA kernels and hardware-aware associative scans for production performance, the mathematical heart of an SSM is surprisingly straightforward. In this post, we will walk through the core components of a functional, educational 1-Dimensional Linear State Space layer written in pure PyTorch. The complete source code is available in my GitHub repository .

1. Introduction

Deep learning architectures often seem intimidating when you glance at their official repositories. Production-grade models must run efficiently on hardware, so the foundational logic is frequently hidden behind layers of Triton blocks, parallel prefix scans, and low-level memory optimizations.

However, the mathematical heart of an SSM—mapping a 1D sequence of inputs $x$ to an output sequence $y$ via a hidden state $h$—merely involves standard tensor operations. In this post, we will dissect a minimal PyTorch nn.Module that acts as a 1D Linear Time-Invariant (LTI) sequence model.

Our unoptimized, educational implementation will demonstrate three essential concepts:

  1. Defining the continuous parameters $(\mathbf{A}, \mathbf{B}, \mathbf{C}, \Delta)$ and computing their discrete counterparts $(\bar{\mathbf{A}}, \bar{\mathbf{B}})$.
  2. Unrolling the system into a massive 1D convolutional kernel to process entire sequences in parallel during training.
  3. Processing inputs dynamically, step-by-step using recurrent state updates during autoregressive inference.

2. Breaking Down the Components

We make a standard assumption to simplify the mathematics and avoid computing dense matrix exponentials (which are slow and numerically unstable): the state matrix $\mathbf{A}$ is diagonal. Instead of maintaining an $N \times N$ matrix, $\mathbf{A}$ becomes an $N$-dimensional vector that strictly operates element-wise.

Let's dissect the most critical mechanisms of our layer.

2.1 Continuous Initialization

Real language models process discrete tokens, yet an SSM actually learns continuous parameters representing a physical dynamical system. This separation lets the model learn real underlying dynamics mapped into discrete token space via a step size $\Delta$.

class Simple1DSSM(nn.Module):
    def __init__(self, d_state: int):
        super().__init__()
        # Continuous-time parameters
        # A: (d_state) initialized with negative values for stable decay
        self.A = nn.Parameter(-torch.rand(d_state) - 1.0)
        
        # B: (d_state, 1) mapping scalar input to state
        self.B = nn.Parameter(torch.randn(d_state, 1))
        
        # C: (1, d_state) mapping state to scalar output
        self.C = nn.Parameter(torch.randn(1, d_state))
        
        # Delta: The step size, enforced strictly positive
        self.log_delta = nn.Parameter(torch.randn(1))

By constraining $\mathbf{A}$ to be diagonal, we drastically reduce memory footprint while retaining rich expressive capacity.

2.2 ZOH Discretization

To process discrete text tokens, we use Zero-Order Hold (ZOH) to bridge continuous physics with discrete sequence modeling. This maps our continuous $\mathbf{A}$ onto discrete steps $\bar{\mathbf{A}}$.

    def discretize(self):
        delta = torch.exp(self.log_delta)
        
        # The matrix exponential simplifies to an element-wise exp
        A_bar = torch.exp(delta * self.A) 
        
        # Algebraic simplification for a diagonal A
        B_bar = ((A_bar - 1.0) / self.A).unsqueeze(-1) * self.B
        
        return A_bar, B_bar

The mathematical magic here is how elegantly dense matrix exponentials collapse into fast, standard torch.exp() calls since $\mathbf{A}$ is modeled as a diagonal vector.

2.3 Massive Sequence Convolutions (Fast Training)

During training, we have entire sequences available simultaneously. Because an LTI model applies exact static matrix transitions to every token, we compute an $L$-dimensional unrolled response kernel $\mathbf{K}$ ahead of time. This kernel defines how much token $x_{T-k}$ influences token $x_{T}$.

    def forward(self, x):
        """ Convolutional execution for rapid parallel training. """
        batch, L = x.shape
        A_bar, B_bar = self.discretize()
        
        # Compute global convolutional kernel: [CB, CAB, CA^2B, ...]
        K = torch.zeros(L, device=x.device)
        A_pow = torch.ones_like(self.A) 
        
        for k in range(L):
            K[k] = (self.C @ (A_pow.unsqueeze(-1) * B_bar)).squeeze()
            A_pow = A_pow * A_bar
            
        K = K.view(1, 1, L)
        x = x.view(batch, 1, L)
        
        # Note: PyTorch's F.conv1d computes cross-correlation!
        # We must reverse/flip the kernel to make it a true causal convolution.
        K_flipped = torch.flip(K, dims=(-1,))
        y = F.conv1d(x, K_flipped, padding=L-1)
        
        y = y[..., :L] # Slice to original length
        return y.squeeze(1)

Once $\mathbf{K}$ is materialized, we use PyTorch's optimized F.conv1d. Padding the input sequence by $L - 1$ zeros is a crucial trick to maintain causality. Furthermore, this code explicitly addresses a notorious PyTorch gotcha: F.conv1d structurally performs cross-correlation, not mathematical convolution. We must physically flip our sequence kernel before applying it; otherwise, future tokens bleed backwards in time.

2.4 O(1) Autoregressive Inference

When generating tokens one-by-one under deployment, we swap to a recurrent mode. At every step, we ingest a scalar $x_k$, multiply the pre-existing constant-sized hidden state $h$ by $\bar{\mathbf{A}}$, inject the new token via $\bar{\mathbf{B}}$, and emit $y_k$.

    def step(self, x_k, h_prev):
        """ Recurrent execution for O(1) memory autoregressive inference. """
        A_bar, B_bar = self.discretize()
        
        # State Update: h_k = A_bar * h_prev + B_bar * x_k
        h_k = A_bar * h_prev + (B_bar.squeeze(-1) * x_k)
        
        # Output Projection: y_k = C * h_k
        y_k = (self.C * h_k).sum(dim=-1, keepdim=True)
        
        return y_k, h_k

Whether calculating the response for token 10 or token 10,000, the computation overhead stays perfectly flat, effectively bypassing the memory-scaling wall that cripples standard Attention models.

3. Debugging the Duality

When building this implementation, my initial test to ensure both functions mathematically matched failed dramatically—the difference between the unrolled Convolution and the Step-by-Step recurrence was massive ($1.4 \times 10^{11}$).

This highlighted two vital, non-obvious engineering realities when translating Control Theory to code:

4. Conclusion and Next Steps

By examining isolated chunks, we've dissected the theoretical bedrock of State Space Models in PyTorch. The duality between recurrent step-by-step $O(1)$ memory execution and parallel sequence convolutional training is an elegant property inherent to classical Control Theory mapped into neural networks.

You can find the full, unbroken Simple1DSSM module implementation in my GitHub repository, ready for experimentation.

However, the architecture we built maps a 1D scalar to another scalar. Real language models are composed of deep, multi-channel embeddings operating over $D_{model}$ dimensions.

In Part 3 of our series, we will integrate this basic 1D component into a multi-headed block, enabling it to handle modern token embeddings natively, and pit it competitively against a Transformer on a long-sequence task.