Back to Hub

Deconstructing State Space Models: Part 3

Scaling to Mamba and Benchmarking vs Transformers

Abstract

In our previous installations, we established the mathematical foundations of State Space Models (Part 1) and translated those dynamical equations into a minimal, educational 1D SSM layer in pure PyTorch (Part 2). However, a single 1D sequence transformation mapping scalars to scalars remains insufficient for modeling language, where tokens are represented by deep $D_{model}$-dimensional embeddings. In Part 3, we bridge the gap between our basic 1D algorithmic core and modern architecture by wrapping our SSM into a full multi-channel "Mamba" block. Finally, we pit this custom architecture against a baseline PyTorch Transformer on a long-sequence task, demonstrating the promised theoretical scaling: quadratic performance degradation in Attention versus linear, memory-efficient scaling in State Space Models. The complete testing application is available in my GitHub repository .

1. Introduction

Transformers rule the Deep Learning era by brute-forcing global context routing via Self-Attention. However, every practitioner eventually hits the Attention memory wall. When a sequence length doubles, the computation requirement quadruples ($O(L^2)$ time and memory complexity).

State Space Models compress this contextual routing into a recurrent hidden state $h$. As established in Part 1 and Part 2, learning continuous representations and applying parallel discrete convolutions enables an SSM to execute sequence transformations in linear time ($O(L)$ or $O(L \log L)$ with FFT).

In this final post of the mini-series, we will scale our 1D layer to accept modern embedding representations, wrap it in the gated structure popularized by architectures like Mamba, and benchmark its empirical inference scaling against the incumbent Transformer class.

2. Building the Mamba Block

Our Simple1DSSM class from Part 2 operates strictly on sequences of scalar values. To process an input tensor $x \in \mathbb{R}^{B \times L \times D_{model}}$, we must project our embeddings and run the 1D SSM independently across derived channels.

Mamba-style architectures introduce two key upgrades to the basic SSM: expansion dimensions and multiplicative gating paths.

2.1 Projections and Gating

We split the incoming token vectors into two paths: an activation branch (the gate) and an SSM branch.

class MambaBlock(nn.Module):
    def __init__(self, d_model: int, expand: int = 2, d_state: int = 16):
        super().__init__()
        self.d_inner = d_model * expand
        
        # 1. Project input for the SSM path
        self.in_proj = nn.Linear(d_model, self.d_inner)
        
        # 2. Project input for the parallel Gate path
        self.gate_proj = nn.Linear(d_model, self.d_inner)
        
        # 3. Independent 1D SSMs for every channel
        self.ssm_channels = nn.ModuleList([
            Simple1DSSM(d_state=d_state) for _ in range(self.d_inner)
        ])
        
        # 4. Fold the high-dimensional representation back down
        self.out_proj = nn.Linear(self.d_inner, d_model)

Expanding $D_{model}$ into $D_{inner}$ gives the independent 1D channels room to process complex, disentangled representational features.

2.2 The Forward Pass

The forward execution runs the inputs through both projections, applies the sequence transformations, and mixes the outputs.

    def forward(self, x):
        # x shape: (batch_size, seq_len, d_model)
        x_proj = self.in_proj(x)
        
        # The gating mechanism acts as a dynamic filter
        gate = torch.sigmoid(self.gate_proj(x))
        
        # Run SSMs over each channel independently
        ssm_outputs = []
        for i, ssm in enumerate(self.ssm_channels):
            channel_x = x_proj[..., i] # Shape: (batch, seq_len)
            channel_y = ssm(channel_x)
            ssm_outputs.append(channel_y)
            
        ssm_out = torch.stack(ssm_outputs, dim=-1)
        
        # Element-wise gate modulation
        y = ssm_out * gate 
        
        # Final output projection
        return self.out_proj(y)

This SSM $\times$ Gate motif is crucial; it behaves similarly to a Gated Recurrent Unit (GRU) or LSTM gating mechanism but with significantly faster parallel sequence execution.

3. Benchmarking: SSM vs Transformer

With our fully functional PyTorch Mamba Block compiled, we can confront the primary hypothesis of State Space Models: Linear Scaling and Long-Range Routing.

We instantiated our MambaBlock alongside PyTorch's native nn.TransformerEncoderLayer. Both modules were constrained to matched parameter counts. To test their memory-routing capabilities, we mapped synthetic sequence tensors of varying lengths ($L \in \{256, 512, 1024, 2048\}$) through both models on a strict selective copying task: predicting the class of the very first token at the very end of the sequence.

We trained tiny proxy models for exactly 10 steps to observe how quickly each architecture degrades when stretched across long contexts.

3.1 Empirical Accuracy Results

Our quick local benchmark yielded the following accuracy metrics after 10 proxy training steps:

Notice how Mamba consistently maintains its ability to route the state information across thousands of steps, even as the sequence stretches to 2048 tokens. The Transformer, unable to compress internal state efficiently and relying solely on global pairwise comparisons, struggles to maintain the signal over distance given such limited training capacity.

3.2 The $O(L^2)$ Memory Wall

Beyond raw accuracy, the most staggering differentiation appears in the memory required to maintain the computational graph. Expanding a Transformer sequence from 2048 to 4096 strictly quadruples the memory requirements ($O(L^2)$) due to the exploding attention matrix. By contrast, the Mamba Block's global sequence convolution scales purely linearly ($O(L)$), circumventing quadratic collapse and allowing it to process massively larger contexts on the exact same hardware.

4. Conclusion

Over the course of this "Build in Public" series, we've dragged State Space Models out of the abstract realm of classical Control Theory and materialized them in modern deep learning code.

We've observed that mapping continuous dynamics into discrete matrices via Zero-Order Hold opens the door to $O(L)$ sequential convolutions for rapid training, while guaranteeing $O(1)$ memory requirements for autoregressive recurrent generation. By wrapping that basic math in a multi-channel GLU-style block, we replicated the core architecture powering the next generation of foundational models like Mamba.

While Transformers currently remain the undisputed kings of massive-scale pretraining due to immense community-driven optimization momentum, architectures combining the strengths of Attention with the efficiency of State Space Models are rapidly cementing themselves as the definitive standard for processing massive context windows and genomic sequences.

Thank you for following along with this deconstruction series!