Back to Neural ODEs Hub

Deconstructing Neural ODEs from Scratch

Part 2: PyTorch Implementation

Introduction

Part 1 covered the math. Now we turn those equations into PyTorch code.

The implementation lives in three files: ode_solver.py (Euler and RK4 fixed-step solvers), neural_ode.py (ODEFunc, NeuralODELayer, ODEClassifier), and train.py (spiral dataset, training loop, visualization). No torchdiffeq, no external ODE libraries.

ODE Solvers

Euler Step

The forward Euler update is a single line:

def euler_step(f, t, y, dt):
    return y + dt * f(t, y)

Because y has shape (batch, dim) and f returns the same shape, this naturally handles batched inputs. The scalar dt broadcasts across all dimensions.

RK4 Step

def rk4_step(f, t, y, dt):
    k1 = f(t, y)
    k2 = f(t + dt/2, y + dt/2 * k1)
    k3 = f(t + dt/2, y + dt/2 * k2)
    k4 = f(t + dt, y + dt * k3)
    return y + (dt/6) * (k1 + 2*k2 + 2*k3 + k4)

Four evaluations of f, combined in a weighted average. The intermediate stages $\mathbf{k}_2$ and $\mathbf{k}_3$ evaluate at the midpoint $t + \Delta t/2$, which is where the fourth-order accuracy comes from.

Integration Loop

def ode_solve(f, y0, t_span, n_steps, method="euler"):
    step_fn = {"euler": euler_step, "rk4": rk4_step}[method]
    t0, t1 = t_span
    dt = (t1 - t0) / n_steps
    trajectory = [y0]
    y, t = y0, t0
    for _ in range(n_steps):
        y = step_fn(f, t, y, dt)
        t = t + dt
        trajectory.append(y)
    return torch.stack(trajectory, dim=0), ...

The function returns the full trajectory tensor of shape (n_steps+1, batch, dim) -- useful for visualization -- plus the time grid.

The Dynamics Network: ODEFunc

The core component of a Neural ODE is the function $f(t, \mathbf{y}; \boldsymbol{\theta})$ that defines the dynamics.

Architecture

class ODEFunc(nn.Module):
    def __init__(self, dim, hidden=128):
        self.net = nn.Sequential(
            nn.Linear(dim + 1, hidden),
            nn.Tanh(),
            nn.Linear(hidden, hidden),
            nn.Tanh(),
            nn.Linear(hidden, dim),
        )

Three key design decisions:

Time concatenation. The input is $[\,t\;;\;\mathbf{y}\,] \in \mathbb{R}^{d+1}$, not just $\mathbf{y}$. This makes the dynamics non-autonomous: the vector field can change as we integrate. Without time, the flow is constrained to follow the same vector field everywhere, limiting expressivity.

Tanh activation. We use $\tanh$ rather than ReLU for a specific reason: the dynamics $f$ should be Lipschitz continuous to guarantee well-posed ODEs. $\tanh$ is globally Lipschitz (bounded derivatives), while ReLU has discontinuous gradients at zero. This matters for solver stability.

Zero initialization of the last layer. We initialize the final linear layer to zero weights and biases:

nn.init.zeros_(self.net[-1].weight)
nn.init.zeros_(self.net[-1].bias)

At initialization, the dynamics are $d\mathbf{y}/dt \approx 0$, so $\mathbf{y}(1) \approx \mathbf{y}(0)$. The network starts as an approximate identity map and gradually learns to deform the space. Without this, random initial dynamics can push states to extreme values and destabilize the solver.

The Neural ODE Layer

class NeuralODELayer(nn.Module):
    def forward(self, y0):
        trajectory, _ = ode_solve(
            self.func, y0, (0.0, 1.0), self.n_steps, self.method
        )
        return trajectory[-1]  # y(1)

The layer takes in $\mathbf{y}_0$ and returns $\mathbf{y}(1)$ by calling the ODE solver. Standard PyTorch autograd handles the backward pass by backpropagating through all solver steps.

For the experimental adjoint path, we implement a custom torch.autograd.Function that only stores the final state and passes gradients through without storing the full trajectory.

The Classifier: Dimensional Lifting

Why Lifting is Necessary

ODE trajectories in $\mathbb{R}^d$ cannot cross (uniqueness theorem). In 2D, this means a continuous flow cannot separate two interleaving spirals -- the trajectories would have to pass through each other.

Fix: lift the data to a higher-dimensional space before the ODE.

class ODEClassifier(nn.Module):
    def __init__(self, input_dim, n_classes, ode_dim=6, ...):
        self.lift = nn.Linear(input_dim, ode_dim)  # 2D -> 6D
        self.odefunc = ODEFunc(ode_dim, hidden=128)
        self.ode_layer = NeuralODELayer(...)
        self.classifier = nn.Sequential(
            nn.Linear(ode_dim, 64),
            nn.ReLU(),
            nn.Linear(64, n_classes),
        )

The architecture is:

$$ \mathbf{x} \in \mathbb{R}^2 \;\xrightarrow{\text{lift}}\; \mathbf{y}_0 \in \mathbb{R}^6 \;\xrightarrow{\text{ODE}}\; \mathbf{y}(1) \in \mathbb{R}^6 \;\xrightarrow{\text{classify}}\; \hat{\mathbf{c}} \in \mathbb{R}^2 $$

The Augmented Neural ODE Perspective

This is equivalent to the Augmented Neural ODE (Dupont et al., 2019): augment the state with extra dimensions via a learned projection, then run the ODE in that higher-dimensional space where uniqueness no longer prevents class separation. Here, lifting from 2D to 6D gives four extra dimensions for untangling the spirals.

The Baseline: ResNet MLP

For comparison, a discrete residual MLP that mirrors the Neural ODE structure:

class ResNetMLP(nn.Module):
    def forward(self, x):
        h = self.lift(x)
        dt = 1.0 / self.n_blocks
        for block in self.blocks:
            h = h + dt * block(h)  # Euler-like residual
        return self.classifier(h)

Each of the 20 blocks has its own set of weights (independent nn.Linear layers), while the Neural ODE reuses the same ODEFunc weights at every step. The dt = 1/n_blocks scaling mirrors the Euler method step size.

Summary

In Part 3, we train everything on the spiral dataset and analyze the results: accuracy, speed, trajectories, and memory.