Back to Hopfield Hub

Deconstructing Hopfield Networks from Scratch

Part 2: PyTorch Implementation

Abstract

In Part 1, we developed the mathematical foundations of Hopfield Networks. Now we translate every equation into working PyTorch code. We implement both the classical binary Hopfield network (1982) with Hebbian learning and asynchronous updates, and the modern continuous Hopfield network (2020) with log-sum-exp energy and softmax retrieval. Zero external frameworks beyond PyTorch---every operation is explicit.

Implementation Overview

Our implementation consists of two classes:

  1. ClassicalHopfield: Binary $\{-1, +1\}$ patterns, Hebbian weight matrix, asynchronous sign update, quadratic energy.
  2. ModernHopfield: Continuous $\mathbb{R}^d$ patterns, direct pattern storage, softmax-weighted retrieval, log-sum-exp energy.

Both classes expose three core operations: store(), retrieve(), and energy().

Classical Hopfield: Weight Matrix Construction

The Hebbian learning rule from Part 1 translates directly:

def store(self, patterns: torch.Tensor) -> None:
    self.patterns = patterns.clone()
    num_patterns = patterns.shape[0]

    self.W = torch.zeros(self.size, self.size)
    for i in range(num_patterns):
        xi = patterns[i].unsqueeze(1)   # (size, 1)
        self.W += xi @ xi.T             # outer product

    self.W /= self.size                 # normalize by N
    self.W.fill_diagonal_(0.0)          # zero diagonal

This implements $W = \frac{1}{N} \sum_{\mu=1}^P \boldsymbol{\xi}^\mu (\boldsymbol{\xi}^\mu)^\top$ with $W_{ii} = 0$.

Key Design Decisions

Classical Hopfield: Asynchronous Update

Retrieval uses random asynchronous updates---one neuron at a time:

def retrieve(self, query, max_steps=20):
    state = query.clone().float()
    energy_history = [self.energy(state).item()]

    for _ in range(max_steps):
        old_state = state.clone()
        order = torch.randperm(self.size)
        for i in order:
            h = self.W[i] @ state
            state[i] = 1.0 if h >= 0 else -1.0

        energy_history.append(self.energy(state).item())
        if torch.equal(state, old_state):
            break

    return state, energy_history

Why Asynchronous?

Synchronous updates (updating all neurons simultaneously) can lead to oscillations between two states. Asynchronous updates guarantee monotonic energy decrease, because each individual neuron flip is guaranteed to lower (or maintain) the energy:

$$ \Delta E_i = -(x_i^{\text{new}} - x_i) h_i \leq 0 $$

In our implementation, one "step" is a complete sweep through all $N$ neurons in random order. Convergence typically occurs in 2--3 sweeps for well-separated patterns.

Classical Hopfield: Energy Computation

def energy(self, state: torch.Tensor) -> float:
    return -0.5 * state @ self.W @ state

This is a single line because PyTorch handles the matrix-vector products. The energy $E = -\frac{1}{2} \mathbf{x}^\top W \mathbf{x}$ is a scalar quadratic form. During retrieval, we observe energy dropping from approximately $-11$ to $-48$, confirming that the dynamics are descending the energy landscape toward a stored pattern.

Modern Hopfield: Energy Function

The modern energy function uses log-sum-exp:

def energy(self, state, beta=1.0):
    similarities = self.patterns @ state   # (num_patterns,)
    lse = torch.logsumexp(beta * similarities, dim=0)
    return (-lse / beta + 0.5 * torch.sum(state ** 2)).item()

This implements:

$$ E(\mathbf{x}) = -\frac{1}{\beta} \ln\!\left(\sum_{\mu=1}^P \exp(\beta \, \boldsymbol{\xi}^\mu \cdot \mathbf{x})\right) + \frac{1}{2}\|\mathbf{x}\|^2 $$

PyTorch's logsumexp is numerically stable (it subtracts the max before exponentiating), which is critical when $\beta$ is large.

Modern Hopfield: Softmax Retrieval

The iterative update rule:

def retrieve(self, query, beta=1.0, steps=5):
    state = query.clone().float()
    energy_history = [self.energy(state, beta)]

    for _ in range(steps):
        similarities = self.patterns @ state
        weights = F.softmax(beta * similarities, dim=0)
        state = self.patterns.T @ weights
        energy_history.append(self.energy(state, beta))

    return state, energy_history

Each step computes:

$$ \mathbf{x}^{\text{new}} = \Xi^\top \, \text{softmax}(\beta \, \Xi \, \mathbf{x}) $$

where $\Xi \in \mathbb{R}^{P \times d}$ is the pattern matrix. This is a weighted average of stored patterns, where the weights are softmax over similarity scores.

The Attention Connection: Implementation

The attention_retrieve method makes the connection explicit:

def attention_retrieve(self, query, beta=1.0):
    if query.dim() == 1:
        query = query.unsqueeze(0)       # (1, dim)

    K = self.patterns                     # (P, dim)
    V = self.patterns                     # (P, dim)

    scores = beta * (query @ K.T)         # (1, P)
    attn_weights = F.softmax(scores, dim=-1)
    output = attn_weights @ V             # (1, dim)

    return output.squeeze(0)

Compare with standard Transformer attention:

$$ \text{Attention}(Q, K, V) = \text{softmax}\!\left(\frac{QK^\top}{\sqrt{d_k}}\right) V $$

Setting $\beta = 1/\sqrt{d_k}$ and $K = V = \Xi$, these are identical. The Hopfield network is an attention layer where keys equal values (the stored patterns) and the query is the state to be retrieved.

Utility Functions

We also implement helper functions for benchmarking:

def generate_binary_patterns(num_patterns, size):
    return 2 * torch.randint(0, 2, (num_patterns, size)).float() - 1

def corrupt_pattern(pattern, corruption_rate):
    corrupted = pattern.clone()
    num_flip = int(corruption_rate * pattern.numel())
    flip_idx = torch.randperm(pattern.numel())[:num_flip]
    corrupted[flip_idx] *= -1
    return corrupted

def hamming_accuracy(original, recovered):
    return (original == recovered).float().mean().item()

generate_binary_patterns samples from $\{-1, +1\}^N$ uniformly. corrupt_pattern flips a specified fraction of bits. hamming_accuracy measures the fraction of matching bits between the original and recovered patterns.

Summary

Method Classical Modern
store()Hebbian outer productsDirect pattern matrix
retrieve()Asynchronous sign updateIterative softmax
energy()$-\frac{1}{2}\mathbf{x}^\top W \mathbf{x}$$-\frac{1}{\beta}\text{lse}(\cdot) + \frac{1}{2}\|\mathbf{x}\|^2$
attention_retrieve()---Single-step attention

The entire implementation fits in under 120 lines of Python. Every operation uses standard PyTorch tensor ops---no custom CUDA kernels, no external libraries. In Part 3, we benchmark both networks, measure storage capacity, and formally demonstrate the attention equivalence with numerical experiments.