Part 1 covered the theory: PCNs replace the
global loss.backward() with iterative, local energy minimization. Here we turn that
into working PyTorch code -- a network that learns without calling backpropagation through the
full graph.
The PCN Layer
A standard nn.Linear maps inputs to outputs in one shot. A PCN layer is different: it
holds a dynamic latent state that gets updated during an inference phase, and it computes its own
local prediction error against the layer below.
class PCNLayer(nn.Module):
def __init__(self, in_features, out_features, activation=torch.tanh):
super().__init__()
self.out_features = out_features
self.activation = activation
# Generative weights predicting the layer below
self.W = nn.Linear(out_features, in_features, bias=True)
# Internal states for the PCN
self.latent_state = None
self.prediction_error = None
def forward(self):
# Prediction: f(v) * W
return self.W(self.activation(self.latent_state))
def compute_prediction_error(self, target_state):
prediction = self.forward()
self.prediction_error = target_state - prediction
return self.prediction_error
We use torch.tanh instead of ReLU. In iterative PCNs, ReLU's zero gradient
for negative inputs causes latent states to die during inference, stalling learning entirely. During
MNIST classification, ReLU flatlined accuracy at exactly 10% -- random guessing -- because negative
latent states received zero gradient and could never recover during the inference walk. Tanh avoids
this by providing a smooth, nonzero gradient across the entire real line.
Weight initialization matters too. We apply Xavier/Glorot uniform initialization to the generative weight matrices and zero-initialize the biases. Because the PCN's top-down predictions must initially be reasonable enough for the inference phase to converge, poor initialization can cause the energy landscape to be too flat or too chaotic for the latent states to settle.
Phase 1: The Inference Walk
PCNs do not compute output in one forward pass. Instead, all latent states undergo gradient descent to minimize total prediction error. Weights stay frozen throughout. This is the fundamental departure from standard neural networks: where a feed-forward net computes its output in a single sweep, a PCN enters an iterative relaxation loop that may run for 20 to 50 steps before producing a stable internal representation.
def inference_step(self, inference_lr):
total_energy = 0.0
current_target = self.input_data # Fixed sensory input
# 1. Compute prediction errors bottom-up
for i in range(self.num_layers):
error = self.layers[i].compute_prediction_error(current_target)
# Energy = 1/2 * ||error||^2
layer_energy = 0.5 * torch.sum(error ** 2)
total_energy += layer_energy
current_target = self.layers[i].latent_state
# 2. Extract latent states that require gradients
latent_states = [layer.latent_state for layer in self.layers
if layer.latent_state.requires_grad]
if len(latent_states) == 0:
return total_energy.item()
# 3. Gradient w.r.t latent states ONLY (Weights are frozen)
gradients = torch.autograd.grad(total_energy, latent_states)
# 4. Update the latent states (Minimize Energy)
with torch.no_grad():
grad_idx = 0
for layer in self.layers:
if layer.latent_state.requires_grad:
layer.latent_state.sub_(inference_lr * gradients[grad_idx])
grad_idx += 1
return total_energy.item()
We deliberately use autograd.grad instead of the standard loss.backward().
The difference is critical: backward() accumulates gradients on every parameter in the
graph, including weights. autograd.grad lets us target only the latent states, keeping
weights completely frozen. We also pass create_graph=False because we do not need
second-order gradients -- the inference walk is pure first-order gradient descent on the energy
surface, and skipping the higher-order graph saves significant memory.
Phase 2: The Hebbian Weight Update
After the inference phase runs (e.g., $T=50$ iterations) and the latent states settle, we update weights. The chain rule is gone. Each weight update uses only the post-synaptic prediction error ($\epsilon_{l-1}$) and the pre-synaptic activation ($f(v_l)$):
def update_weights(self, learning_rate):
with torch.no_grad():
for layer in self.layers:
error = layer.prediction_error
pre_synaptic = layer.activation(layer.latent_state)
# Local Hebbian Update Rule
# delta W = error^T @ pre_synaptic
delta_W = torch.matmul(error.t(), pre_synaptic) / error.size(0)
delta_bias = torch.mean(error, dim=0)
# Update weights locally
layer.W.weight.add_(learning_rate * delta_W)
if layer.W.bias is not None:
layer.W.bias.add_(learning_rate * delta_bias)
Supervised Learning: Label Clamping
To use a PCN for classification, we clamp the top layer's latent state to the one-hot encoded label
and freeze it (requires_grad = False). During the inference walk, the label information
propagates downward through the generative weights, shaping the internal representations of every
hidden layer. This is the opposite of a feed-forward net, where label information only enters at the
loss function after the forward pass is complete.
At prediction time, we leave the top layer unclamped and let the entire network relax freely. The top layer's settled state becomes the network's "belief" about the input's class. We read it out with a softmax to get class probabilities.
Next: Benchmarks
The PCN learns via local energy minimization. Part 3 benchmarks it against an identical MLP on nonlinear regression and MNIST classification.