Part 1 covered the math: the Kolmogorov-Arnold theorem and B-spline parameterization. Now we
translate those formulas into a working PyTorch module. The goal is a single
KANLayer that can replace nn.Linear as a drop-in.
1. The Architecture of a KAN Layer
A standard Linear layer computes $y = Wx + b$ then applies a fixed activation $\sigma(y)$. A
KANLayer computes $y = \sum \phi(x)$, where each edge function combines a base
activation with a learned spline:
The PyTorch module manages three components:
- The Grid: Fixed knots defining B-spline evaluation points.
- Base Weights: $w_b$, a linear layer applied to SiLU-activated input.
- Spline Weights: $c_i$, coefficients controlling the learned curve shape on each edge.
2. Implementing B-Splines: The Cox-de Boor Algorithm
The hard part is evaluating B-spline basis functions $B_i(x)$ for a full batch without Python
for loops over batch or feature dimensions.
Cox-de Boor recursion starts from degree-0 indicator functions (1 inside a grid interval, 0 outside) and builds higher-degree splines by linearly interpolating between lower-degree ones.
Vectorized Code
def b_spline(self, x):
# Add dimension for broadcasting: x becomes [batch, in_features, 1]
x = x.unsqueeze(-1)
# Degree 0: Indicator functions
bases = ((x >= self.grid[:-1]) & (x < self.grid[1:])).to(x.dtype)
# Cox-de Boor recursion for higher degrees
for k in range(1, self.spline_order + 1):
left_num = x - self.grid[:-k-1]
left_den = self.grid[k:-1] - self.grid[:-k-1]
left = (left_num / left_den) * bases[:, :, :-1]
right_num = self.grid[k+1:] - x
right_den = self.grid[k+1:] - self.grid[1:-k]
right = (right_num / right_den) * bases[:, :, 1:]
bases = left + right
return bases
Broadcasting handles the entire batch and all input features in one pass.
3. The Forward Pass
With the B-spline bases evaluated, the forward pass reduces to two linear combinations:
def forward(self, x):
# 1. Base activation
base_output = F.linear(F.silu(x), self.base_weight)
# 2. Spline activation
splines = self.b_spline(x)
splines = splines.view(x.shape[0], -1)
spline_weight_flat = self.spline_weight.view(self.out_features, -1)
spline_output = F.linear(splines, spline_weight_flat)
# 3. Combine
return self.scale_base * base_output + self.scale_spline * spline_output
Flattening the splines and spline weights lets us use F.linear to sum across all edges
in a single matrix multiply.
4. Common Gotchas
Three failure modes to watch for:
- Grid Bound Violations (Dead Gradients): B-splines are only active within
their knot range (roughly $-1$ to $1$ here). Inputs outside this range produce zero spline
output and kill the gradient.
Fix: Apply LayerNorm before KAN layers, or implement dynamic grid updates that track activation statistics. - Exploding Spline Variance: High-variance initialization on spline coefficients
$c_i$ creates volatile curves and immediate
NaNlosses.
Fix: Initialize spline weights with small standard deviation ($\sigma = 0.1$) so the stable SiLU base dominates early training. - Division by Zero in Cox-de Boor: If two knots collapse to the same value
(especially during dynamic grid updates), the recursion denominators hit zero.
Fix: Enforce strictly monotonic grids and guard divisions.
The full implementation comes out to under 100 lines, is GPU-compatible, and exposes grid size and spline order as tunable hyperparameters.