In Part 1 of this series, we explored the elegant mathematical foundation of Kolmogorov-Arnold Networks (KANs). We discussed how the theorem allows us to represent complex multivariate functions using sums of 1D functions, and how B-splines provide the perfect differentiable basis to learn these functions.
Theory is powerful, but implementation represents true understanding. In Part 2, we will take those
continuous mathematical formulations and translate them into a discrete, highly optimized PyTorch
module. Our goal is to write a single KANLayer that acts as a drop-in replacement for
PyTorch's standard nn.Linear.
1. The Architecture of a KAN Layer
A standard Linear layer computes $y = Wx + b$ and then applies a trailing activation $y = \sigma(y)$. Our
KANLayer computes $y = \sum \phi(x)$.
Following the original KAN paper, the edge function $\phi(x)$ is a combination of a base activation and a spline activation:
Our PyTorch module needs to manage three distinct components:
- The Grid: A set of fixed knots that define where our B-splines are evaluated.
- Base Weights: $w_b$, which act similarly to a standard linear layer applied to a SiLU activation.
- Spline Weights: $c_i$, the coefficients that determine the shape of the learned curve on each edge.
2. Implementing B-Splines: The Cox-de Boor Algorithm
The most complex part of the implementation is evaluating the B-spline basis functions $B_i(x)$ for a
batch of inputs natively in PyTorch. We want to avoid Python for loops over the batch or
feature dimensions to ensure GPU acceleration.
We use the Cox-de Boor recursion formula. We define degree 0 splines as simple indicator functions (1 if the input is within a specific grid interval, 0 otherwise). Higher-degree splines are built by linearly interpolating between lower-degree splines.
Vectorized Code Snippet
Here is the core logical block within our module that handles the recursion:
def b_spline(self, x):
# Add dimension for broadcasting: x becomes [batch, in_features, 1]
x = x.unsqueeze(-1)
# Degree 0: Indicator functions
bases = ((x >= self.grid[:-1]) & (x < self.grid[1:])).to(x.dtype)
# Cox-de Boor recursion for higher degrees
for k in range(1, self.spline_order + 1):
left_num = x - self.grid[:-k-1]
left_den = self.grid[k:-1] - self.grid[:-k-1]
left = (left_num / left_den) * bases[:, :, :-1]
right_num = self.grid[k+1:] - x
right_den = self.grid[k+1:] - self.grid[1:-k]
right = (right_num / right_den) * bases[:, :, 1:]
bases = left + right
return bases
By utilizing PyTorch's broadcasting capabilities, we process the entire batch and all input features simultaneously.
3. The Forward Pass
Once we have our evaluated B-spline bases, the forward pass is surprisingly simple. It reduces to a large linear combination. We compute the base SiLU output, then compute the spline output by multiplying the evaluated bases by our learnable spline coefficients.
def forward(self, x):
# 1. Base activation
base_output = F.linear(F.silu(x), self.base_weight)
# 2. Spline activation
splines = self.b_spline(x)
splines = splines.view(x.shape[0], -1)
spline_weight_flat = self.spline_weight.view(self.out_features, -1)
spline_output = F.linear(splines, spline_weight_flat)
# 3. Combine
return self.scale_base * base_output + self.scale_spline * spline_output
Notice how we flatten the splines and the spline weights. This allows us to use PyTorch's highly
optimized F.linear to perform the summation across all edges in a single matrix
multiplication, making the execution speed comparable to standard deep learning operations.
4. Common Implementation Errors & Gotchas
While this clean 100-line implementation is functional, it is important to understand the common failure modes and errors that can arise when training KANs in practice:
- Grid Bound Violations (Dead Gradients): Standard B-splines are only active within
their defined knot vector. In our implementation, the grid spans from roughly $-1$ to $1$. If the
inputs $x$ to the layer exceed this range, the spline evaluation returns exactly zero. This will
immediately "kill" the gradient for that edge.
Solution: Strictly normalize inputs (e.g., LayerNorm) before passing them into a KAN Layer, or implement a dynamic grid update that expands bounds based on activation statistics. - Exploding Spline Variance: The spline coefficients $c_i$ represent the $y$-values
of the control points. If these are initialized with high variance, the resulting curves will be
extremely volatile, causing immediate
NaNlosses.
Solution: Initialize the spline weights with a very small standard deviation (e.g., $\sigma = 0.1$) to ensure the initial learning phase is dominated by the stable base SiLU activation. - Division by Zero in Cox-de Boor: In scenarios involving dynamic grid updates, if
two knot values become identical, the denominator in the Cox-de Boor recursion becomes zero,
resulting in
NaNvalues.
Solution: Ensure grids strictly monotonically increase and safely handle division.
Conclusion
We have successfully built a mathematically rigorous KAN layer in under 100 lines of code. It is heavily modular, GPU compatible, and exposes all mathematical levers (grid size, spline order) directly to the user.
In Part 3, the grand finale, we will stack these KANLayers together and benchmark them
against a traditional MLP on a complex symbolic regression task to see if the parameter-efficiency
claims hold true.