Back to SNNs Hub

Deconstructing Spiking Neural Networks (SNNs)

Part 1: Biology, Spikes, and the LIF Model

Introduction

Every neural network you have ever trained—CNN, Transformer, LSTM—shares one fundamental assumption: neurons communicate via continuous-valued activations. A ReLU neuron outputs a floating-point number; a softmax layer produces a probability vector. These real-valued signals flow through dense matrix multiplications on every forward pass, demanding billions of multiply-accumulate (MAC) operations for even modest models.

The brain does none of this. Biological neurons communicate via discrete, all-or-nothing electrical pulses called action potentials (or spikes). This sparse, event-driven communication is why a human brain performs remarkable cognitive feats on roughly 20 Watts—while a GPU-accelerated Transformer requires kilowatts for comparable tasks.

Spiking Neural Networks (SNNs) are the attempt to bring this paradigm into machine learning. In this 3-part series, we will completely deconstruct SNNs: from the biological neuron model and the infamous gradient problem, to a full PyTorch implementation, to a benchmark against a standard ANN on MNIST.

The Biological Neuron

A biological neuron integrates electrical charge from thousands of synaptic inputs through its dendrites. This charge gradually builds the neuron's membrane potential $V_m$. If $V_m$ crosses a threshold $V_{th}$, the neuron fires a spike—a stereotyped 1–2 ms voltage pulse—which travels down the axon to stimulate or inhibit downstream neurons.

Crucially, after firing the neuron's potential is rapidly reset and then hyperpolarized (driven below resting potential) for a short refractory period. Information is not encoded in the amplitude of the spike (they are all the same height). It is encoded in timing and rate—when and how often the neuron fires.

The Leaky Integrate-and-Fire (LIF) Model

The simplest and most widely used computational model of the spiking neuron is the Leaky Integrate-and-Fire (LIF) model. It captures the essential dynamics in a single first-order differential equation:

$$ \tau_m \frac{dV(t)}{dt} = -V(t) + RI(t) $$

Here, $V(t)$ is the membrane potential, $\tau_m$ is the membrane time constant (governing how fast the potential leaks back to rest), $I(t)$ is the synaptic input current, and $R$ is the membrane resistance. The dynamics are augmented by two firing rules:

  1. Fire: If $V(t) \geq V_{th}$, emit a spike $S(t) = 1$.
  2. Reset: After firing, immediately reset $V(t) \leftarrow 0$.

For implementation on a digital computer, we discretize using an Euler step. Defining $\beta = \exp(-\Delta t / \tau_m) \in (0,1)$ as the discrete membrane decay factor, we get the update rule used in our PyTorch implementation:

$$ V[t] = \beta \cdot V[t-1] + I[t] $$ $$ S[t] = \Theta\bigl(V[t] - V_{th}\bigr) $$ $$ V[t] = V[t] \cdot \bigl(1 - S[t]\bigr) $$

where $\Theta(\cdot)$ is the Heaviside step function, and the final equation implements the soft reset: after a spike, the membrane drops by $V_{th}$.

Rate Coding vs. Temporal Coding

How does a spiking neuron encode information? There are two primary paradigms:

In our PyTorch implementation (Part 2), we use rate coding: the same input pixel values are presented at every timestep $t \in \{1, \ldots, T\}$, and the output class is determined by which output neuron accumulates the most spikes over $T$ steps.

The Gradient Problem

This is where SNNs diverge sharply from standard deep learning. The spike emission step $S[t] = \Theta(V[t] - V_{th})$ is a Heaviside step function. Its derivative is a Dirac delta: zero almost everywhere, and undefined at the threshold. Backpropagation requires computing $\frac{\partial S}{\partial V}$ at every neuron—and with the true gradient, the entire chain rule collapses to zero.

SNNs appeared, for decades, to be fundamentally untrainable via gradient descent. The solution, explored in Part 2, is the surrogate gradient trick.

Next Steps: Building an SNN in PyTorch

The equations in this post are clean and compact, but translating them to a trainable PyTorch module requires solving the gradient problem. In Part 2, we will implement the LIF update in a custom nn.Module, introduce the fast-sigmoid surrogate gradient via a custom torch.autograd.Function, and assemble a full 2-layer SNN classifier for MNIST.