Back to Hub

Deconstructing State Space Models: Part 3

Scaling to Mamba and Benchmarking vs Transformers

Abstract

Parts 1 and 2 built the math and a minimal 1D SSM layer. But a scalar-to-scalar model cannot process language -- tokens live in $D_{model}$-dimensional embedding space. In this final post, we wrap our 1D SSM into a multi-channel gated Mamba block, then benchmark it against a PyTorch Transformer on a selective copying task at sequence lengths from 256 to 2048. The results confirm the core claim: Attention memory scales as $O(L^2)$, while the SSM stays linear. Code is on GitHub .

1. Introduction

Self-Attention scales as $O(L^2)$. Double the sequence, quadruple the memory. SSMs compress context into a recurrent hidden state $h$ and process sequences in $O(L)$ time (or $O(L \log L)$ via FFT).

This post scales our 1D layer from Part 2 to handle $D_{model}$-dimensional embeddings, wraps it in Mamba's gated architecture, and runs a head-to-head benchmark against a Transformer.

2. Building the Mamba Block

The Simple1DSSM from Part 2 processes scalar sequences. To handle an input tensor $x \in \mathbb{R}^{B \times L \times D_{model}}$, we project embeddings into an expanded space and run independent 1D SSMs per channel. Two upgrades over the bare SSM: dimension expansion and multiplicative gating.

2.1 Projections and Gating

Input vectors split into two parallel paths: one feeds the SSM channels, the other produces a sigmoid gate.

class MambaBlock(nn.Module):
    def __init__(self, d_model: int, expand: int = 2, d_state: int = 16):
        super().__init__()
        self.d_inner = d_model * expand

        # 1. Project input for the SSM path
        self.in_proj = nn.Linear(d_model, self.d_inner)

        # 2. Project input for the parallel Gate path
        self.gate_proj = nn.Linear(d_model, self.d_inner)

        # 3. Independent 1D SSMs for every channel
        self.ssm_channels = nn.ModuleList([
            Simple1DSSM(d_state=d_state) for _ in range(self.d_inner)
        ])

        # 4. Fold the high-dimensional representation back down
        self.out_proj = nn.Linear(self.d_inner, d_model)

The expansion from $D_{model}$ to $D_{inner} = D_{model} \times \text{expand}$ gives each independent SSM channel its own slice of the representation to work with.

2.2 The Forward Pass

Both projections run in parallel. The SSM outputs are element-wise multiplied by the gate, then projected back to $D_{model}$.

    def forward(self, x):
        # x shape: (batch_size, seq_len, d_model)
        x_proj = self.in_proj(x)

        # The gating mechanism acts as a dynamic filter
        gate = torch.sigmoid(self.gate_proj(x))

        # Run SSMs over each channel independently
        ssm_outputs = []
        for i, ssm in enumerate(self.ssm_channels):
            channel_x = x_proj[..., i] # Shape: (batch, seq_len)
            channel_y = ssm(channel_x)
            ssm_outputs.append(channel_y)

        ssm_out = torch.stack(ssm_outputs, dim=-1)

        # Element-wise gate modulation
        y = ssm_out * gate

        # Final output projection
        return self.out_proj(y)

The SSM * Gate pattern is functionally similar to GRU/LSTM gating but benefits from the SSM's parallel convolution during training.

3. Benchmarking: SSM vs Transformer

Setup: our MambaBlock vs. PyTorch's nn.TransformerEncoderLayer, matched on parameter count. The task is selective copying -- the model must predict the class of the first token after reading the entire sequence. We tested at $L \in \{256, 512, 1024, 2048\}$ with 10 training steps each.

3.1 Accuracy After 10 Steps

Mamba holds steady across all sequence lengths. The Transformer, relying on pairwise comparisons across the full context, loses signal as $L$ grows -- especially given so few training steps to learn the routing.

3.2 The $O(L^2)$ Memory Wall

The sharper difference is in memory. Going from $L=2048$ to $L=4096$ quadruples the Transformer's attention matrix memory ($O(L^2)$). The Mamba block's convolution kernel grows linearly ($O(L)$), so it can handle much longer contexts on the same hardware.

4. Conclusion

Across these three posts, we went from continuous-time ODEs to a working multi-channel Mamba block benchmarked against a Transformer. The key results: ZOH discretization gives us $O(L)$ training via convolution and $O(1)$ memory at inference via recurrence. Wrapping that in a GLU-style gated block produces a model that holds up against Attention on long-range tasks while scaling linearly.

Transformers still dominate large-scale pretraining thanks to years of engineering optimization, but hybrid architectures mixing Attention and SSM layers are gaining ground -- particularly for workloads where context length is the bottleneck.