Back to PCN Hub

Deconstructing Predictive Coding Networks

Part 2: Energy Minimization in Pure PyTorch

Introduction

Part 1 covered the theory: PCNs replace the global loss.backward() with iterative, local energy minimization. Here we turn that into working PyTorch code -- a network that learns without calling backpropagation through the full graph.

View Full Setup on GitHub

The PCN Layer

A standard nn.Linear maps inputs to outputs in one shot. A PCN layer is different: it holds a dynamic latent state that gets updated during an inference phase, and it computes its own local prediction error against the layer below.

class PCNLayer(nn.Module):
    def __init__(self, in_features, out_features, activation=torch.tanh):
        super().__init__()
        self.out_features = out_features
        self.activation = activation

        # Generative weights predicting the layer below
        self.W = nn.Linear(out_features, in_features, bias=True)

        # Internal states for the PCN
        self.latent_state = None
        self.prediction_error = None

    def forward(self):
        # Prediction: f(v) * W
        return self.W(self.activation(self.latent_state))

    def compute_prediction_error(self, target_state):
        prediction = self.forward()
        self.prediction_error = target_state - prediction
        return self.prediction_error

We use torch.tanh instead of ReLU. In iterative PCNs, ReLU's zero gradient for negative inputs causes latent states to die during inference, stalling learning entirely. During MNIST classification, ReLU flatlined accuracy at exactly 10% -- random guessing -- because negative latent states received zero gradient and could never recover during the inference walk. Tanh avoids this by providing a smooth, nonzero gradient across the entire real line.

Weight initialization matters too. We apply Xavier/Glorot uniform initialization to the generative weight matrices and zero-initialize the biases. Because the PCN's top-down predictions must initially be reasonable enough for the inference phase to converge, poor initialization can cause the energy landscape to be too flat or too chaotic for the latent states to settle.

Phase 1: The Inference Walk

PCNs do not compute output in one forward pass. Instead, all latent states undergo gradient descent to minimize total prediction error. Weights stay frozen throughout. This is the fundamental departure from standard neural networks: where a feed-forward net computes its output in a single sweep, a PCN enters an iterative relaxation loop that may run for 20 to 50 steps before producing a stable internal representation.

def inference_step(self, inference_lr):
    total_energy = 0.0
    current_target = self.input_data # Fixed sensory input

    # 1. Compute prediction errors bottom-up
    for i in range(self.num_layers):
        error = self.layers[i].compute_prediction_error(current_target)
        # Energy = 1/2 * ||error||^2
        layer_energy = 0.5 * torch.sum(error ** 2)
        total_energy += layer_energy
        current_target = self.layers[i].latent_state

    # 2. Extract latent states that require gradients
    latent_states = [layer.latent_state for layer in self.layers
                     if layer.latent_state.requires_grad]

    if len(latent_states) == 0:
        return total_energy.item()

    # 3. Gradient w.r.t latent states ONLY (Weights are frozen)
    gradients = torch.autograd.grad(total_energy, latent_states)

    # 4. Update the latent states (Minimize Energy)
    with torch.no_grad():
        grad_idx = 0
        for layer in self.layers:
            if layer.latent_state.requires_grad:
                layer.latent_state.sub_(inference_lr * gradients[grad_idx])
                grad_idx += 1

    return total_energy.item()

We deliberately use autograd.grad instead of the standard loss.backward(). The difference is critical: backward() accumulates gradients on every parameter in the graph, including weights. autograd.grad lets us target only the latent states, keeping weights completely frozen. We also pass create_graph=False because we do not need second-order gradients -- the inference walk is pure first-order gradient descent on the energy surface, and skipping the higher-order graph saves significant memory.

Phase 2: The Hebbian Weight Update

After the inference phase runs (e.g., $T=50$ iterations) and the latent states settle, we update weights. The chain rule is gone. Each weight update uses only the post-synaptic prediction error ($\epsilon_{l-1}$) and the pre-synaptic activation ($f(v_l)$):

def update_weights(self, learning_rate):
    with torch.no_grad():
        for layer in self.layers:
            error = layer.prediction_error
            pre_synaptic = layer.activation(layer.latent_state)

            # Local Hebbian Update Rule
            # delta W = error^T @ pre_synaptic
            delta_W = torch.matmul(error.t(), pre_synaptic) / error.size(0)
            delta_bias = torch.mean(error, dim=0)

            # Update weights locally
            layer.W.weight.add_(learning_rate * delta_W)
            if layer.W.bias is not None:
                layer.W.bias.add_(learning_rate * delta_bias)

Supervised Learning: Label Clamping

To use a PCN for classification, we clamp the top layer's latent state to the one-hot encoded label and freeze it (requires_grad = False). During the inference walk, the label information propagates downward through the generative weights, shaping the internal representations of every hidden layer. This is the opposite of a feed-forward net, where label information only enters at the loss function after the forward pass is complete.

At prediction time, we leave the top layer unclamped and let the entire network relax freely. The top layer's settled state becomes the network's "belief" about the input's class. We read it out with a softmax to get class probabilities.

Next: Benchmarks

The PCN learns via local energy minimization. Part 3 benchmarks it against an identical MLP on nonlinear regression and MNIST classification.