Back to SNNs Hub

Deconstructing Spiking Neural Networks (SNNs)

Part 1: Biology, Spikes, and the LIF Model

Introduction

Every neural network you have ever trained—CNN, Transformer, LSTM—shares one fundamental assumption: neurons communicate via continuous-valued activations. A ReLU neuron outputs a floating-point number; a softmax layer produces a probability vector. These real-valued signals flow through dense matrix multiplications on every forward pass, demanding billions of multiply-accumulate (MAC) operations for even modest models.

The brain does none of this. Biological neurons communicate via discrete, all-or-nothing electrical pulses called action potentials (or spikes). A neuron only fires when its internal charge crosses a threshold; the rest of the time it is silent. This sparse, event-driven communication is why the human brain runs on roughly 20 Watts—less than a laptop screen—while a GPU-accelerated Transformer burns kilowatts.

Spiking Neural Networks (SNNs) try to bring this paradigm into machine learning. This 3-part series deconstructs SNNs: the biological neuron model and the gradient problem (here), a full PyTorch implementation (Part 2), and a head-to-head benchmark against a standard ANN on MNIST (Part 3).

The Biological Neuron

A biological neuron integrates electrical charge from thousands of synaptic inputs through its dendrites. This charge gradually builds the neuron's membrane potential $V_m$. If $V_m$ crosses a threshold $V_{th}$, the neuron fires a spike---a stereotyped 1--2 ms voltage pulse---which travels down the axon to stimulate or inhibit downstream neurons.

After firing, the potential is rapidly reset and then hyperpolarized---driven below resting potential---for a short refractory period. During this window the neuron is effectively silent, leaking charge back toward its resting state before it can integrate new input. This refractory mechanism prevents runaway firing and enforces a natural sparsity: even heavily stimulated neurons have an upper bound on their firing rate.

Information is not encoded in spike amplitude---they are all the same height. It is encoded in timing and rate: when and how often the neuron fires. This is a radically different communication protocol from the continuous-valued activations in standard neural networks, and it is the reason biological neural circuits are so energy-efficient. A neuron that does not fire consumes almost no energy. Its downstream synapses are never activated, and the corresponding weights are never read.

The Leaky Integrate-and-Fire (LIF) Model

The Leaky Integrate-and-Fire (LIF) model is the simplest and most widely used computational spiking neuron. It reduces the dynamics to a single first-order ODE:

$$ \tau_m \frac{dV(t)}{dt} = -V(t) + RI(t) $$

$V(t)$ is the membrane potential, $\tau_m$ is the membrane time constant (how fast the potential leaks back to rest), $I(t)$ is the synaptic input current, and $R$ is the membrane resistance. Two firing rules complete the model:

  1. Fire: If $V(t) \geq V_{th}$, emit a spike $S(t) = 1$.
  2. Reset: After firing, immediately reset $V(t) \leftarrow 0$.

To run this on a digital computer, we discretize via an Euler step. Define $\beta = \exp(-\Delta t / \tau_m) \in (0,1)$ as the discrete membrane decay factor. The update rule becomes:

$$ V[t] = \beta \cdot V[t-1] + I[t] $$ $$ S[t] = \Theta\bigl(V[t] - V_{th}\bigr) $$ $$ V[t] = V[t] \cdot \bigl(1 - S[t]\bigr) $$

where $\Theta(\cdot)$ is the Heaviside step function, and the final equation implements the soft reset: after a spike, the membrane drops by $V_{th}$. The decay factor $\beta$ controls the neuron's memory. A value close to 1 means the neuron retains charge across many timesteps, integrating information over a long temporal window. A value close to 0 means the neuron is memoryless---it responds only to the current input. In our PyTorch implementation (Part 2), $\beta$ is a learnable parameter per neuron, constrained to $(0,1)$ via a sigmoid, so the network discovers the right time constant for each neuron during training.

Rate Coding vs. Temporal Coding

Two paradigms for encoding information in spikes:

In our implementation (Part 2), we use rate coding: the same input pixel values are presented at every timestep $t \in \{1, \ldots, T\}$, and the output class is whichever neuron accumulates the most spikes over $T$ steps.

The Gradient Problem

Here is where SNNs diverge sharply from standard deep learning. The spike emission $S[t] = \Theta(V[t] - V_{th})$ is a Heaviside step function. Its derivative is a Dirac delta: zero almost everywhere, undefined at the threshold. Backpropagation requires $\frac{\partial S}{\partial V}$ at every neuron---and with the true gradient, the entire chain rule collapses to zero.

This is not a minor inconvenience. In a multi-layer SNN, every layer contains a spiking nonlinearity. The gradient must pass through all of them during backpropagation, and each one multiplies the signal by zero. The result: no weight in the network receives any useful learning signal. For decades, this made SNNs appear fundamentally untrainable via gradient descent.

The solution is the surrogate gradient trick: keep the forward pass exact (hard Heaviside threshold), but substitute a smooth, differentiable approximation during the backward pass only. The surrogate we use is the derivative of the fast sigmoid, $\frac{1}{(1 + k|V - V_{th}|)^2}$, which peaks at the threshold and smoothly decays to zero away from it. The sharpness factor $k$ controls how closely the surrogate approximates the true step---we use $k = 10$. This is implemented in Part 2 via a custom torch.autograd.Function.

What Comes Next

Part 2 takes these equations and turns them into a trainable PyTorch module: a custom torch.autograd.Function for the fast-sigmoid surrogate gradient, a LIFNeuron layer with learnable decay, and a 2-layer SNN classifier for MNIST.