We translate the math from Part 1 into working PyTorch code. Two implementations: the classical binary Hopfield network (Hebbian learning, asynchronous updates) and the modern continuous variant (log-sum-exp energy, softmax retrieval). Nothing beyond PyTorch---every operation is explicit.
Implementation Overview
Two classes:
ClassicalHopfield: Binary $\{-1, +1\}$ patterns, Hebbian weight matrix, asynchronous sign update, quadratic energy.ModernHopfield: Continuous $\mathbb{R}^d$ patterns, direct pattern storage, softmax-weighted retrieval, log-sum-exp energy.
Both expose three operations: store(), retrieve(), and energy().
Classical Hopfield: Weight Matrix Construction
The Hebbian learning rule from Part 1:
def store(self, patterns: torch.Tensor) -> None:
self.patterns = patterns.clone()
num_patterns = patterns.shape[0]
self.W = torch.zeros(self.size, self.size)
for i in range(num_patterns):
xi = patterns[i].unsqueeze(1) # (size, 1)
self.W += xi @ xi.T # outer product
self.W /= self.size # normalize by N
self.W.fill_diagonal_(0.0) # zero diagonal
This implements $W = \frac{1}{N} \sum_{\mu=1}^P \boldsymbol{\xi}^\mu (\boldsymbol{\xi}^\mu)^\top$ with $W_{ii} = 0$.
Design Notes
- Outer product loop: We iterate over patterns rather than batching with
einsum. The loop makes the Hebbian accumulation explicit. - Normalization by $N$: Dividing by neuron count (not pattern count) keeps the local field $h_i$ on the order of $\pm 1$.
- Zero diagonal: Without this, neurons reinforce their own state and the network collapses to trivial all-$+1$ or all-$-1$ attractors.
Classical Hopfield: Asynchronous Update
Retrieval updates one neuron at a time in random order:
def retrieve(self, query, max_steps=20):
state = query.clone().float()
energy_history = [self.energy(state).item()]
for _ in range(max_steps):
old_state = state.clone()
order = torch.randperm(self.size)
for i in order:
h = self.W[i] @ state
state[i] = 1.0 if h >= 0 else -1.0
energy_history.append(self.energy(state).item())
if torch.equal(state, old_state):
break
return state, energy_history
Why Asynchronous?
Synchronous updates (all neurons at once) can oscillate between two states. Asynchronous updates guarantee monotonic energy decrease, since each individual flip can only lower or maintain the energy:
One "step" is a full sweep through all $N$ neurons in random order. For well-separated patterns, convergence typically takes 2--3 sweeps.
Classical Hopfield: Energy Computation
def energy(self, state: torch.Tensor) -> float:
return -0.5 * state @ self.W @ state
One line---PyTorch handles the matrix-vector products. The energy $E = -\frac{1}{2} \mathbf{x}^\top W \mathbf{x}$ is a scalar quadratic form. During retrieval, energy drops from roughly $-11$ to $-48$, confirming descent toward a stored pattern.
Modern Hopfield: Energy Function
The modern energy function uses log-sum-exp:
def energy(self, state, beta=1.0):
similarities = self.patterns @ state # (num_patterns,)
lse = torch.logsumexp(beta * similarities, dim=0)
return (-lse / beta + 0.5 * torch.sum(state ** 2)).item()
This implements:
PyTorch's logsumexp is numerically stable (it subtracts the max before exponentiating), which is critical when $\beta$ is large.
Modern Hopfield: Softmax Retrieval
The iterative update rule:
def retrieve(self, query, beta=1.0, steps=5):
state = query.clone().float()
energy_history = [self.energy(state, beta)]
for _ in range(steps):
similarities = self.patterns @ state
weights = F.softmax(beta * similarities, dim=0)
state = self.patterns.T @ weights
energy_history.append(self.energy(state, beta))
return state, energy_history
Each step computes:
where $\Xi \in \mathbb{R}^{P \times d}$ is the pattern matrix. This is a weighted average of stored patterns, where the weights are softmax over similarity scores.
The Attention Connection: Implementation
The attention_retrieve method writes out the equivalence directly:
def attention_retrieve(self, query, beta=1.0):
if query.dim() == 1:
query = query.unsqueeze(0) # (1, dim)
K = self.patterns # (P, dim)
V = self.patterns # (P, dim)
scores = beta * (query @ K.T) # (1, P)
attn_weights = F.softmax(scores, dim=-1)
output = attn_weights @ V # (1, dim)
return output.squeeze(0)
Compare with standard Transformer attention:
Set $\beta = 1/\sqrt{d_k}$ and $K = V = \Xi$, and the two are identical. The Hopfield network is an attention layer where keys equal values (the stored patterns) and the query is the state being retrieved.
Utility Functions
We also implement helper functions for benchmarking:
def generate_binary_patterns(num_patterns, size):
return 2 * torch.randint(0, 2, (num_patterns, size)).float() - 1
def corrupt_pattern(pattern, corruption_rate):
corrupted = pattern.clone()
num_flip = int(corruption_rate * pattern.numel())
flip_idx = torch.randperm(pattern.numel())[:num_flip]
corrupted[flip_idx] *= -1
return corrupted
def hamming_accuracy(original, recovered):
return (original == recovered).float().mean().item()
generate_binary_patterns samples from $\{-1, +1\}^N$ uniformly. corrupt_pattern flips a specified fraction of bits. hamming_accuracy measures the fraction of matching bits between the original and recovered patterns.
Summary
| Method | Classical | Modern |
|---|---|---|
store() | Hebbian outer products | Direct pattern matrix |
retrieve() | Asynchronous sign update | Iterative softmax |
energy() | $-\frac{1}{2}\mathbf{x}^\top W \mathbf{x}$ | $-\frac{1}{\beta}\text{lse}(\cdot) + \frac{1}{2}\|\mathbf{x}\|^2$ |
attention_retrieve() | --- | Single-step attention |
The whole thing is under 120 lines of Python, all standard PyTorch tensor ops. In Part 3, we benchmark both networks, measure storage capacity, and verify the attention equivalence numerically.