Back to SNNs Hub

Deconstructing Spiking Neural Networks (SNNs)

Part 2: Writing a LIF Layer in Pure PyTorch

Introduction

In Part 1 we derived the LIF neuron and hit the wall: the spike emission $S[t] = \Theta(V[t] - V_{th})$ is a Heaviside function whose derivative is zero almost everywhere. Backpropagation breaks.

Here we fix that and build a working LIF layer in pure PyTorch--no external SNN frameworks.

The Surrogate Gradient Trick

We do not need the true gradient of the Heaviside. We need a gradient signal good enough to push the weights in a useful direction. The fix: keep the forward pass exact, but substitute a smooth approximation on the backward pass.

Forward pass (unchanged):

$$ S[t] = \Theta(V[t] - V_{th}) = \begin{cases} 1 & \text{if } V[t] \geq V_{th} \\ 0 & \text{otherwise} \end{cases} $$

Backward pass -- derivative of the fast sigmoid surrogate:

$$ \frac{\partial \hat{S}}{\partial V} \approx \frac{1}{\left(1 + k \cdot |V - V_{th}|\right)^2} $$

This peaks at $V = V_{th}$ and decays to zero away from the threshold. The sharpness $k$ controls how tightly the surrogate hugs the true step; we use $k = 10$. At the threshold, the surrogate gradient equals 1---full signal passes through. A few units away, it drops to near zero. The gradient is strongest exactly where the neuron is "about to fire," which makes biological sense: small changes in synaptic weight should matter most when the neuron is close to its decision boundary.

Implementing the Surrogate in PyTorch

PyTorch's torch.autograd.Function lets us define separate forward and backward logic:

class SurrogateHeaviside(torch.autograd.Function):
    @staticmethod
    def forward(ctx, membrane_potential, threshold):
        ctx.save_for_backward(membrane_potential)
        ctx.threshold = threshold
        return (membrane_potential >= threshold).float()

    @staticmethod
    def backward(ctx, grad_output):
        (membrane_potential,) = ctx.saved_tensors
        threshold = ctx.threshold
        k = 10.0
        shifted = membrane_potential - threshold
        surrogate = 1.0 / (1.0 + k * shifted.abs()) ** 2
        return grad_output * surrogate, None

The LIF Neuron Layer

With the surrogate in place, the LIF update equations from Part 1 map directly to an nn.Module:

class LIFNeuron(nn.Module):
    def __init__(self, num_neurons, decay=0.9, threshold=1.0):
        super().__init__()
        self.threshold = threshold
        # Learnable decay per neuron (kept in (0,1) via sigmoid)
        self._decay_logit = nn.Parameter(
            torch.full((num_neurons,), inv_sigmoid(decay))
        )

    @property
    def decay(self):
        return torch.sigmoid(self._decay_logit)

    def forward(self, current, membrane):
        # 1. Leak + integrate
        new_membrane = self.decay * membrane + current
        # 2. Fire (surrogate gradient backward)
        spike = SurrogateHeaviside.apply(new_membrane, self.threshold)
        # 3. Soft reset: subtract threshold from fired neurons
        new_membrane = new_membrane - spike * self.threshold
        return spike, new_membrane

The decay factor $\beta$ is learnable per neuron, constrained to $(0,1)$ via sigmoid. The network learns how much memory each neuron retains---similar to the adaptive time constant in Liquid Neural Networks. Concretely, the raw parameter is stored as a logit (via the inverse sigmoid), and the effective decay is recovered as torch.sigmoid(self._decay_logit). This reparameterization ensures unconstrained gradient-based optimization while guaranteeing the decay stays in the valid $(0,1)$ range throughout training.

The soft reset in step 3 is worth noting: rather than hard-resetting the membrane to zero after a spike, we subtract the threshold value. If a neuron received a very strong input that pushed it well past threshold, the excess charge carries over into the next timestep. This preserves information that would otherwise be lost and tends to produce more stable training dynamics.

Assembling the SNN Classifier

The full classifier alternates linear projections with LIF layers. The same input is presented at each of $T=25$ timesteps, and the readout accumulates output logits:

class SNNClassifier(nn.Module):
    def __init__(self, input_size=784, hidden_size=256, output_size=10):
        super().__init__()
        self.fc1    = nn.Linear(input_size, hidden_size)
        self.fc2    = nn.Linear(hidden_size, hidden_size)
        self.fc_out = nn.Linear(hidden_size, output_size)
        self.lif1   = LIFNeuron(hidden_size)
        self.lif2   = LIFNeuron(hidden_size)

    def forward(self, x, num_steps=25):
        mem1, mem2 = self.lif1.init_membrane(x.shape[0]), \
                     self.lif2.init_membrane(x.shape[0])
        accumulator = torch.zeros(x.shape[0], 10)
        for _ in range(num_steps):
            spk1, mem1 = self.lif1(self.fc1(x), mem1)
            spk2, mem2 = self.lif2(self.fc2(spk1), mem2)
            accumulator += self.fc_out(spk2)
        return accumulator

Each timestep replicates the input, spikes through both hidden layers, and adds to the accumulator. The predicted class is the argmax of the accumulated logits. This is rate coding in action: the network presents the same static input at every one of the $T = 25$ timesteps, and the output neurons that fire most frequently---or accumulate the largest raw logits---determine the predicted class. The temporal unrolling is conceptually similar to running a recurrent network for $T$ steps, except the recurrent state is the membrane potential and the nonlinearity is a binary spike rather than a tanh or sigmoid.

One architectural note: the readout layer (fc_out) accumulates raw logits from the second LIF layer's spikes, not additional spikes. This avoids the information bottleneck of forcing the output to be binary, and it lets us apply standard CrossEntropyLoss directly on the accumulated output. The model has 269,834 trainable parameters---nearly identical to a comparable 2-layer MLP (269,322 params), ensuring a fair benchmark in Part 3.

Why Spikes Save Energy

On GPUs, every ANN weight interaction is a floating-point Multiply-Accumulate (MAC). On neuromorphic chips (Intel Loihi, IBM TrueNorth), a spike triggers a simple integer Accumulate (AC)--no multiply. ACs cost roughly $5\times$ less energy than MACs.

Two factors compound:

  1. Operation type: AC instead of MAC (~$5\times$ cheaper per op).
  2. Sparsity: silent neurons generate zero operations.

In Part 3 we measure this directly. The trained SNN fires at a 9.7% average rate---meaning 90.3% of neurons are completely silent at any given timestep. Their synaptic weights are never read; no current accumulates; no energy is spent. The effective energy cost is computed by converting SOPs to MAC-equivalent units:

$$ \text{SNN effective energy} = \frac{\text{SOPs}}{5} = \frac{649{,}765}{5} \approx 129{,}953 \text{ MAC-equivalents} $$

Compared to the ANN's 268,800 MACs, this is a 51.7% energy reduction---the SNN uses just 48% of the energy while losing only 0.47 percentage points of accuracy on MNIST. These gains compound at scale: wider networks have more neurons that can stay silent, and temporal coding techniques can push the average firing rate below 5%, amplifying efficiency further.

What Comes Next

Part 3 puts this SNN head-to-head against a parameter-matched MLP on MNIST for 10 epochs: accuracy curves, Synaptic Operation counts vs. MACs, and the full energy comparison.