Back to KANs Hub

Deconstructing Kolmogorov-Arnold Networks (KANs)

Part 2: Pure PyTorch Implementation

Introduction

In Part 1 of this series, we explored the elegant mathematical foundation of Kolmogorov-Arnold Networks (KANs). We discussed how the theorem allows us to represent complex multivariate functions using sums of 1D functions, and how B-splines provide the perfect differentiable basis to learn these functions.

Theory is powerful, but implementation represents true understanding. In Part 2, we will take those continuous mathematical formulations and translate them into a discrete, highly optimized PyTorch module. Our goal is to write a single KANLayer that acts as a drop-in replacement for PyTorch's standard nn.Linear.

View the complete source code on GitHub

1. The Architecture of a KAN Layer

A standard Linear layer computes $y = Wx + b$ and then applies a trailing activation $y = \sigma(y)$. Our KANLayer computes $y = \sum \phi(x)$.

Following the original KAN paper, the edge function $\phi(x)$ is a combination of a base activation and a spline activation:

$$ \phi(x) = w_b \text{SiLU}(x) + \sum_{i=1}^{c} c_i B_i(x) $$

Our PyTorch module needs to manage three distinct components:

  1. The Grid: A set of fixed knots that define where our B-splines are evaluated.
  2. Base Weights: $w_b$, which act similarly to a standard linear layer applied to a SiLU activation.
  3. Spline Weights: $c_i$, the coefficients that determine the shape of the learned curve on each edge.

2. Implementing B-Splines: The Cox-de Boor Algorithm

The most complex part of the implementation is evaluating the B-spline basis functions $B_i(x)$ for a batch of inputs natively in PyTorch. We want to avoid Python for loops over the batch or feature dimensions to ensure GPU acceleration.

We use the Cox-de Boor recursion formula. We define degree 0 splines as simple indicator functions (1 if the input is within a specific grid interval, 0 otherwise). Higher-degree splines are built by linearly interpolating between lower-degree splines.

Vectorized Code Snippet

Here is the core logical block within our module that handles the recursion:

def b_spline(self, x):
    # Add dimension for broadcasting: x becomes [batch, in_features, 1]
    x = x.unsqueeze(-1)
    
    # Degree 0: Indicator functions
    bases = ((x >= self.grid[:-1]) & (x < self.grid[1:])).to(x.dtype)
    
    # Cox-de Boor recursion for higher degrees
    for k in range(1, self.spline_order + 1):
        left_num = x - self.grid[:-k-1]
        left_den = self.grid[k:-1] - self.grid[:-k-1]
        left = (left_num / left_den) * bases[:, :, :-1]
        
        right_num = self.grid[k+1:] - x
        right_den = self.grid[k+1:] - self.grid[1:-k]
        right = (right_num / right_den) * bases[:, :, 1:]
        
        bases = left + right
        
    return bases

By utilizing PyTorch's broadcasting capabilities, we process the entire batch and all input features simultaneously.

3. The Forward Pass

Once we have our evaluated B-spline bases, the forward pass is surprisingly simple. It reduces to a large linear combination. We compute the base SiLU output, then compute the spline output by multiplying the evaluated bases by our learnable spline coefficients.

def forward(self, x):
    # 1. Base activation
    base_output = F.linear(F.silu(x), self.base_weight)
    
    # 2. Spline activation
    splines = self.b_spline(x)
    splines = splines.view(x.shape[0], -1) 
    spline_weight_flat = self.spline_weight.view(self.out_features, -1)
    spline_output = F.linear(splines, spline_weight_flat)
    
    # 3. Combine
    return self.scale_base * base_output + self.scale_spline * spline_output

Notice how we flatten the splines and the spline weights. This allows us to use PyTorch's highly optimized F.linear to perform the summation across all edges in a single matrix multiplication, making the execution speed comparable to standard deep learning operations.

4. Common Implementation Errors & Gotchas

While this clean 100-line implementation is functional, it is important to understand the common failure modes and errors that can arise when training KANs in practice:

  1. Grid Bound Violations (Dead Gradients): Standard B-splines are only active within their defined knot vector. In our implementation, the grid spans from roughly $-1$ to $1$. If the inputs $x$ to the layer exceed this range, the spline evaluation returns exactly zero. This will immediately "kill" the gradient for that edge.
    Solution: Strictly normalize inputs (e.g., LayerNorm) before passing them into a KAN Layer, or implement a dynamic grid update that expands bounds based on activation statistics.
  2. Exploding Spline Variance: The spline coefficients $c_i$ represent the $y$-values of the control points. If these are initialized with high variance, the resulting curves will be extremely volatile, causing immediate NaN losses.
    Solution: Initialize the spline weights with a very small standard deviation (e.g., $\sigma = 0.1$) to ensure the initial learning phase is dominated by the stable base SiLU activation.
  3. Division by Zero in Cox-de Boor: In scenarios involving dynamic grid updates, if two knot values become identical, the denominator in the Cox-de Boor recursion becomes zero, resulting in NaN values.
    Solution: Ensure grids strictly monotonically increase and safely handle division.

Conclusion

We have successfully built a mathematically rigorous KAN layer in under 100 lines of code. It is heavily modular, GPU compatible, and exposes all mathematical levers (grid size, spline order) directly to the user.

In Part 3, the grand finale, we will stack these KANLayers together and benchmark them against a traditional MLP on a complex symbolic regression task to see if the parameter-efficiency claims hold true.