In Part 1, we explored the biological plausibility of Predictive Coding Networks (PCNs) and how they
flip standard Multi-Layer Perceptrons (MLPs) on their head. We proved that PCNs can replace the
global, non-biological loss.backward() of standard Artificial Neural Networks with
iterative, local energy minimization.
Today, we will turn that theory into raw PyTorch code. Our goal is to create a network that can learn a simple task without ever calling PyTorch's automatic generalized backpropagation through the whole network!
The PCN Layer
A standard PyTorch nn.Linear layer simply maps inputs to outputs. A PCN layer, however,
contains dynamic states that require updating during an explicit inference phase. Each layer
holds a latent "belief" and calculates its own local prediction error against the layer below it.
class PCNLayer(nn.Module):
def __init__(self, in_features, out_features, activation=torch.tanh):
super().__init__()
self.out_features = out_features
self.activation = activation
# Generative weights predicting the layer below
self.W = nn.Linear(out_features, in_features, bias=True)
# Internal states for the PCN
self.latent_state = None
self.prediction_error = None
def forward(self):
# Prediction: f(v) * W
return self.W(self.activation(self.latent_state))
def compute_prediction_error(self, target_state):
prediction = self.forward()
self.prediction_error = target_state - prediction
return self.prediction_error
Notice we recommend torch.tanh. In standard neural nets, ReLU is the default.
But in iterative PCNs, ReLU's zero-gradient for negative inputs can cause latent beliefs to
"die" during the inference phase, halting learning.
Phase 1: The Inference Walk
Unlike feed-forward networks, PCNs don't calculate their output in one step. They enter an inference phase where the latent states of all layers recursively undergo gradient descent to minimize total prediction error. The synaptic weights remain entirely frozen during this phase.
def inference_step(self, inference_lr):
total_energy = 0.0
current_target = self.input_data # Fixed sensory input
# 1. Compute prediction errors bottom-up
for i in range(self.num_layers):
error = self.layers[i].compute_prediction_error(current_target)
# Energy = 1/2 * ||error||^2
layer_energy = 0.5 * torch.sum(error ** 2)
total_energy += layer_energy
current_target = self.layers[i].latent_state
# 2. Extract latent states that require gradients
latent_states = [layer.latent_state for layer in self.layers
if layer.latent_state.requires_grad]
if len(latent_states) == 0:
return total_energy.item()
# 3. Gradient w.r.t latent states ONLY (Weights are frozen)
gradients = torch.autograd.grad(total_energy, latent_states)
# 4. Update the latent states (Minimize Energy)
with torch.no_grad():
grad_idx = 0
for layer in self.layers:
if layer.latent_state.requires_grad:
layer.latent_state.sub_(inference_lr * gradients[grad_idx])
grad_idx += 1
return total_energy.item()
We use PyTorch's autograd.grad explicitly here to descend the energy landscape with respect
to the latent nodes. No optimizer.step() is used on the weights.
Phase 2: The Biological Learning Rule
Once the inference phase has run for sufficient iterations (e.g., $T=50$) and the latent states have settled into a low-energy configuration, we execute the learning phase.
This is the biological magic. The generic chain rule is gone. The weight update is executed using only the post-synaptic local prediction error ($\epsilon_{l-1}$) and the pre-synaptic activation ($f(v_l)$).
def update_weights(self, learning_rate):
with torch.no_grad():
for layer in self.layers:
error = layer.prediction_error
pre_synaptic = layer.activation(layer.latent_state)
# Local Hebbian Update Rule
# delta W = error^T @ pre_synaptic
delta_W = torch.matmul(error.t(), pre_synaptic) / error.size(0)
delta_bias = torch.mean(error, dim=0)
# Update weights locally
layer.W.weight.add_(learning_rate * delta_W)
if layer.W.bias is not None:
layer.W.bias.add_(learning_rate * delta_bias)
Next Up: The Benchmark Showdown
We have successfully built a Predictive Coding Network that learns via local energy minimization. But how does this elegant biological algorithm stack up against raw, globally-optimized Backpropagation?
In Part 3, we will put this PyTorch PCN implementation to the test against a standard identical MLP on nonlinear regression and MNIST image classification to analyze its stability, generative properties, and hardware future.
Ready to see the results? Stay tuned for the final data drop!