Back to Hub

Deconstructing State Space Models: Part 2

Writing a 1D State Space Layer in PyTorch

Abstract

Part 1 covered the math behind SSMs and Mamba -- continuous ODEs, ZOH discretization, and the recurrence/convolution duality. Here we translate that math into code. Production implementations of S4 and Mamba use custom CUDA kernels and hardware-aware scans, but the core algorithm is simple enough to fit in a single PyTorch nn.Module. This post walks through a minimal, educational 1D LTI State Space layer in pure PyTorch. Full source is on GitHub .

1. Introduction

Official SSM repositories look intimidating -- Triton kernels, parallel prefix scans, low-level memory tricks. But the underlying operation -- mapping a 1D input sequence $x$ to an output $y$ through a hidden state $h$ -- is just standard tensor math. We will build a minimal nn.Module that implements a 1D LTI sequence model.

Three things to demonstrate:

  1. Defining continuous parameters $(\mathbf{A}, \mathbf{B}, \mathbf{C}, \Delta)$ and computing their discrete counterparts $(\bar{\mathbf{A}}, \bar{\mathbf{B}})$.
  2. Unrolling the system into a 1D convolutional kernel for parallel training.
  3. Step-by-step recurrent state updates for autoregressive inference.

2. Breaking Down the Components

To avoid dense matrix exponentials (slow and numerically fragile), we assume $\mathbf{A}$ is diagonal. This reduces it from an $N \times N$ matrix to an $N$-dimensional vector with element-wise operations.

2.1 Continuous Initialization

An SSM learns continuous-time parameters representing a dynamical system. The step size $\Delta$ maps these continuous dynamics into discrete token space.

class Simple1DSSM(nn.Module):
    def __init__(self, d_state: int):
        super().__init__()
        # Continuous-time parameters
        # A: (d_state) initialized with negative values for stable decay
        self.A = nn.Parameter(-torch.rand(d_state) - 1.0)

        # B: (d_state, 1) mapping scalar input to state
        self.B = nn.Parameter(torch.randn(d_state, 1))

        # C: (1, d_state) mapping state to scalar output
        self.C = nn.Parameter(torch.randn(1, d_state))

        # Delta: The step size, enforced strictly positive
        self.log_delta = nn.Parameter(torch.randn(1))

The diagonal constraint on $\mathbf{A}$ cuts memory from $O(N^2)$ to $O(N)$ without losing much expressiveness.

2.2 ZOH Discretization

Zero-Order Hold converts our continuous $\mathbf{A}$ into the discrete $\bar{\mathbf{A}}$ needed for token-level processing.

    def discretize(self):
        delta = torch.exp(self.log_delta)

        # The matrix exponential simplifies to an element-wise exp
        A_bar = torch.exp(delta * self.A)

        # Algebraic simplification for a diagonal A
        B_bar = ((A_bar - 1.0) / self.A).unsqueeze(-1) * self.B

        return A_bar, B_bar

Because $\mathbf{A}$ is diagonal, the matrix exponential $\exp(\Delta \mathbf{A})$ reduces to element-wise torch.exp(). No eigendecomposition required.

2.3 Sequence Convolution (Fast Training)

With the full sequence available at training time, we precompute an $L$-length impulse-response kernel $\mathbf{K}$. Entry $K[k]$ encodes how much token $x_{T-k}$ influences token $x_{T}$.

    def forward(self, x):
        """ Convolutional execution for rapid parallel training. """
        batch, L = x.shape
        A_bar, B_bar = self.discretize()

        # Compute global convolutional kernel: [CB, CAB, CA^2B, ...]
        K = torch.zeros(L, device=x.device)
        A_pow = torch.ones_like(self.A)

        for k in range(L):
            K[k] = (self.C @ (A_pow.unsqueeze(-1) * B_bar)).squeeze()
            A_pow = A_pow * A_bar

        K = K.view(1, 1, L)
        x = x.view(batch, 1, L)

        # Note: PyTorch's F.conv1d computes cross-correlation!
        # We must reverse/flip the kernel to make it a true causal convolution.
        K_flipped = torch.flip(K, dims=(-1,))
        y = F.conv1d(x, K_flipped, padding=L-1)

        y = y[..., :L] # Slice to original length
        return y.squeeze(1)

Two details worth noting: we pad by $L-1$ to maintain causality, and we explicitly flip the kernel before calling F.conv1d. PyTorch's F.conv1d computes cross-correlation, not convolution -- skip the flip and future tokens leak backward in time.

2.4 O(1) Autoregressive Inference

At generation time, we switch to recurrent mode: ingest one scalar $x_k$, update the fixed-size state via $\bar{\mathbf{A}}$ and $\bar{\mathbf{B}}$, emit $y_k$.

    def step(self, x_k, h_prev):
        """ Recurrent execution for O(1) memory autoregressive inference. """
        A_bar, B_bar = self.discretize()

        # State Update: h_k = A_bar * h_prev + B_bar * x_k
        h_k = A_bar * h_prev + (B_bar.squeeze(-1) * x_k)

        # Output Projection: y_k = C * h_k
        y_k = (self.C * h_k).sum(dim=-1, keepdim=True)

        return y_k, h_k

Token 10 and token 10,000 cost the same. The state size is fixed regardless of how long the sequence gets.

3. Debugging the Duality

My first sanity check -- verifying that forward() and a loop of step() calls produce identical outputs -- failed with a discrepancy of $1.4 \times 10^{11}$. Two bugs:

4. Next Steps

The full Simple1DSSM module is in the GitHub repo.

What we have here maps a single scalar sequence to another scalar sequence. Real language models operate over $D_{model}$-dimensional embeddings. In Part 3, we wrap this 1D core into a multi-channel Mamba block and benchmark it against a Transformer on a long-sequence copying task.