Back to Hopfield Hub

Deconstructing Hopfield Networks from Scratch

Part 2: PyTorch Implementation

Abstract

We translate the math from Part 1 into working PyTorch code. Two implementations: the classical binary Hopfield network (Hebbian learning, asynchronous updates) and the modern continuous variant (log-sum-exp energy, softmax retrieval). Nothing beyond PyTorch---every operation is explicit.

Implementation Overview

Two classes:

  1. ClassicalHopfield: Binary $\{-1, +1\}$ patterns, Hebbian weight matrix, asynchronous sign update, quadratic energy.
  2. ModernHopfield: Continuous $\mathbb{R}^d$ patterns, direct pattern storage, softmax-weighted retrieval, log-sum-exp energy.

Both expose three operations: store(), retrieve(), and energy().

Classical Hopfield: Weight Matrix Construction

The Hebbian learning rule from Part 1:

def store(self, patterns: torch.Tensor) -> None:
    self.patterns = patterns.clone()
    num_patterns = patterns.shape[0]

    self.W = torch.zeros(self.size, self.size)
    for i in range(num_patterns):
        xi = patterns[i].unsqueeze(1)   # (size, 1)
        self.W += xi @ xi.T             # outer product

    self.W /= self.size                 # normalize by N
    self.W.fill_diagonal_(0.0)          # zero diagonal

This implements $W = \frac{1}{N} \sum_{\mu=1}^P \boldsymbol{\xi}^\mu (\boldsymbol{\xi}^\mu)^\top$ with $W_{ii} = 0$.

Design Notes

Classical Hopfield: Asynchronous Update

Retrieval updates one neuron at a time in random order:

def retrieve(self, query, max_steps=20):
    state = query.clone().float()
    energy_history = [self.energy(state).item()]

    for _ in range(max_steps):
        old_state = state.clone()
        order = torch.randperm(self.size)
        for i in order:
            h = self.W[i] @ state
            state[i] = 1.0 if h >= 0 else -1.0

        energy_history.append(self.energy(state).item())
        if torch.equal(state, old_state):
            break

    return state, energy_history

Why Asynchronous?

Synchronous updates (all neurons at once) can oscillate between two states. Asynchronous updates guarantee monotonic energy decrease, since each individual flip can only lower or maintain the energy:

$$ \Delta E_i = -(x_i^{\text{new}} - x_i) h_i \leq 0 $$

One "step" is a full sweep through all $N$ neurons in random order. For well-separated patterns, convergence typically takes 2--3 sweeps.

Classical Hopfield: Energy Computation

def energy(self, state: torch.Tensor) -> float:
    return -0.5 * state @ self.W @ state

One line---PyTorch handles the matrix-vector products. The energy $E = -\frac{1}{2} \mathbf{x}^\top W \mathbf{x}$ is a scalar quadratic form. During retrieval, energy drops from roughly $-11$ to $-48$, confirming descent toward a stored pattern.

Modern Hopfield: Energy Function

The modern energy function uses log-sum-exp:

def energy(self, state, beta=1.0):
    similarities = self.patterns @ state   # (num_patterns,)
    lse = torch.logsumexp(beta * similarities, dim=0)
    return (-lse / beta + 0.5 * torch.sum(state ** 2)).item()

This implements:

$$ E(\mathbf{x}) = -\frac{1}{\beta} \ln\!\left(\sum_{\mu=1}^P \exp(\beta \, \boldsymbol{\xi}^\mu \cdot \mathbf{x})\right) + \frac{1}{2}\|\mathbf{x}\|^2 $$

PyTorch's logsumexp is numerically stable (it subtracts the max before exponentiating), which is critical when $\beta$ is large.

Modern Hopfield: Softmax Retrieval

The iterative update rule:

def retrieve(self, query, beta=1.0, steps=5):
    state = query.clone().float()
    energy_history = [self.energy(state, beta)]

    for _ in range(steps):
        similarities = self.patterns @ state
        weights = F.softmax(beta * similarities, dim=0)
        state = self.patterns.T @ weights
        energy_history.append(self.energy(state, beta))

    return state, energy_history

Each step computes:

$$ \mathbf{x}^{\text{new}} = \Xi^\top \, \text{softmax}(\beta \, \Xi \, \mathbf{x}) $$

where $\Xi \in \mathbb{R}^{P \times d}$ is the pattern matrix. This is a weighted average of stored patterns, where the weights are softmax over similarity scores.

The Attention Connection: Implementation

The attention_retrieve method writes out the equivalence directly:

def attention_retrieve(self, query, beta=1.0):
    if query.dim() == 1:
        query = query.unsqueeze(0)       # (1, dim)

    K = self.patterns                     # (P, dim)
    V = self.patterns                     # (P, dim)

    scores = beta * (query @ K.T)         # (1, P)
    attn_weights = F.softmax(scores, dim=-1)
    output = attn_weights @ V             # (1, dim)

    return output.squeeze(0)

Compare with standard Transformer attention:

$$ \text{Attention}(Q, K, V) = \text{softmax}\!\left(\frac{QK^\top}{\sqrt{d_k}}\right) V $$

Set $\beta = 1/\sqrt{d_k}$ and $K = V = \Xi$, and the two are identical. The Hopfield network is an attention layer where keys equal values (the stored patterns) and the query is the state being retrieved.

Utility Functions

We also implement helper functions for benchmarking:

def generate_binary_patterns(num_patterns, size):
    return 2 * torch.randint(0, 2, (num_patterns, size)).float() - 1

def corrupt_pattern(pattern, corruption_rate):
    corrupted = pattern.clone()
    num_flip = int(corruption_rate * pattern.numel())
    flip_idx = torch.randperm(pattern.numel())[:num_flip]
    corrupted[flip_idx] *= -1
    return corrupted

def hamming_accuracy(original, recovered):
    return (original == recovered).float().mean().item()

generate_binary_patterns samples from $\{-1, +1\}^N$ uniformly. corrupt_pattern flips a specified fraction of bits. hamming_accuracy measures the fraction of matching bits between the original and recovered patterns.

Summary

Method Classical Modern
store()Hebbian outer productsDirect pattern matrix
retrieve()Asynchronous sign updateIterative softmax
energy()$-\frac{1}{2}\mathbf{x}^\top W \mathbf{x}$$-\frac{1}{\beta}\text{lse}(\cdot) + \frac{1}{2}\|\mathbf{x}\|^2$
attention_retrieve()---Single-step attention

The whole thing is under 120 lines of Python, all standard PyTorch tensor ops. In Part 3, we benchmark both networks, measure storage capacity, and verify the attention equivalence numerically.