From 683163406747d56a832e84724d0392b72dfd9b04 Mon Sep 17 00:00:00 2001 From: ajacoby9 Date: Thu, 15 Jan 2026 04:51:37 -0500 Subject: [PATCH 1/3] KAN implementation (#611) * Improve spline * Add KAN --------- Co-authored-by: Filippo Olivo --- .../kolmogorov_arnold_network/kan_layer.py | 223 ++++++++++++++++++ .../kolmogorov_arnold_network/kan_network.py | 194 +++++++++++++++ pina/model/spline.py | 2 +- 3 files changed, 418 insertions(+), 1 deletion(-) create mode 100644 pina/model/kolmogorov_arnold_network/kan_layer.py create mode 100644 pina/model/kolmogorov_arnold_network/kan_network.py diff --git a/pina/model/kolmogorov_arnold_network/kan_layer.py b/pina/model/kolmogorov_arnold_network/kan_layer.py new file mode 100644 index 000000000..ddd360587 --- /dev/null +++ b/pina/model/kolmogorov_arnold_network/kan_layer.py @@ -0,0 +1,223 @@ +"""Create the infrastructure for a KAN layer""" +import torch +import numpy as np + +from pina.model.spline import Spline + + +class KAN_layer(torch.nn.Module): + """define a KAN layer using splines""" + def __init__(self, k: int, input_dimensions: int, output_dimensions: int, inner_nodes: int, num=3, grid_eps=0.1, grid_range=[-1, 1], grid_extension=True, noise_scale=0.1, base_function=torch.nn.SiLU(), scale_base_mu=0.0, scale_base_sigma=1.0, scale_sp=1.0, sparse_init=True, sp_trainable=True, sb_trainable=True) -> None: + """ + Initialize the KAN layer. + """ + super().__init__() + self.k = k + self.input_dimensions = input_dimensions + self.output_dimensions = output_dimensions + self.inner_nodes = inner_nodes + self.num = num + self.grid_eps = grid_eps + self.grid_range = grid_range + self.grid_extension = grid_extension + + if sparse_init: + self.mask = torch.nn.Parameter(self.sparse_mask(input_dimensions, output_dimensions)).requires_grad_(False) + else: + self.mask = torch.nn.Parameter(torch.ones(input_dimensions, output_dimensions)).requires_grad_(False) + + grid = torch.linspace(grid_range[0], grid_range[1], steps=self.num + 1)[None,:].expand(self.input_dimensions, self.num+1) + + if grid_extension: + h = (grid[:, [-1]] - grid[:, [0]]) / (grid.shape[1] - 1) + for i in range(self.k): + grid = torch.cat([grid[:, [0]] - h, grid], dim=1) + grid = torch.cat([grid, grid[:, [-1]] + h], dim=1) + + n_coef = grid.shape[1] - (self.k + 1) + + control_points = torch.nn.Parameter( + torch.randn(self.input_dimensions, self.output_dimensions, n_coef) * noise_scale + ) + + self.spline = Spline(order=self.k+1, knots=grid, control_points=control_points, grid_extension=grid_extension) + + self.scale_base = torch.nn.Parameter(scale_base_mu * 1 / np.sqrt(input_dimensions) + \ + scale_base_sigma * (torch.rand(input_dimensions, output_dimensions)*2-1) * 1/np.sqrt(input_dimensions), requires_grad=sb_trainable) + self.scale_spline = torch.nn.Parameter(torch.ones(input_dimensions, output_dimensions) * scale_sp * 1 / np.sqrt(input_dimensions) * self.mask, requires_grad=sp_trainable) + self.base_function = base_function + + @staticmethod + def sparse_mask(in_dimensions: int, out_dimensions: int) -> torch.Tensor: + ''' + get sparse mask + ''' + in_coord = torch.arange(in_dimensions) * 1/in_dimensions + 1/(2*in_dimensions) + out_coord = torch.arange(out_dimensions) * 1/out_dimensions + 1/(2*out_dimensions) + + dist_mat = torch.abs(out_coord[:,None] - in_coord[None,:]) + in_nearest = torch.argmin(dist_mat, dim=0) + in_connection = torch.stack([torch.arange(in_dimensions), in_nearest]).permute(1,0) + out_nearest = torch.argmin(dist_mat, dim=1) + out_connection = torch.stack([out_nearest, torch.arange(out_dimensions)]).permute(1,0) + all_connection = torch.cat([in_connection, out_connection], dim=0) + mask = torch.zeros(in_dimensions, out_dimensions) + mask[all_connection[:,0], all_connection[:,1]] = 1. + return mask + + def forward(self, x: torch.Tensor) -> torch.Tensor: + """ + Forward pass through the KAN layer. + Each input goes through: w_base*base(x) + w_spline*spline(x) + Then sum across input dimensions for each output node. + """ + if hasattr(x, 'tensor'): + x_tensor = x.tensor + else: + x_tensor = x + + base = self.base_function(x_tensor) # (batch, input_dimensions) + + basis = self.spline.basis(x_tensor, self.spline.k, self.spline.knots) + spline_out_per_input = torch.einsum("bil,iol->bio", basis, self.spline.control_points) + + base_term = self.scale_base[None, :, :] * base[:, :, None] + spline_term = self.scale_spline[None, :, :] * spline_out_per_input + combined = base_term + spline_term + combined = self.mask[None,:,:] * combined + + output = torch.sum(combined, dim=1) # (batch, output_dimensions) + + return output + + def update_grid_from_samples(self, x: torch.Tensor, mode: str = 'sample'): + """ + Update grid from input samples to better fit data distribution. + Based on PyKAN implementation but with boundary preservation. + """ + # Convert LabelTensor to regular tensor for spline operations + if hasattr(x, 'tensor'): + # This is a LabelTensor, extract the tensor part + x_tensor = x.tensor + else: + x_tensor = x + + with torch.no_grad(): + batch_size = x_tensor.shape[0] + x_sorted = torch.sort(x_tensor, dim=0)[0] # (batch_size, input_dimensions) + + # Get current number of intervals (excluding extensions) + if self.grid_extension: + num_interval = self.spline.knots.shape[1] - 1 - 2*self.k + else: + num_interval = self.spline.knots.shape[1] - 1 + + def get_grid(num_intervals: int): + """PyKAN-style grid creation with boundary preservation""" + ids = [int(batch_size * i / num_intervals) for i in range(num_intervals)] + [-1] + grid_adaptive = x_sorted[ids, :].transpose(0, 1) # (input_dimensions, num_intervals+1) + + original_min = self.grid_range[0] + original_max = self.grid_range[1] + + # Clamp adaptive grid to not shrink beyond original domain + grid_adaptive[:, 0] = torch.min(grid_adaptive[:, 0], + torch.full_like(grid_adaptive[:, 0], original_min)) + grid_adaptive[:, -1] = torch.max(grid_adaptive[:, -1], + torch.full_like(grid_adaptive[:, -1], original_max)) + + margin = 0.0 + h = (grid_adaptive[:, [-1]] - grid_adaptive[:, [0]] + 2 * margin) / num_intervals + grid_uniform = (grid_adaptive[:, [0]] - margin + + h * torch.arange(num_intervals + 1, device=x_tensor.device, dtype=x_tensor.dtype)[None, :]) + + grid_blended = (self.grid_eps * grid_uniform + + (1 - self.grid_eps) * grid_adaptive) + + return grid_blended + + # Create augmented evaluation points: samples + boundary points + # This ensures we preserve boundary behavior while adapting to sample density + boundary_points = torch.tensor([[self.grid_range[0]], [self.grid_range[1]]], + device=x_tensor.device, dtype=x_tensor.dtype).expand(-1, self.input_dimensions) + + # Combine samples with boundary points for evaluation + x_augmented = torch.cat([x_sorted, boundary_points], dim=0) + x_augmented = torch.sort(x_augmented, dim=0)[0] # Re-sort with boundaries included + + # Evaluate current spline at augmented points (samples + boundaries) + basis = self.spline.basis(x_augmented, self.spline.k, self.spline.knots) + y_eval = torch.einsum("bil,iol->bio", basis, self.spline.control_points) + + # Create new grid + new_grid = get_grid(num_interval) + + if mode == 'grid': + # For 'grid' mode, use denser sampling + sample_grid = get_grid(2 * num_interval) + x_augmented = sample_grid.transpose(0, 1) # (batch_size, input_dimensions) + basis = self.spline.basis(x_augmented, self.spline.k, self.spline.knots) + y_eval = torch.einsum("bil,iol->bio", basis, self.spline.control_points) + + # Add grid extensions if needed + if self.grid_extension: + h = (new_grid[:, [-1]] - new_grid[:, [0]]) / (new_grid.shape[1] - 1) + for i in range(self.k): + new_grid = torch.cat([new_grid[:, [0]] - h, new_grid], dim=1) + new_grid = torch.cat([new_grid, new_grid[:, [-1]] + h], dim=1) + + # Update grid and refit coefficients + self.spline.knots = new_grid + + try: + # Refit coefficients using augmented points (preserves boundaries) + self.spline.compute_control_points(x_augmented, y_eval) + except Exception as e: + print(f"Warning: Failed to update coefficients during grid refinement: {e}") + + def update_grid_resolution(self, new_num: int): + """ + Update grid resolution to a new number of intervals. + """ + with torch.no_grad(): + # Sample the current spline function on a dense grid + x_eval = torch.linspace( + self.grid_range[0], + self.grid_range[1], + steps=2 * new_num, + device=self.spline.knots.device + ) + x_eval = x_eval.unsqueeze(1).expand(-1, self.input_dimensions) + + basis = self.spline.basis(x_eval, self.spline.k, self.spline.knots) + y_eval = torch.einsum("bil,iol->bio", basis, self.spline.control_points) + + # Update num and create a new grid + self.num = new_num + new_grid = torch.linspace( + self.grid_range[0], + self.grid_range[1], + steps=self.num + 1, + device=self.spline.knots.device + ) + new_grid = new_grid[None, :].expand(self.input_dimensions, self.num + 1) + + if self.grid_extension: + h = (new_grid[:, [-1]] - new_grid[:, [0]]) / (new_grid.shape[1] - 1) + for i in range(self.k): + new_grid = torch.cat([new_grid[:, [0]] - h, new_grid], dim=1) + new_grid = torch.cat([new_grid, new_grid[:, [-1]] + h], dim=1) + + # Update spline with the new grid and re-compute control points + self.spline.knots = new_grid + self.spline.compute_control_points(x_eval, y_eval) + + def get_grid_statistics(self): + """Get statistics about the current grid for debugging/analysis""" + return { + 'grid_shape': self.spline.knots.shape, + 'grid_min': self.spline.knots.min().item(), + 'grid_max': self.spline.knots.max().item(), + 'grid_range': (self.spline.knots.max() - self.spline.knots.min()).mean().item(), + 'num_intervals': self.spline.knots.shape[1] - 1 - (2*self.k if self.spline.grid_extension else 0) + } \ No newline at end of file diff --git a/pina/model/kolmogorov_arnold_network/kan_network.py b/pina/model/kolmogorov_arnold_network/kan_network.py new file mode 100644 index 000000000..cd94a5894 --- /dev/null +++ b/pina/model/kolmogorov_arnold_network/kan_network.py @@ -0,0 +1,194 @@ +"""Kolmogorov Arnold Network implementation""" +import torch +import torch.nn as nn +from typing import List + +try: + from .kan_layer import KAN_layer +except ImportError: + from kan_layer import KAN_layer + +class KAN_Network(torch.nn.Module): + """ + Kolmogorov Arnold Network - A neural network using KAN layers instead of traditional MLP layers. + Each layer uses learnable univariate functions (B-splines + base functions) on edges. + """ + + def __init__( + self, + layer_sizes: List[int], + k: int = 3, + num: int = 3, + grid_eps: float = 0.1, + grid_range: List[float] = [-1, 1], + grid_extension: bool = True, + noise_scale: float = 0.1, + base_function = torch.nn.SiLU(), + scale_base_mu: float = 0.0, + scale_base_sigma: float = 1.0, + scale_sp: float = 1.0, + inner_nodes: int = 5, + sparse_init: bool = False, + sp_trainable: bool = True, + sb_trainable: bool = True, + save_act: bool = True + ): + """ + Initialize the KAN network. + + Args: + layer_sizes: List of integers defining the size of each layer [input_dim, hidden1, hidden2, ..., output_dim] + k: Order of the B-spline + num: Number of grid points for B-splines + grid_eps: Epsilon for grid spacing + grid_range: Range for the grid [min, max] + grid_extension: Whether to extend the grid + noise_scale: Scale for initialization noise + base_function: Base activation function (e.g., SiLU) + scale_base_mu: Mean for base function scaling + scale_base_sigma: Std for base function scaling + scale_sp: Scale for spline functions + """ + super().__init__() + + if len(layer_sizes) < 2: + raise ValueError("Need at least input and output dimensions") + + self.layer_sizes = layer_sizes + self.num_layers = len(layer_sizes) - 1 + self.save_act = save_act + + # Create KAN layers + self.kan_layers = nn.ModuleList() + + for i in range(self.num_layers): + layer = KAN_layer( + k=k, + input_dimensions=layer_sizes[i], + output_dimensions=layer_sizes[i+1], + num=num, + grid_eps=grid_eps, + grid_range=grid_range, + grid_extension=grid_extension, + noise_scale=noise_scale, + base_function=base_function, + scale_base_mu=scale_base_mu, + scale_base_sigma=scale_base_sigma, + scale_sp=scale_sp, + inner_nodes=inner_nodes, + sparse_init=sparse_init, + sp_trainable=sp_trainable, + sb_trainable=sb_trainable + ) + self.kan_layers.append(layer) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + """ + Forward pass through the KAN network. + + Args: + x: Input tensor of shape (batch_size, input_dimensions) + + Returns: + Output tensor of shape (batch_size, output_dimensions) + """ + current = x + self.acts = [current] + + for i, layer in enumerate(self.kan_layers): + current = layer(current) + + if self.save_act: + self.acts.append(current.detach()) + + return current + + def get_num_parameters(self) -> int: + """Get total number of trainable parameters""" + return sum(p.numel() for p in self.parameters() if p.requires_grad) + + + def update_grid_from_samples(self, x: torch.Tensor, mode: str = 'sample'): + """ + Update grid for all layers based on input samples. + This adapts the grid points to better fit the data distribution. + + Args: + x: Input samples, shape (batch_size, input_dimensions) + mode: 'sample' or 'grid' - determines sampling strategy + """ + current = x + + for i, layer in enumerate(self.kan_layers): + layer.update_grid_from_samples(current, mode=mode) + + if i < len(self.kan_layers) - 1: + with torch.no_grad(): + current = layer(current) + + def update_grid_resolution(self, new_num: int): + """ + Update the grid resolution for all layers. + This can be used for adaptive training where grid resolution increases over time. + + Args: + new_num: New number of grid points + """ + for layer in self.kan_layers: + layer.update_grid_resolution(new_num) + + def enable_sparsification(self, threshold: float = 1e-4): + """ + Enable sparsification by setting small weights to zero. + + Args: + threshold: Threshold below which weights are set to zero + """ + with torch.no_grad(): + for layer in self.kan_layers: + # Sparsify scale parameters + layer.scale_base.data[torch.abs(layer.scale_base.data) < threshold] = 0 + layer.scale_spline.data[torch.abs(layer.scale_spline.data) < threshold] = 0 + + # Update mask + layer.mask.data = ((torch.abs(layer.scale_base) >= threshold) | + (torch.abs(layer.scale_spline) >= threshold)).float() + + def get_activation_statistics(self, x: torch.Tensor): + """ + Get statistics about activations for analysis purposes. + + Args: + x: Input tensor + + Returns: + Dictionary with activation statistics + """ + stats = {} + current = x + + for i, layer in enumerate(self.kan_layers): + current = layer(current) + stats[f'layer_{i}'] = { + 'mean': current.mean().item(), + 'std': current.std().item(), + 'min': current.min().item(), + 'max': current.max().item() + } + + return stats + + + def get_network_grid_statistics(self): + """ + Get grid statistics for all layers in the network. + + Returns: + Dictionary with grid statistics for each layer + """ + stats = {} + for i, layer in enumerate(self.kan_layers): + stats[f'layer_{i}'] = layer.get_grid_statistics() + return stats + + \ No newline at end of file diff --git a/pina/model/spline.py b/pina/model/spline.py index d9141fe8c..100b2c6d0 100644 --- a/pina/model/spline.py +++ b/pina/model/spline.py @@ -475,4 +475,4 @@ def knots(self, value): self._boundary_interval_idx = self._compute_boundary_interval() # Recompute derivative denominators when knots change - self._compute_derivative_denominators() + self._compute_derivative_denominators() \ No newline at end of file From 33d3862a0c6dd4131bb360019b45d3b4c6700565 Mon Sep 17 00:00:00 2001 From: Nicola Demo Date: Wed, 21 Jan 2026 14:27:14 +0100 Subject: [PATCH 2/3] KAN with non-vectorized spline --- pina/condition/tensor_condition.py | 84 ++++++++++ pina/model/__init__.py | 3 + pina/model/block/__init__.py | 2 + .../kan_layer.py => block/kan_block.py} | 66 +++++--- ...etwork.py => kolmogorov_arnold_network.py} | 54 +++--- pina/model/spline.py | 7 +- .../test_kolmogorov_arnold_network.py | 156 ++++++++++++++++++ 7 files changed, 326 insertions(+), 46 deletions(-) create mode 100644 pina/condition/tensor_condition.py rename pina/model/{kolmogorov_arnold_network/kan_layer.py => block/kan_block.py} (80%) rename pina/model/{kolmogorov_arnold_network/kan_network.py => kolmogorov_arnold_network.py} (76%) create mode 100644 tests/test_model/test_kolmogorov_arnold_network.py diff --git a/pina/condition/tensor_condition.py b/pina/condition/tensor_condition.py new file mode 100644 index 000000000..fa4b53637 --- /dev/null +++ b/pina/condition/tensor_condition.py @@ -0,0 +1,84 @@ +"""Module for the DataCondition class.""" + +import torch +from torch_geometric.data import Data +from .condition_interface import ConditionInterface +from ..label_tensor import LabelTensor +from ..graph import Graph + + +class _TensorCondition(ConditionInterface): + + __slots__ = ["input", "conditional_variables"] + _avail_input_cls = (torch.Tensor, LabelTensor, Data, Graph, list, tuple) + _avail_conditional_variables_cls = (torch.Tensor, LabelTensor) + + def __new__(cls, input, conditional_variables=None): + """ + Instantiate the appropriate subclass of :class:`DataCondition` based on + the type of ``input``. + + :param input: Input data for the condition. + :type input: torch.Tensor | LabelTensor | Graph | + Data | list[Graph] | list[Data] | tuple[Graph] | tuple[Data] + :param conditional_variables: Conditional variables for the condition. + :type conditional_variables: torch.Tensor | LabelTensor, optional + :return: Subclass of DataCondition. + :rtype: pina.condition.data_condition.TensorDataCondition | + pina.condition.data_condition.GraphDataCondition + + :raises ValueError: If input is not of type :class:`torch.Tensor`, + :class:`~pina.label_tensor.LabelTensor`, :class:`~pina.graph.Graph`, + or :class:`~torch_geometric.data.Data`. + """ + + if cls != DataCondition: + return super().__new__(cls) + if isinstance(input, (torch.Tensor, LabelTensor)): + subclass = TensorDataCondition + return subclass.__new__(subclass, input, conditional_variables) + + if isinstance(input, (Graph, Data, list, tuple)): + cls._check_graph_list_consistency(input) + subclass = GraphDataCondition + return subclass.__new__(subclass, input, conditional_variables) + + raise ValueError( + "Invalid input types. " + "Please provide either torch_geometric.data.Data or Graph objects." + ) + + def __init__(self, input, conditional_variables=None): + """ + Initialize the object by storing the input and conditional + variables (if any). + + :param input: Input data for the condition. + :type input: torch.Tensor | LabelTensor | Graph | Data | list[Graph] | + list[Data] | tuple[Graph] | tuple[Data] + :param conditional_variables: Conditional variables for the condition. + :type conditional_variables: torch.Tensor | LabelTensor + + .. note:: + If ``input`` consists of a list of :class:`~pina.graph.Graph` or + :class:`~torch_geometric.data.Data`, all elements must have the same + structure (keys and data types) + """ + + super().__init__() + self.input = input + self.conditional_variables = conditional_variables + + +class TensorDataCondition(DataCondition): + """ + DataCondition for :class:`torch.Tensor` or + :class:`~pina.label_tensor.LabelTensor` input data + """ + + +class GraphDataCondition(DataCondition): + """ + DataCondition for :class:`~pina.graph.Graph` or + :class:`~torch_geometric.data.Data` input data + """ diff --git a/pina/model/__init__.py b/pina/model/__init__.py index 05ccc6c8c..f724e9cf3 100644 --- a/pina/model/__init__.py +++ b/pina/model/__init__.py @@ -16,6 +16,8 @@ "PirateNet", "EquivariantGraphNeuralOperator", "SINDy", + "SplineSurface", + "KolmogorovArnoldNetwork", ] from .feed_forward import FeedForward, ResidualFeedForward @@ -31,3 +33,4 @@ from .pirate_network import PirateNet from .equivariant_graph_neural_operator import EquivariantGraphNeuralOperator from .sindy import SINDy +from .kolmogorov_arnold_network import KolmogorovArnoldNetwork diff --git a/pina/model/block/__init__.py b/pina/model/block/__init__.py index 08b313387..788c78864 100644 --- a/pina/model/block/__init__.py +++ b/pina/model/block/__init__.py @@ -19,6 +19,7 @@ "RBFBlock", "GNOBlock", "PirateNetBlock", + "KANBlock", ] from .convolution_2d import ContinuousConvBlock @@ -37,3 +38,4 @@ from .rbf_block import RBFBlock from .gno_block import GNOBlock from .pirate_network_block import PirateNetBlock +from .kan_block import KANBlock diff --git a/pina/model/kolmogorov_arnold_network/kan_layer.py b/pina/model/block/kan_block.py similarity index 80% rename from pina/model/kolmogorov_arnold_network/kan_layer.py rename to pina/model/block/kan_block.py index ddd360587..93048c520 100644 --- a/pina/model/kolmogorov_arnold_network/kan_layer.py +++ b/pina/model/block/kan_block.py @@ -5,11 +5,13 @@ from pina.model.spline import Spline -class KAN_layer(torch.nn.Module): +class KANBlock(torch.nn.Module): """define a KAN layer using splines""" def __init__(self, k: int, input_dimensions: int, output_dimensions: int, inner_nodes: int, num=3, grid_eps=0.1, grid_range=[-1, 1], grid_extension=True, noise_scale=0.1, base_function=torch.nn.SiLU(), scale_base_mu=0.0, scale_base_sigma=1.0, scale_sp=1.0, sparse_init=True, sp_trainable=True, sb_trainable=True) -> None: """ Initialize the KAN layer. + + num è il numero di intervalli nella griglia iniziale (esclusi gli eventuali nodi di estensione) """ super().__init__() self.k = k @@ -27,6 +29,7 @@ def __init__(self, k: int, input_dimensions: int, output_dimensions: int, inner_ self.mask = torch.nn.Parameter(torch.ones(input_dimensions, output_dimensions)).requires_grad_(False) grid = torch.linspace(grid_range[0], grid_range[1], steps=self.num + 1)[None,:].expand(self.input_dimensions, self.num+1) + knots = torch.linspace(grid_range[0], grid_range[1], steps=self.num + 1) if grid_extension: h = (grid[:, [-1]] - grid[:, [0]]) / (grid.shape[1] - 1) @@ -34,17 +37,38 @@ def __init__(self, k: int, input_dimensions: int, output_dimensions: int, inner_ grid = torch.cat([grid[:, [0]] - h, grid], dim=1) grid = torch.cat([grid, grid[:, [-1]] + h], dim=1) - n_coef = grid.shape[1] - (self.k + 1) + n_control_points = len(knots) - (self.k ) - control_points = torch.nn.Parameter( - torch.randn(self.input_dimensions, self.output_dimensions, n_coef) * noise_scale - ) + # control_points = torch.nn.Parameter( + # torch.randn(self.input_dimensions, self.output_dimensions, n_control_points) * noise_scale + # ) + # print(control_points.shape) + spline_q = [] + for q in range(self.output_dimensions): + spline_p = [] + for p in range(self.input_dimensions): + spline_ = Spline( + order=self.k, + knots=knots, + control_points=torch.randn(n_control_points) + ) + spline_p.append(spline_) + spline_p = torch.nn.ModuleList(spline_p) + spline_q.append(spline_p) + self.spline_q = torch.nn.ModuleList(spline_q) + + + # control_points = torch.nn.Parameter( + # torch.randn(n_control_points, self.output_dimensions) * noise_scale) + # print(control_points) + # print('uuu') - self.spline = Spline(order=self.k+1, knots=grid, control_points=control_points, grid_extension=grid_extension) + # self.spline = Spline( + # order=self.k, knots=knots, control_points=control_points) - self.scale_base = torch.nn.Parameter(scale_base_mu * 1 / np.sqrt(input_dimensions) + \ - scale_base_sigma * (torch.rand(input_dimensions, output_dimensions)*2-1) * 1/np.sqrt(input_dimensions), requires_grad=sb_trainable) - self.scale_spline = torch.nn.Parameter(torch.ones(input_dimensions, output_dimensions) * scale_sp * 1 / np.sqrt(input_dimensions) * self.mask, requires_grad=sp_trainable) + # self.scale_base = torch.nn.Parameter(scale_base_mu * 1 / np.sqrt(input_dimensions) + \ + # scale_base_sigma * (torch.rand(input_dimensions, output_dimensions)*2-1) * 1/np.sqrt(input_dimensions), requires_grad=sb_trainable) + # self.scale_spline = torch.nn.Parameter(torch.ones(input_dimensions, output_dimensions) * scale_sp * 1 / np.sqrt(input_dimensions) * self.mask, requires_grad=sp_trainable) self.base_function = base_function @staticmethod @@ -76,19 +100,17 @@ def forward(self, x: torch.Tensor) -> torch.Tensor: else: x_tensor = x - base = self.base_function(x_tensor) # (batch, input_dimensions) - - basis = self.spline.basis(x_tensor, self.spline.k, self.spline.knots) - spline_out_per_input = torch.einsum("bil,iol->bio", basis, self.spline.control_points) - - base_term = self.scale_base[None, :, :] * base[:, :, None] - spline_term = self.scale_spline[None, :, :] * spline_out_per_input - combined = base_term + spline_term - combined = self.mask[None,:,:] * combined - - output = torch.sum(combined, dim=1) # (batch, output_dimensions) - - return output + y = [] + for q in range(self.output_dimensions): + y_q = [] + for p in range(self.input_dimensions): + spline_out = self.spline_q[q][p].forward(x_tensor[:, p]) # (batch, input_dimensions, output_dimensions) + base_out = self.base_function(x_tensor[:, p]) # (batch, input_dimensions) + y_q.append(spline_out + base_out) + y.append(torch.stack(y_q, dim=1).sum(dim=1)) + y = torch.stack(y, dim=1) + + return y def update_grid_from_samples(self, x: torch.Tensor, mode: str = 'sample'): """ diff --git a/pina/model/kolmogorov_arnold_network/kan_network.py b/pina/model/kolmogorov_arnold_network.py similarity index 76% rename from pina/model/kolmogorov_arnold_network/kan_network.py rename to pina/model/kolmogorov_arnold_network.py index cd94a5894..ad518b7bc 100644 --- a/pina/model/kolmogorov_arnold_network/kan_network.py +++ b/pina/model/kolmogorov_arnold_network.py @@ -3,15 +3,20 @@ import torch.nn as nn from typing import List -try: - from .kan_layer import KAN_layer -except ImportError: - from kan_layer import KAN_layer +from pina.model.block import KANBlock -class KAN_Network(torch.nn.Module): +class KolmogorovArnoldNetwork(torch.nn.Module): """ - Kolmogorov Arnold Network - A neural network using KAN layers instead of traditional MLP layers. - Each layer uses learnable univariate functions (B-splines + base functions) on edges. + Kolmogorov Arnold Network, a neural network using KAN layers instead of + traditional MLP layers. Each layer uses learnable univariate functions + (B-splines + base functions) on edges. + + .. references:: + + Liu, Z., Wang, Y., Vaidya, S., Ruehle, F., Halverson, J., Soljačić, M., + ... & Tegmark, M. (2024). Kan: Kolmogorov-arnold networks. arXiv + preprint arXiv:2404.19756. + """ def __init__( @@ -35,19 +40,25 @@ def __init__( ): """ Initialize the KAN network. - - Args: - layer_sizes: List of integers defining the size of each layer [input_dim, hidden1, hidden2, ..., output_dim] - k: Order of the B-spline - num: Number of grid points for B-splines - grid_eps: Epsilon for grid spacing - grid_range: Range for the grid [min, max] - grid_extension: Whether to extend the grid - noise_scale: Scale for initialization noise - base_function: Base activation function (e.g., SiLU) - scale_base_mu: Mean for base function scaling - scale_base_sigma: Std for base function scaling - scale_sp: Scale for spline functions + + :param iterable layer_sizes: List of layer sizes including input and + output dimensions. + :param int k: Order of the B-spline. + :param int num: Number of grid points for B-splines. + :param float grid_eps: Epsilon for grid spacing. + :param list grid_range: Range for the grid [min, max]. + :param bool grid_extension: Whether to extend the grid. + :param float noise_scale: Scale for initialization noise. + :param base_function: Base activation function (e.g., SiLU). + :param float scale_base_mu: Mean for base function scaling. + :param float scale_base_sigma: Std for base function scaling. + :param float scale_sp: Scale for spline functions. + :param int inner_nodes: Number of inner nodes for KAN layers. + :param bool sparse_init: Whether to use sparse initialization. + :param bool sp_trainable: Whether spline parameters are trainable. + :param bool sb_trainable: Whether base function parameters are + trainable. + :param bool save_act: Whether to save activations after each layer. """ super().__init__() @@ -62,7 +73,7 @@ def __init__( self.kan_layers = nn.ModuleList() for i in range(self.num_layers): - layer = KAN_layer( + layer = KANBlock( k=k, input_dimensions=layer_sizes[i], output_dimensions=layer_sizes[i+1], @@ -97,6 +108,7 @@ def forward(self, x: torch.Tensor) -> torch.Tensor: for i, layer in enumerate(self.kan_layers): current = layer(current) + current = torch.nn.functional.sigmoid(current) if self.save_act: self.acts.append(current.detach()) diff --git a/pina/model/spline.py b/pina/model/spline.py index 100b2c6d0..7383e7268 100644 --- a/pina/model/spline.py +++ b/pina/model/spline.py @@ -117,7 +117,7 @@ def __init__(self, order=4, knots=None, control_points=None): raise ValueError("knots must be one-dimensional.") # Check dimensionality of control points - if self.control_points.ndim > 1: + if self.control_points.ndim > 2: raise ValueError("control_points must be one-dimensional.") # Raise error if #knots != order + #control_points @@ -277,9 +277,10 @@ def forward(self, x): :return: The output tensor. :rtype: torch.Tensor """ + basis = self.basis(x.as_subclass(torch.Tensor)) return torch.einsum( - "...bi, i -> ...b", - self.basis(x.as_subclass(torch.Tensor)).squeeze(-1), + "...bi, ...i -> ...b", + basis, self.control_points, ) diff --git a/tests/test_model/test_kolmogorov_arnold_network.py b/tests/test_model/test_kolmogorov_arnold_network.py new file mode 100644 index 000000000..245257028 --- /dev/null +++ b/tests/test_model/test_kolmogorov_arnold_network.py @@ -0,0 +1,156 @@ +import torch +import pytest + +from pina.model.block import KANBlock +from pina.model import KolmogorovArnoldNetwork + +data = torch.rand((20, 3)) +input_vars = 3 +output_vars = 1 + + +def test_constructor(): + KolmogorovArnoldNetwork([input_vars, output_vars]) + KolmogorovArnoldNetwork([input_vars, 10, 20, output_vars]) + KolmogorovArnoldNetwork( + [input_vars, 10, 20, output_vars], + k=3, + num=5 + ) + KolmogorovArnoldNetwork( + [input_vars, 10, 20, output_vars], + k=3, + num=5, + grid_eps=0.05, + grid_range=[-2, 2] + ) + KolmogorovArnoldNetwork( + [input_vars, 10, output_vars], + base_function=torch.nn.Tanh(), + scale_sp=0.5, + sparse_init=True + ) + + +def test_constructor_wrong(): + with pytest.raises(ValueError): + KolmogorovArnoldNetwork([input_vars]) + with pytest.raises(ValueError): + KolmogorovArnoldNetwork([]) + + +def test_forward(): + dim_in, dim_out = 3, 2 + kan = KolmogorovArnoldNetwork([dim_in, dim_out]) + output_ = kan(data) + assert output_.shape == (data.shape[0], dim_out) + + +def test_forward_multilayer(): + dim_in, dim_out = 3, 2 + kan = KolmogorovArnoldNetwork([dim_in, 10, 5, dim_out]) + output_ = kan(data) + assert output_.shape == (data.shape[0], dim_out) + + +def test_backward(): + dim_in, dim_out = 3, 2 + kan = KolmogorovArnoldNetwork([dim_in, dim_out]) + data.requires_grad = True + output_ = kan(data) + loss = torch.mean(output_) + loss.backward() + assert data._grad.shape == torch.Size([20, 3]) + + +def test_get_num_parameters(): + kan = KolmogorovArnoldNetwork([3, 5, 2]) + num_params = kan.get_num_parameters() + assert num_params > 0 + assert isinstance(num_params, int) + +from pina.problem.zoo import Poisson2DSquareProblem +from pina.solver import PINN +from pina.trainer import Trainer + +def test_train_poisson(): + problem = Poisson2DSquareProblem() + problem.discretise_domain(n=10, mode="random", domains="all") + + model = KolmogorovArnoldNetwork([2, 3, 1], k=3, num=5) + solver = PINN(model=model, problem=problem) + trainer = Trainer( + solver=solver, + max_epochs=1000, + accelerator="cpu", + batch_size=100, + train_size=1.0, + val_size=0.0, + test_size=0.0, + ) + trainer.train() + assert False + + + + +# def test_update_grid_from_samples(): +# kan = KolmogorovArnoldNetwork([3, 5, 2]) +# samples = torch.randn(50, 3) +# kan.update_grid_from_samples(samples, mode='sample') +# # Check that the network still works after grid update +# output = kan(data) +# assert output.shape == (data.shape[0], 2) + + +# def test_update_grid_resolution(): +# kan = KolmogorovArnoldNetwork([3, 5, 2], num=3) +# kan.update_grid_resolution(5) +# # Check that the network still works after resolution update +# output = kan(data) +# assert output.shape == (data.shape[0], 2) + + +# def test_enable_sparsification(): +# kan = KolmogorovArnoldNetwork([3, 5, 2]) +# kan.enable_sparsification(threshold=1e-4) +# # Check that the network still works after sparsification +# output = kan(data) +# assert output.shape == (data.shape[0], 2) + + +# def test_get_activation_statistics(): +# kan = KolmogorovArnoldNetwork([3, 5, 2]) +# stats = kan.get_activation_statistics(data) +# assert isinstance(stats, dict) +# assert 'layer_0' in stats +# assert 'layer_1' in stats +# assert 'mean' in stats['layer_0'] +# assert 'std' in stats['layer_0'] +# assert 'min' in stats['layer_0'] +# assert 'max' in stats['layer_0'] + + +# def test_get_network_grid_statistics(): +# kan = KolmogorovArnoldNetwork([3, 5, 2]) +# stats = kan.get_network_grid_statistics() +# assert isinstance(stats, dict) +# assert 'layer_0' in stats +# assert 'layer_1' in stats + + +# def test_save_act(): +# kan = KolmogorovArnoldNetwork([3, 5, 2], save_act=True) +# output = kan(data) +# assert hasattr(kan, 'acts') +# assert len(kan.acts) == 3 # input + 2 layers +# assert kan.acts[0].shape == data.shape +# assert kan.acts[-1].shape == output.shape + + +# def test_save_act_disabled(): +# kan = KolmogorovArnoldNetwork([3, 5, 2], save_act=False) +# _ = kan(data) +# assert hasattr(kan, 'acts') +# # Only the first activation (input) is saved +# assert len(kan.acts) == 1 From 3ad5a68b273f9c26aa923dbf224430db1ea9bd2d Mon Sep 17 00:00:00 2001 From: Nicola Demo Date: Thu, 22 Jan 2026 10:57:24 +0100 Subject: [PATCH 3/3] vectorized spline --- pina/model/block/kan_block.py | 70 +++++++--- pina/model/kolmogorov_arnold_network.py | 2 +- pina/model/spline.py | 169 +++++++++++++++++++++++- tests/test_model/test_spline.py | 38 ++++++ 4 files changed, 254 insertions(+), 25 deletions(-) diff --git a/pina/model/block/kan_block.py b/pina/model/block/kan_block.py index 93048c520..ffb8edc6b 100644 --- a/pina/model/block/kan_block.py +++ b/pina/model/block/kan_block.py @@ -22,6 +22,8 @@ def __init__(self, k: int, input_dimensions: int, output_dimensions: int, inner_ self.grid_eps = grid_eps self.grid_range = grid_range self.grid_extension = grid_extension + self.vec = True + # self.vec = False if sparse_init: self.mask = torch.nn.Parameter(self.sparse_mask(input_dimensions, output_dimensions)).requires_grad_(False) @@ -43,19 +45,35 @@ def __init__(self, k: int, input_dimensions: int, output_dimensions: int, inner_ # torch.randn(self.input_dimensions, self.output_dimensions, n_control_points) * noise_scale # ) # print(control_points.shape) - spline_q = [] - for q in range(self.output_dimensions): - spline_p = [] - for p in range(self.input_dimensions): - spline_ = Spline( - order=self.k, - knots=knots, - control_points=torch.randn(n_control_points) - ) - spline_p.append(spline_) - spline_p = torch.nn.ModuleList(spline_p) - spline_q.append(spline_p) - self.spline_q = torch.nn.ModuleList(spline_q) + if self.vec: + from pina.model.spline import SplineVectorized as VectorizedSpline + control_points = torch.randn(self.input_dimensions * self.output_dimensions, n_control_points) + print('control points', control_points.shape) + control_points = torch.stack([ + torch.randn(n_control_points) + for _ in range(self.input_dimensions * self.output_dimensions) + ]) + print('control points', control_points.shape) + self.spline_q = VectorizedSpline( + order=self.k, + knots=knots, + control_points=control_points + ) + + else: + spline_q = [] + for q in range(self.output_dimensions): + spline_p = [] + for p in range(self.input_dimensions): + spline_ = Spline( + order=self.k, + knots=knots, + control_points=torch.randn(n_control_points) + ) + spline_p.append(spline_) + spline_p = torch.nn.ModuleList(spline_p) + spline_q.append(spline_p) + self.spline_q = torch.nn.ModuleList(spline_q) # control_points = torch.nn.Parameter( @@ -99,16 +117,24 @@ def forward(self, x: torch.Tensor) -> torch.Tensor: x_tensor = x.tensor else: x_tensor = x + - y = [] - for q in range(self.output_dimensions): - y_q = [] - for p in range(self.input_dimensions): - spline_out = self.spline_q[q][p].forward(x_tensor[:, p]) # (batch, input_dimensions, output_dimensions) - base_out = self.base_function(x_tensor[:, p]) # (batch, input_dimensions) - y_q.append(spline_out + base_out) - y.append(torch.stack(y_q, dim=1).sum(dim=1)) - y = torch.stack(y, dim=1) + if self.vec: + y = self.spline_q.forward(x_tensor) # (batch, output_dimensions, input_dimensions) + y = y.reshape(y.shape[0], y.shape[1], self.output_dimensions, self.input_dimensions) + base_out = self.base_function(x_tensor) # (batch, input_dimensions) + y = y + base_out[:, :, None, None] + y = y.sum(dim=3).sum(dim=1) # sum over input dimensions + else: + y = [] + for q in range(self.output_dimensions): + y_q = [] + for p in range(self.input_dimensions): + spline_out = self.spline_q[q][p].forward(x_tensor[:, p]) # (batch, input_dimensions, output_dimensions) + base_out = self.base_function(x_tensor[:, p]) # (batch, input_dimensions) + y_q.append(spline_out + base_out) + y.append(torch.stack(y_q, dim=1).sum(dim=1)) + y = torch.stack(y, dim=1) return y diff --git a/pina/model/kolmogorov_arnold_network.py b/pina/model/kolmogorov_arnold_network.py index ad518b7bc..e798b677a 100644 --- a/pina/model/kolmogorov_arnold_network.py +++ b/pina/model/kolmogorov_arnold_network.py @@ -108,7 +108,7 @@ def forward(self, x: torch.Tensor) -> torch.Tensor: for i, layer in enumerate(self.kan_layers): current = layer(current) - current = torch.nn.functional.sigmoid(current) + # current = torch.nn.functional.sigmoid(current) if self.save_act: self.acts.append(current.detach()) diff --git a/pina/model/spline.py b/pina/model/spline.py index 7383e7268..46946b1f6 100644 --- a/pina/model/spline.py +++ b/pina/model/spline.py @@ -278,6 +278,7 @@ def forward(self, x): :rtype: torch.Tensor """ basis = self.basis(x.as_subclass(torch.Tensor)) + return basis @ self.control_points return torch.einsum( "...bi, ...i -> ...b", basis, @@ -473,7 +474,171 @@ def knots(self, value): # Recompute boundary interval when knots change if hasattr(self, "_boundary_interval_idx"): - self._boundary_interval_idx = self._compute_boundary_interval() + self._boundary_interval_Widx = self._compute_boundary_interval() # Recompute derivative denominators when knots change - self._compute_derivative_denominators() \ No newline at end of file + self._compute_derivative_denominators() + + +import torch +import torch.nn as nn + +class SplineVectorized(nn.Module): + """ + Vectorized univariate B-spline model (shared knots, many splines). + + Notation: + - knots: shape (m,) + - order: k (degree = k-1) + - n_ctrl = m - k + - control_points: + * (S, n_ctrl) -> S splines, scalar output each + * (S, O, n_ctrl) -> S splines, O outputs each (like multiple channels) + Input: + - x: shape (...,) or (..., B) + Output: + - if control_points is (S, n_ctrl): shape (..., S) + - if control_points is (S, O, n_ctrl): shape (..., S, O) + """ + + def __init__(self, order: int, knots: torch.Tensor, control_points: torch.Tensor | None = None): + super().__init__() + if not isinstance(order, int) or order <= 0: + raise ValueError("order must be a positive integer.") + if not isinstance(knots, torch.Tensor): + raise ValueError("knots must be a torch.Tensor.") + if knots.ndim != 1: + raise ValueError("knots must be 1D.") + + self.order = order + + # store sorted knots as buffer + knots_sorted = knots.sort().values + self.register_buffer("knots", knots_sorted) + + n_ctrl = knots_sorted.numel() - order + if n_ctrl <= 0: + raise ValueError(f"Need #knots > order. Got #knots={knots_sorted.numel()} order={order}.") + + # boundary interval idx for rightmost inclusion + self._boundary_interval_idx = self._compute_boundary_interval_idx(knots_sorted) + + # # control points init + # if control_points is None: + # # default: one spline + # cp = torch.zeros(1, n_ctrl, dtype=knots_sorted.dtype, device=knots_sorted.device) + # self.control_points = nn.Parameter(cp, requires_grad=True) + # else: + # if not isinstance(control_points, torch.Tensor): + # raise ValueError("control_points must be a torch.Tensor or None.") + # if control_points.ndim not in (2, 3): + # raise ValueError("control_points must have shape (S, n_ctrl) or (S, O, n_ctrl).") + # if control_points.shape[-1] != n_ctrl: + # raise ValueError( + # f"Last dim of control_points must be n_ctrl={n_ctrl}. Got {control_points.shape[-1]}." + # ) + self.control_points = nn.Parameter(control_points, requires_grad=True) + + @staticmethod + def _compute_boundary_interval_idx(knots: torch.Tensor) -> int: + if knots.numel() < 2: + return 0 + diffs = knots[1:] - knots[:-1] + valid = torch.nonzero(diffs > 0, as_tuple=False) + if valid.numel() == 0: + return 0 + return int(valid[-1]) + + def basis(self, x: torch.Tensor) -> torch.Tensor: + """ + Compute B-spline basis functions of order self.order at x. + + Returns: + basis: shape (..., n_ctrl) + """ + if not isinstance(x, torch.Tensor): + x = torch.as_tensor(x) + + # ensure float dtype consistent + x = x.to(dtype=self.knots.dtype, device=self.knots.device) + + # make x shape (..., 1) for broadcasting + x_exp = x.unsqueeze(-1) # (..., 1) + + # knots as (1, ..., 1, m) via unsqueeze to broadcast + # (m,) -> (1,)*x.ndim + (m,) + knots = self.knots.view(*([1] * x.ndim), -1) + + # order-1 base: indicator on intervals [t_i, t_{i+1}) + basis = ((x_exp >= knots[..., :-1]) & (x_exp < knots[..., 1:])).to(x_exp.dtype) # (..., m-1) + + # include rightmost boundary in the last non-degenerate interval + j = self._boundary_interval_idx + knot_left = knots[..., j] + knot_right = knots[..., j + 1] + at_right = (x >= knot_left.squeeze(-1)) & torch.isclose(x, knot_right.squeeze(-1), rtol=1e-8, atol=1e-10) + if torch.any(at_right): + basis_j = basis[..., j].bool() | at_right + basis[..., j] = basis_j.to(basis.dtype) + + # Cox-de Boor recursion up to order k + # after i-th iteration, basis has length (m-1 - i) + for i in range(1, self.order): + denom1 = knots[..., i:-1] - knots[..., :-(i + 1)] + denom2 = knots[..., i + 1:] - knots[..., 1:-i] + + denom1 = torch.where(denom1.abs() < 1e-8, torch.ones_like(denom1), denom1) + denom2 = torch.where(denom2.abs() < 1e-8, torch.ones_like(denom2), denom2) + + term1 = ((x_exp - knots[..., :-(i + 1)]) / denom1) * basis[..., :-1] + term2 = ((knots[..., i + 1:] - x_exp) / denom2) * basis[..., 1:] + basis = term1 + term2 + + # final basis length is n_ctrl = m - order + return basis # (..., n_ctrl) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + """ + Evaluate spline(s) at x. + + If control_points is (S, n_ctrl): output (..., S) + If control_points is (S, O, n_ctrl): output (..., S, O) + """ + B = self.basis(x) # (..., n_ctrl) + + cp = self.control_points + if cp.ndim == 2: + # (S, n_ctrl) + # want (..., S) = (..., n_ctrl) @ (n_ctrl, S) + out = B @ cp.transpose(0, 1) + return out + else: + # (S, O, n_ctrl) + # Compute for each S: (..., n_ctrl) @ (n_ctrl, O) -> (..., O), then stack over S + # vectorized using einsum (yes, this one is actually appropriate) + # (..., n) * (S, O, n) -> (..., S, O) + # out = torch.einsum("...n, son -> ...so", B, cp) + out = torch.einsum("bsc,sco->bso", B, cp) + + return out + + def forward_basis(self, basis): + """ + Evaluate spline(s) given precomputed basis. + + """ + cp = self.control_points + if cp.ndim == 2: + # (S, n_ctrl) + # want (..., S) = (..., n_ctrl) @ (n_ctrl, S) + out = basis @ cp.transpose(0, 1) + return out + else: + # (S, O, n_ctrl) + # Compute for each S: (..., n_ctrl) @ (n_ctrl, O) -> (..., O), then stack over S + # vectorized using einsum (yes, this one is actually appropriate) + # (..., n) * (S, O, n) -> (..., S, O) + # out = torch.einsum("...n, son -> ...so", B, cp) + out = torch.einsum("bsc,sco->bso", basis, cp) + + return out diff --git a/tests/test_model/test_spline.py b/tests/test_model/test_spline.py index b47ea8d30..1dbf8ac28 100644 --- a/tests/test_model/test_spline.py +++ b/tests/test_model/test_spline.py @@ -192,3 +192,41 @@ def test_derivative(args, pts): # Check shape and value assert first_der.shape == pts.shape assert torch.allclose(first_der, first_der_auto, atol=1e-4, rtol=1e-4) + + +@pytest.mark.parametrize("out_dim", [1, 3, 5]) +def test_vectorized(out_dim): + + N = 7 + cps = [] + splines = [] + for i in range(N): + cp = torch.rand(n_ctrl_pts, 3) + cps.append(cp) + spline = Spline( + order=order, + control_points=cp + ) + splines.append(spline) + + from pina.model.spline import SplineVectorized as VectorizedSpline + unique_cps = torch.stack(cps, dim=0) + print(unique_cps.shape) + print(cps[0].shape) + # Vectorized control points + vectorized_spline = VectorizedSpline( + order=order, + knots=splines[0].knots, + control_points=torch.stack(cps, dim=0) + ) + + x = torch.rand(100, 1) + + result_single = torch.stack([ + splines[i](x) for i in range(N) + ]) + result_single = result_single.squeeze(2).permute(1, 0, 2) + out_vectorized = vectorized_spline(x) + print(out_vectorized.shape) + print(result_single.shape) + assert torch.allclose(out_vectorized, result_single, atol=1e-5, rtol=1e-5) \ No newline at end of file