Every neural network you have ever trained—CNN, Transformer, LSTM—shares one fundamental assumption: neurons communicate via continuous-valued activations. A ReLU neuron outputs a floating-point number; a softmax layer produces a probability vector. These real-valued signals flow through dense matrix multiplications on every forward pass, demanding billions of multiply-accumulate (MAC) operations for even modest models.
The brain does none of this. Biological neurons communicate via discrete, all-or-nothing electrical pulses called action potentials (or spikes). This sparse, event-driven communication is why a human brain performs remarkable cognitive feats on roughly 20 Watts—while a GPU-accelerated Transformer requires kilowatts for comparable tasks.
Spiking Neural Networks (SNNs) are the attempt to bring this paradigm into machine learning. In this 3-part series, we will completely deconstruct SNNs: from the biological neuron model and the infamous gradient problem, to a full PyTorch implementation, to a benchmark against a standard ANN on MNIST.
The Biological Neuron
A biological neuron integrates electrical charge from thousands of synaptic inputs through its dendrites. This charge gradually builds the neuron's membrane potential $V_m$. If $V_m$ crosses a threshold $V_{th}$, the neuron fires a spike—a stereotyped 1–2 ms voltage pulse—which travels down the axon to stimulate or inhibit downstream neurons.
Crucially, after firing the neuron's potential is rapidly reset and then hyperpolarized (driven below resting potential) for a short refractory period. Information is not encoded in the amplitude of the spike (they are all the same height). It is encoded in timing and rate—when and how often the neuron fires.
The Leaky Integrate-and-Fire (LIF) Model
The simplest and most widely used computational model of the spiking neuron is the Leaky Integrate-and-Fire (LIF) model. It captures the essential dynamics in a single first-order differential equation:
Here, $V(t)$ is the membrane potential, $\tau_m$ is the membrane time constant (governing how fast the potential leaks back to rest), $I(t)$ is the synaptic input current, and $R$ is the membrane resistance. The dynamics are augmented by two firing rules:
- Fire: If $V(t) \geq V_{th}$, emit a spike $S(t) = 1$.
- Reset: After firing, immediately reset $V(t) \leftarrow 0$.
For implementation on a digital computer, we discretize using an Euler step. Defining $\beta = \exp(-\Delta t / \tau_m) \in (0,1)$ as the discrete membrane decay factor, we get the update rule used in our PyTorch implementation:
where $\Theta(\cdot)$ is the Heaviside step function, and the final equation implements the soft reset: after a spike, the membrane drops by $V_{th}$.
Rate Coding vs. Temporal Coding
How does a spiking neuron encode information? There are two primary paradigms:
- Rate Coding: The information is the average firing rate over a time window $T$. A strongly stimulated neuron fires many times; a weakly stimulated one fires rarely. This is the most common encoding scheme in computational SNNs because it maps naturally onto probabilities.
- Temporal Coding: The information is encoded in the precise timing of individual spikes. This is believed to be the dominant coding scheme in the biological brain and offers extreme efficiency, but is much harder to train computationally.
In our PyTorch implementation (Part 2), we use rate coding: the same input pixel values are presented at every timestep $t \in \{1, \ldots, T\}$, and the output class is determined by which output neuron accumulates the most spikes over $T$ steps.
The Gradient Problem
This is where SNNs diverge sharply from standard deep learning. The spike emission step $S[t] = \Theta(V[t] - V_{th})$ is a Heaviside step function. Its derivative is a Dirac delta: zero almost everywhere, and undefined at the threshold. Backpropagation requires computing $\frac{\partial S}{\partial V}$ at every neuron—and with the true gradient, the entire chain rule collapses to zero.
SNNs appeared, for decades, to be fundamentally untrainable via gradient descent. The solution, explored in Part 2, is the surrogate gradient trick.
Next Steps: Building an SNN in PyTorch
The equations in this post are clean and compact, but translating them to a trainable PyTorch module
requires solving the gradient problem. In Part 2, we will implement the LIF update in a custom
nn.Module, introduce the fast-sigmoid surrogate gradient via a custom
torch.autograd.Function, and assemble a full 2-layer SNN classifier for MNIST.