Back to ESNs Hub

Deconstructing Echo State Networks

Part 2: Pure PyTorch Implementation

Introduction

Part 1 covered the math: why BPTT is fragile, how a frozen random reservoir sidesteps gradient computation, and how Ridge Regression gives us a closed-form readout. Here we translate that into a working PyTorch implementation.

View Full Setup on GitHub

The Constructor: Wiring the ESN

The EchoStateNetwork class extends nn.Module but uses requires_grad=False on every parameter except the readout. The constructor takes seven arguments: input_size, hidden_size, output_size, spectral_radius (default 0.9), sparsity (default 0.1), input_scaling (default 1.0), and leaky_rate (default 1.0).

class EchoStateNetwork(nn.Module):
    def __init__(self, input_size, hidden_size, output_size,
                 spectral_radius=0.9, sparsity=0.1,
                 input_scaling=1.0, leaky_rate=1.0):
        super().__init__()
        # Input weights: uniform [-input_scaling, input_scaling]
        self.W_in = nn.Parameter(
            (torch.rand(hidden_size, input_size) * 2 - 1) * input_scaling,
            requires_grad=False
        )

The input weight matrix $W_{in}$ maps from the input dimension to the reservoir dimension. It is drawn uniformly in $[-\text{input\_scaling}, +\text{input\_scaling}]$ and frozen. The scaling factor controls how strongly the input drives the reservoir dynamics.

Initializing the Sparse Reservoir

The reservoir matrix $W_{res}$ is the core data structure. It needs to be large (hundreds to thousands of neurons) and sparse. Sparsity keeps the matrix computationally manageable and introduces structural diversity--each neuron receives input from only a small random subset of other neurons.

# 1. Initialize sparse random matrix
W = torch.rand(hidden_size, hidden_size) - 0.5
mask = torch.rand(hidden_size, hidden_size) < sparsity
W = W * mask

With sparsity=0.05 (5%), a 500-neuron reservoir has about 12,500 non-zero entries out of 250,000 total. An unscaled random matrix will either explode or vanish. To enforce the Echo State Property, we compute the current spectral radius and rescale to a target value below 1.

# 2. Scale W_res to achieve the desired Spectral Radius
eigenvalues = torch.linalg.eigvals(W)
current_sr = torch.max(torch.abs(eigenvalues)).item()

if current_sr > 0:
    W = W * (spectral_radius / current_sr)

self.W_res = nn.Parameter(W, requires_grad=False)

Setting requires_grad=False tells PyTorch this matrix is frozen--no gradients will ever flow through it. The output weight matrix W_out is initialized to zeros with shape (output_size, hidden_size + 1 + input_size)--the extra dimensions accommodate a bias term and the raw input, which are concatenated into the readout features.

The State Harvesting Loop

Since there is no BPTT, the forward pass just pushes input through the frozen reservoir and records the resulting states $X$ at each timestep. The update follows the leaky integrator equation from Part 1:

# Leaky integrator ESN state update
update = torch.tanh(u_t @ self.W_in.T + state @ self.W_res.T)
state = (1 - leaky_rate) * state + leaky_rate * update
states.append(state)

The state is initialized to zeros. At each timestep $t$, the input $u(t)$ is projected through $W_{in}$, the previous state is multiplied by $W_{res}$, and the sum passes through tanh. The leak rate $\alpha$ then blends the old state with the new activation. After the loop, states are stacked into a tensor of shape (seq_len, hidden_size) and concatenated with a bias column and the raw input to form the extended state matrix:

# Construct extended state: [1, u(t), x(t)]
bias = torch.ones(b, seq_len, 1, device=inputs.device)
extended_states = torch.cat([bias, inputs, states], dim=-1)
outputs = extended_states @ self.W_out.T

Washout: Discarding the Transient

The reservoir starts from a zero state, which means the first several timesteps reflect initial transient dynamics rather than the true input-driven response. The fit() method accepts a washout parameter (default 100) and discards that many initial states before solving the regression. This is essential--including the transient corrupts the readout fit.

Closed-Form Ridge Regression Readout

The fit() method is where training happens. Given harvested states $X$ and targets $Y$ (both after washout), we solve for $W_{out}$ via the Ridge Regression normal equations:

$$ W_{out} = Y^T X (X^T X + \lambda \mathbf{I})^{-1} $$

Rather than computing the matrix inverse explicitly, we reformulate this as a linear system and use torch.linalg.solve, which is numerically more stable:

# Add a bias term to the harvested states
S = torch.cat([bias, U, X], dim=1)
identity = torch.eye(S.shape[1], device=S.device)

# S^T * S and S^T * Y
STS = S.T @ S
STY = S.T @ Y

# Solve (X^T X + lambda I) W^T = X^T Y
W_out_T = torch.linalg.solve(STS + ridge_lambda * identity, STY)

# Assign parameters instantly
self.W_out.data = W_out_T.T

The regularization term $\lambda I$ (with $\lambda = 10^{-4}$ by default) prevents overfitting when the reservoir is very large relative to the training sequence. The matrix $S^T S$ has dimensions $(1 + d_{\text{input}} + d_{\text{hidden}}) \times (1 + d_{\text{input}} + d_{\text{hidden}})$--for our 500-neuron reservoir, that is a $502 \times 502$ system. Solving it takes milliseconds on any modern CPU.

That is the entire training procedure: one forward pass to harvest states, one matrix solve to fit the readout. No epochs, no learning rate, no gradient tape. The implementation includes unit tests verifying that the spectral radius is correctly scaled, output shapes are correct, and fitting actually reduces prediction error on a simple sine-wave task.