Back to KANs Hub

Deconstructing Kolmogorov-Arnold Networks (KANs)

Part 2: Pure PyTorch Implementation

Introduction

Part 1 covered the math: the Kolmogorov-Arnold theorem and B-spline parameterization. Now we translate those formulas into a working PyTorch module. The goal is a single KANLayer that can replace nn.Linear as a drop-in.

View the complete source code on GitHub

1. The Architecture of a KAN Layer

A standard Linear layer computes $y = Wx + b$ then applies a fixed activation $\sigma(y)$. A KANLayer computes $y = \sum \phi(x)$, where each edge function combines a base activation with a learned spline:

$$ \phi(x) = w_b \text{SiLU}(x) + \sum_{i=1}^{c} c_i B_i(x) $$

The PyTorch module manages three components:

  1. The Grid: Fixed knots defining B-spline evaluation points.
  2. Base Weights: $w_b$, a linear layer applied to SiLU-activated input.
  3. Spline Weights: $c_i$, coefficients controlling the learned curve shape on each edge.

2. Implementing B-Splines: The Cox-de Boor Algorithm

The hard part is evaluating B-spline basis functions $B_i(x)$ for a full batch without Python for loops over batch or feature dimensions.

Cox-de Boor recursion starts from degree-0 indicator functions (1 inside a grid interval, 0 outside) and builds higher-degree splines by linearly interpolating between lower-degree ones.

Vectorized Code

def b_spline(self, x):
    # Add dimension for broadcasting: x becomes [batch, in_features, 1]
    x = x.unsqueeze(-1)

    # Degree 0: Indicator functions
    bases = ((x >= self.grid[:-1]) & (x < self.grid[1:])).to(x.dtype)

    # Cox-de Boor recursion for higher degrees
    for k in range(1, self.spline_order + 1):
        left_num = x - self.grid[:-k-1]
        left_den = self.grid[k:-1] - self.grid[:-k-1]
        left = (left_num / left_den) * bases[:, :, :-1]

        right_num = self.grid[k+1:] - x
        right_den = self.grid[k+1:] - self.grid[1:-k]
        right = (right_num / right_den) * bases[:, :, 1:]

        bases = left + right

    return bases

Broadcasting handles the entire batch and all input features in one pass.

3. The Forward Pass

With the B-spline bases evaluated, the forward pass reduces to two linear combinations:

def forward(self, x):
    # 1. Base activation
    base_output = F.linear(F.silu(x), self.base_weight)

    # 2. Spline activation
    splines = self.b_spline(x)
    splines = splines.view(x.shape[0], -1)
    spline_weight_flat = self.spline_weight.view(self.out_features, -1)
    spline_output = F.linear(splines, spline_weight_flat)

    # 3. Combine
    return self.scale_base * base_output + self.scale_spline * spline_output

Flattening the splines and spline weights lets us use F.linear to sum across all edges in a single matrix multiply.

4. Common Gotchas

Three failure modes to watch for:

  1. Grid Bound Violations (Dead Gradients): B-splines are only active within their knot range (roughly $-1$ to $1$ here). Inputs outside this range produce zero spline output and kill the gradient.
    Fix: Apply LayerNorm before KAN layers, or implement dynamic grid updates that track activation statistics.
  2. Exploding Spline Variance: High-variance initialization on spline coefficients $c_i$ creates volatile curves and immediate NaN losses.
    Fix: Initialize spline weights with small standard deviation ($\sigma = 0.1$) so the stable SiLU base dominates early training.
  3. Division by Zero in Cox-de Boor: If two knots collapse to the same value (especially during dynamic grid updates), the recursion denominators hit zero.
    Fix: Enforce strictly monotonic grids and guard divisions.

The full implementation comes out to under 100 lines, is GPU-compatible, and exposes grid size and spline order as tunable hyperparameters.