From 78128b10623a9cc51f4f7946bc30f0b5a8523035 Mon Sep 17 00:00:00 2001 From: Davide Miotti Date: Wed, 3 Dec 2025 17:57:18 +0100 Subject: [PATCH 01/10] implement autoregressive condition, time_weighting, solver --- pina/condition/__init__.py | 3 + pina/condition/autoregressive_condition.py | 91 ++++++++++++++++++ pina/loss/__init__.py | 10 ++ pina/loss/time_weighting.py | 57 ++++++++++++ pina/loss/time_weighting_interface.py | 24 +++++ pina/solver/__init__.py | 5 + pina/solver/autoregressive_solver/__init__.py | 4 + .../autoregressive_solver.py | 88 ++++++++++++++++++ .../autoregressive_solver_interface.py | 93 +++++++++++++++++++ 9 files changed, 375 insertions(+) create mode 100644 pina/condition/autoregressive_condition.py create mode 100644 pina/loss/time_weighting.py create mode 100644 pina/loss/time_weighting_interface.py create mode 100644 pina/solver/autoregressive_solver/__init__.py create mode 100644 pina/solver/autoregressive_solver/autoregressive_solver.py create mode 100644 pina/solver/autoregressive_solver/autoregressive_solver_interface.py diff --git a/pina/condition/__init__.py b/pina/condition/__init__.py index 4e57811fb..502c34ae9 100644 --- a/pina/condition/__init__.py +++ b/pina/condition/__init__.py @@ -15,6 +15,7 @@ "DataCondition", "GraphDataCondition", "TensorDataCondition", + "AutoregressiveCondition", ] from .condition_interface import ConditionInterface @@ -37,3 +38,5 @@ GraphDataCondition, TensorDataCondition, ) + +from .autoregressive_condition import AutoregressiveCondition diff --git a/pina/condition/autoregressive_condition.py b/pina/condition/autoregressive_condition.py new file mode 100644 index 000000000..1d48b638d --- /dev/null +++ b/pina/condition/autoregressive_condition.py @@ -0,0 +1,91 @@ +import torch +from .condition_interface import ConditionInterface +from ..loss import TimeWeightingInterface, ConstantTimeWeighting +from ..utils import check_consistency + + +class AutoregressiveCondition(ConditionInterface): + """ + A specialized condition for autoregressive tasks. + It generates input/unroll pairs from a single time-series tensor. + """ + + __slots__ = ["input", "unroll"] + + def __init__( + self, + data, + unroll_length, + num_unrolls=None, + randomize=True, + time_weighting=None, + ): + """ + Create an AutoregressiveCondition. + """ + super().__init__() + + self._n_timesteps, n_features = data.shape + self._unroll_length = unroll_length + self._requested_num_unrolls = num_unrolls + self._randomize = randomize + + # time weighting: weight the loss differently along the unroll + if time_weighting is None: + self._time_weighting = ConstantTimeWeighting() + else: + check_consistency(time_weighting, TimeWeightingInterface) + self._time_weighting = time_weighting + + # windows creation + initial_data = [] + unroll_data = [] + + for starting_index in self.starting_indices: + initial_data.append(data[starting_index]) + target_start = starting_index + 1 + unroll_data.append( + data[target_start : target_start + self._unroll_length, :] + ) + + self.input = torch.stack(initial_data) # [num_unrolls, features] + self.unroll = torch.stack( + unroll_data + ) # [num_unrolls, unroll_length, features] + + @property + def unroll_length(self): + return self._unroll_length + + @property + def time_weighting(self): + return self._time_weighting + + @property + def max_start_idx(self): + max_start_idx = self._n_timesteps - self._unroll_length + assert max_start_idx > 0, "Provided data sequence too short" + return max_start_idx + + @property + def num_unrolls(self): + if self._requested_num_unrolls is None: + return self.max_start_idx + else: + assert ( + self._requested_num_unrolls < self.max_start_idx + ), "too many samples requested" + return self._requested_num_unrolls + + @property + def starting_indices(self): + all_starting_indices = torch.arange(self.max_start_idx) + + if self._randomize: + perm = torch.randperm(len(all_starting_indices)) + return all_starting_indices[perm[: self.num_unrolls]] + else: + selected_indices = torch.linspace( + 0, len(all_starting_indices) - 1, self.num_unrolls + ).long() + return all_starting_indices[selected_indices] diff --git a/pina/loss/__init__.py b/pina/loss/__init__.py index d91cf7ab0..2d8ab288e 100644 --- a/pina/loss/__init__.py +++ b/pina/loss/__init__.py @@ -9,6 +9,10 @@ "NeuralTangentKernelWeighting", "SelfAdaptiveWeighting", "LinearWeighting", + "TimeWeightingInterface", + "ConstantTimeWeighting", + "ExponentialTimeWeighting", + "LinearTimeWeighting", ] from .loss_interface import LossInterface @@ -19,3 +23,9 @@ from .ntk_weighting import NeuralTangentKernelWeighting from .self_adaptive_weighting import SelfAdaptiveWeighting from .linear_weighting import LinearWeighting +from .time_weighting_interface import TimeWeightingInterface +from .time_weighting import ( + ConstantTimeWeighting, + ExponentialTimeWeighting, + LinearTimeWeighting, +) diff --git a/pina/loss/time_weighting.py b/pina/loss/time_weighting.py new file mode 100644 index 000000000..0b1d1ed65 --- /dev/null +++ b/pina/loss/time_weighting.py @@ -0,0 +1,57 @@ +"""Module for the Time Weighting.""" + +import torch +from .time_weighting_interface import TimeWeightingInterface + + +class ConstantTimeWeighting(TimeWeightingInterface): + """ + Weighting scheme that assigns equal weight to all time steps. + """ + + def __call__(self, num_steps, device): + return torch.ones(num_steps, device=device) / num_steps + + +class ExponentialTimeWeighting(TimeWeightingInterface): + """ + Weighting scheme change exponentially with time. + gamma > 1.0: increasing weights + 0 < gamma < 1.0: decreasing weights + weight at time t is gamma^t + """ + + def __init__(self, gamma=0.9): + """ + Initialization of the :class:`ExponentialTimeWeighting` class. + :param float gamma: The decay factor. Default is 0.9. + """ + self.gamma = gamma + + def __call__(self, num_steps, device): + steps = torch.arange(num_steps, device=device, dtype=torch.float32) + weights = self.gamma**steps + return weights / weights.sum() + + +class LinearTimeWeighting(TimeWeightingInterface): + """ + Weighting scheme that changes linearly from a start weight to an end weight. + """ + + def __init__(self, start=0.1, end=1.0): + """ + Initialization of the :class:`LinearDecayTimeWeighting` class. + + :param float start: The starting weight. Default is 0.1. + :param float end: The ending weight. Default is 1.0. + """ + self.start = start + self.end = end + + def __call__(self, num_steps, device): + if num_steps == 1: + return torch.ones(1, device=device) + + weights = torch.linspace(self.start, self.end, num_steps, device=device) + return weights / weights.sum() diff --git a/pina/loss/time_weighting_interface.py b/pina/loss/time_weighting_interface.py new file mode 100644 index 000000000..9d9781351 --- /dev/null +++ b/pina/loss/time_weighting_interface.py @@ -0,0 +1,24 @@ +"""Module for the Time Weighting Interface.""" + +from abc import ABCMeta, abstractmethod +import torch + + +class TimeWeightingInterface(metaclass=ABCMeta): + """ + Abstract base class for all time weighting schemas. All time weighting + schemas should inherit from this class. + """ + + @abstractmethod + def __call__(self, num_steps, device): + """ + Compute the weights for the time steps. + + :param int num_steps: The number of time steps. + :param torch.device device: The device on which the weights should be + created. + :return: The weights for the time steps. + :rtype: torch.Tensor + """ + pass diff --git a/pina/solver/__init__.py b/pina/solver/__init__.py index 43f18078f..e7d48e2b3 100644 --- a/pina/solver/__init__.py +++ b/pina/solver/__init__.py @@ -18,6 +18,7 @@ "DeepEnsembleSupervisedSolver", "DeepEnsemblePINN", "GAROM", + "AutoregressiveSolver", ] from .solver import SolverInterface, SingleSolverInterface, MultiSolverInterface @@ -41,3 +42,7 @@ DeepEnsemblePINN, ) from .garom import GAROM +from .autoregressive_solver import ( + AutoregressiveSolver, + AutoregressiveSolverInterface, +) diff --git a/pina/solver/autoregressive_solver/__init__.py b/pina/solver/autoregressive_solver/__init__.py new file mode 100644 index 000000000..9ef7c43e1 --- /dev/null +++ b/pina/solver/autoregressive_solver/__init__.py @@ -0,0 +1,4 @@ +__all__ = ["AutoregressiveSolver", "AutoregressiveSolverInterface"] + +from .autoregressive_solver import AutoregressiveSolver +from .autoregressive_solver_interface import AutoregressiveSolverInterface diff --git a/pina/solver/autoregressive_solver/autoregressive_solver.py b/pina/solver/autoregressive_solver/autoregressive_solver.py new file mode 100644 index 000000000..d0a46c310 --- /dev/null +++ b/pina/solver/autoregressive_solver/autoregressive_solver.py @@ -0,0 +1,88 @@ +import torch +from torch.nn.modules.loss import _Loss + +from pina.utils import check_consistency +from pina.solver.solver import SingleSolverInterface +from pina.condition import AutoregressiveCondition +from pina.loss import ( + LossInterface, + TimeWeightingInterface, + ConstantTimeWeighting, +) +from .autoregressive_solver_interface import AutoregressiveSolverInterface + + +class AutoregressiveSolver( + AutoregressiveSolverInterface, SingleSolverInterface +): + """ + Autoregressive Solver class. + """ + + accepted_conditions_types = AutoregressiveCondition + + def __init__( + self, + problem, + model, + loss=None, + optimizer=None, + scheduler=None, + weighting=None, + use_lt=False, + ): + """ + Initialization of the :class:`AutoregressiveSolver` class. + """ + super().__init__( + problem=problem, + model=model, + loss=loss, + optimizer=optimizer, + scheduler=scheduler, + weighting=weighting, + use_lt=use_lt, + ) + + def loss_data(self, input, target, unroll_length, time_weighting): + """ + Compute the data loss for the recursive autoregressive solver. + This will be applied to each condition individually. + """ + steps_to_predict = unroll_length - 1 + # weights are passed from the condition + weights = time_weighting(steps_to_predict, device=input.device) + + total_loss = 0.0 + current_state = input + + for step in range(steps_to_predict): + + predicted_next_state = self.forward( + current_state + ) # [batch_size, features] + actual_next_state = target[:, step, :] # [batch_size, features] + + step_loss = self.loss(predicted_next_state, actual_next_state) + + total_loss += step_loss * weights[step] + + current_state = predicted_next_state.detach() + + return total_loss + + def predict(self, initial_state, num_steps): + """ + Make recursive predictions starting from an initial state. + """ + self.eval() # Set model to evaluation mode + + current_state = initial_state + predictions = [current_state] # Store initial state without batch dim + with torch.no_grad(): + for step in range(num_steps): + next_state = self.forward(current_state) + predictions.append(next_state) # Keep batch dim for storage + current_state = next_state + + return torch.stack(predictions) diff --git a/pina/solver/autoregressive_solver/autoregressive_solver_interface.py b/pina/solver/autoregressive_solver/autoregressive_solver_interface.py new file mode 100644 index 000000000..e895705fe --- /dev/null +++ b/pina/solver/autoregressive_solver/autoregressive_solver_interface.py @@ -0,0 +1,93 @@ +"""Module for the Autoregressive solver interface.""" + +from abc import abstractmethod +import torch +from torch.nn.modules.loss import _Loss + +from ..solver import SolverInterface +from ...utils import check_consistency +from ...loss.loss_interface import LossInterface +from ...loss import TimeWeightingInterface, ConstantTimeWeighting +from ...condition import AutoregressiveCondition + + +class AutoregressiveSolverInterface(SolverInterface): + + accepted_conditions_types = AutoregressiveCondition + + def __init__(self, loss=None, **kwargs): + + if loss is None: + loss = torch.nn.MSELoss() + + super().__init__(**kwargs) + + check_consistency(loss, (LossInterface, _Loss), subclass=False) + self._loss_fn = loss + + def optimization_cycle(self, batch): + """ + Optimization cycle for this family of solvers. + Iterates over each conditions and each time applies the specialized loss_data function. + """ + + condition_loss = {} + for condition_name, points in batch: + condition = self.problem.conditions[condition_name] + + unroll_length = getattr(condition, "unroll_length", None) + time_weighting = getattr(condition, "time_weighting", None) + + if "unroll" in points: + loss = self.loss_data( + points["input"], + points["unroll"], + unroll_length, + time_weighting, + ) + condition_loss[condition_name] = loss + return condition_loss + + @abstractmethod + def loss_data(self, input, target, unroll_length, time_weighting): + """ + Computes the data loss for each condition. + N.B.: unroll_length and time_weighting are attributes of the condition. + + :param torch.Tensor input: Initial states. + :param torch.Tensor target: Target sequences. + :param int unroll_length: The number of steps to unroll (attribute of the condition). + :param TimeWeightingInterface time_weighting: The time weighting strategy (attribute of the condition). + :return: The average loss over all unroll steps. + :rtype: torch.Tensor + """ + pass + + @abstractmethod + def predict(self, initial_state, num_steps): + """ + Make recursive predictions starting from an initial state. + + :param torch.Tensor initial_state: Initial state tensor. + :param int num_steps: Number of steps to predict ahead. + :return: Tensor of predictions. + :rtype: torch.Tensor + """ + pass + + @property + def loss(self): + """ + The loss function to be minimized. + + :return: The loss function to be minimized. + :rtype: torch.nn.Module + """ + return self._loss_fn + + @property + def time_weighting(self): + """ + The time weighting strategy. + """ + return self._time_weighting From 76e592567c65590ad3cc6356660ddd8a6b9804f7 Mon Sep 17 00:00:00 2001 From: Davide Miotti Date: Tue, 16 Dec 2025 18:12:41 +0100 Subject: [PATCH 02/10] implement everything into solver --- autoregressive_prova_generic_condition.py | 149 ++++++++++++++++++ pina/condition/__init__.py | 3 - pina/condition/autoregressive_condition.py | 91 ----------- pina/loss/__init__.py | 10 -- pina/loss/time_weighting.py | 57 ------- pina/loss/time_weighting_interface.py | 24 --- .../autoregressive_solver.py | 142 +++++++++++++---- .../autoregressive_solver_interface.py | 54 +++---- 8 files changed, 285 insertions(+), 245 deletions(-) create mode 100644 autoregressive_prova_generic_condition.py delete mode 100644 pina/condition/autoregressive_condition.py delete mode 100644 pina/loss/time_weighting.py delete mode 100644 pina/loss/time_weighting_interface.py diff --git a/autoregressive_prova_generic_condition.py b/autoregressive_prova_generic_condition.py new file mode 100644 index 000000000..3c0796bbc --- /dev/null +++ b/autoregressive_prova_generic_condition.py @@ -0,0 +1,149 @@ +import torch +import matplotlib.pyplot as plt + +from pina import Trainer +from pina.optim import TorchOptimizer +from pina.problem import AbstractProblem +from pina.condition.data_condition import DataCondition +from pina.solver import AutoregressiveSolver + +NUM_TIMESTEPS = 100 +NUM_FEATURES = 15 +USE_TEST_MODEL = False + +# ============================================================================ +# DATA +# ============================================================================ + +torch.manual_seed(42) + +y = torch.zeros(NUM_TIMESTEPS, NUM_FEATURES) +y[0] = torch.rand(NUM_FEATURES) # Random initial state + +for t in range(NUM_TIMESTEPS - 1): + y[t + 1] = 0.95 * y[t] # + 0.05 * torch.sin(y[t].sum()) + +# ============================================================================ +# TRAINING +# ============================================================================ + +class SimpleModel(torch.nn.Module): + def __init__(self): + super().__init__() + self.layers = torch.nn.Sequential( + torch.nn.Linear(y.shape[1], 20), + torch.nn.ReLU(), + torch.nn.Dropout(0.2), + torch.nn.Linear(20, y.shape[1]), + ) + + def forward(self, x): + return x + self.layers(x) + + +class TestModel(torch.nn.Module): + """ + Debug model that implements the EXACT transformation rule. + y[t+1] = 0.95 * y[t] + Expected loss is zero + """ + + def __init__(self, data_series=None): + super().__init__() + self.dummy_param = torch.nn.Parameter(torch.zeros(1)) + + def forward(self, x): + next_state = 0.95 * x # + 0.05 * torch.sin(x.sum(dim=1, keepdim=True)) + return next_state + 0.0 * self.dummy_param + + +class Problem(AbstractProblem): + output_variables = None + input_variables = None + conditions = { + "data_condition_0":DataCondition(input=y), + "data_condition_1":DataCondition(input=y), + } + +problem = Problem() + +#for each condition, define unroll instructions with these keys: +# - unroll_length: length of each unroll window +# - num_unrolls: number of unroll windows to create (if None, use all possible) +# - randomize: whether to randomize the starting indices of the unroll windows +unroll_instructions = { + "data_condition_0": { + "unroll_length": 10, + "num_unrolls": 89, + "randomize": True, + "eps": 5.0 + }, + "data_condition_1": { + "unroll_length": 20, + "num_unrolls": 79, + "randomize": True, + "eps": 10.0 + }, +} + +solver = AutoregressiveSolver( + unroll_instructions=unroll_instructions, + problem=problem, + model=TestModel() if USE_TEST_MODEL else SimpleModel(), + optimizer= TorchOptimizer(torch.optim.AdamW, lr=0.01), + eps=10.0, +) + +trainer = Trainer( + solver, max_epochs=2000, accelerator="cpu", enable_model_summary=False, shuffle=False +) +trainer.train() + +# ============================================================================ +# VISUALIZATION +# ============================================================================ + +test_start_idx = 50 +num_prediction_steps = 30 + +initial_state = y[test_start_idx] # Shape: [features] +predictions = solver.predict(initial_state, num_prediction_steps) +actual = y[test_start_idx : test_start_idx + num_prediction_steps + 1] + +total_mse = torch.nn.functional.mse_loss(predictions[1:], actual[1:]) +print(f"\nOverall MSE (all {num_prediction_steps} steps): {total_mse:.6f}") + +# viauzlize single dof +dof_to_plot = [0, 3, 6, 9, 12] +colors = [ + "r", + "g", + "b", + "c", + "m", + "y", + "k", +] +plt.figure(figsize=(10, 6)) +for dof, color in zip(dof_to_plot, colors): + plt.plot( + range(test_start_idx, test_start_idx + num_prediction_steps + 1), + actual[:, dof].numpy(), + label="Actual", + marker="o", + color=color, + markerfacecolor="none", + ) + plt.plot( + range(test_start_idx, test_start_idx + num_prediction_steps + 1), + predictions[:, dof].numpy(), + label="Predicted", + marker="x", + color=color, + ) + +plt.title(f"Autoregressive Predictions vs Actual, MRSE: {total_mse:.6f}") +plt.legend() +plt.xlabel("Timestep") +plt.savefig(f"autoregressive_predictions.png") +plt.close() diff --git a/pina/condition/__init__.py b/pina/condition/__init__.py index 502c34ae9..4e57811fb 100644 --- a/pina/condition/__init__.py +++ b/pina/condition/__init__.py @@ -15,7 +15,6 @@ "DataCondition", "GraphDataCondition", "TensorDataCondition", - "AutoregressiveCondition", ] from .condition_interface import ConditionInterface @@ -38,5 +37,3 @@ GraphDataCondition, TensorDataCondition, ) - -from .autoregressive_condition import AutoregressiveCondition diff --git a/pina/condition/autoregressive_condition.py b/pina/condition/autoregressive_condition.py deleted file mode 100644 index 1d48b638d..000000000 --- a/pina/condition/autoregressive_condition.py +++ /dev/null @@ -1,91 +0,0 @@ -import torch -from .condition_interface import ConditionInterface -from ..loss import TimeWeightingInterface, ConstantTimeWeighting -from ..utils import check_consistency - - -class AutoregressiveCondition(ConditionInterface): - """ - A specialized condition for autoregressive tasks. - It generates input/unroll pairs from a single time-series tensor. - """ - - __slots__ = ["input", "unroll"] - - def __init__( - self, - data, - unroll_length, - num_unrolls=None, - randomize=True, - time_weighting=None, - ): - """ - Create an AutoregressiveCondition. - """ - super().__init__() - - self._n_timesteps, n_features = data.shape - self._unroll_length = unroll_length - self._requested_num_unrolls = num_unrolls - self._randomize = randomize - - # time weighting: weight the loss differently along the unroll - if time_weighting is None: - self._time_weighting = ConstantTimeWeighting() - else: - check_consistency(time_weighting, TimeWeightingInterface) - self._time_weighting = time_weighting - - # windows creation - initial_data = [] - unroll_data = [] - - for starting_index in self.starting_indices: - initial_data.append(data[starting_index]) - target_start = starting_index + 1 - unroll_data.append( - data[target_start : target_start + self._unroll_length, :] - ) - - self.input = torch.stack(initial_data) # [num_unrolls, features] - self.unroll = torch.stack( - unroll_data - ) # [num_unrolls, unroll_length, features] - - @property - def unroll_length(self): - return self._unroll_length - - @property - def time_weighting(self): - return self._time_weighting - - @property - def max_start_idx(self): - max_start_idx = self._n_timesteps - self._unroll_length - assert max_start_idx > 0, "Provided data sequence too short" - return max_start_idx - - @property - def num_unrolls(self): - if self._requested_num_unrolls is None: - return self.max_start_idx - else: - assert ( - self._requested_num_unrolls < self.max_start_idx - ), "too many samples requested" - return self._requested_num_unrolls - - @property - def starting_indices(self): - all_starting_indices = torch.arange(self.max_start_idx) - - if self._randomize: - perm = torch.randperm(len(all_starting_indices)) - return all_starting_indices[perm[: self.num_unrolls]] - else: - selected_indices = torch.linspace( - 0, len(all_starting_indices) - 1, self.num_unrolls - ).long() - return all_starting_indices[selected_indices] diff --git a/pina/loss/__init__.py b/pina/loss/__init__.py index 2d8ab288e..d91cf7ab0 100644 --- a/pina/loss/__init__.py +++ b/pina/loss/__init__.py @@ -9,10 +9,6 @@ "NeuralTangentKernelWeighting", "SelfAdaptiveWeighting", "LinearWeighting", - "TimeWeightingInterface", - "ConstantTimeWeighting", - "ExponentialTimeWeighting", - "LinearTimeWeighting", ] from .loss_interface import LossInterface @@ -23,9 +19,3 @@ from .ntk_weighting import NeuralTangentKernelWeighting from .self_adaptive_weighting import SelfAdaptiveWeighting from .linear_weighting import LinearWeighting -from .time_weighting_interface import TimeWeightingInterface -from .time_weighting import ( - ConstantTimeWeighting, - ExponentialTimeWeighting, - LinearTimeWeighting, -) diff --git a/pina/loss/time_weighting.py b/pina/loss/time_weighting.py deleted file mode 100644 index 0b1d1ed65..000000000 --- a/pina/loss/time_weighting.py +++ /dev/null @@ -1,57 +0,0 @@ -"""Module for the Time Weighting.""" - -import torch -from .time_weighting_interface import TimeWeightingInterface - - -class ConstantTimeWeighting(TimeWeightingInterface): - """ - Weighting scheme that assigns equal weight to all time steps. - """ - - def __call__(self, num_steps, device): - return torch.ones(num_steps, device=device) / num_steps - - -class ExponentialTimeWeighting(TimeWeightingInterface): - """ - Weighting scheme change exponentially with time. - gamma > 1.0: increasing weights - 0 < gamma < 1.0: decreasing weights - weight at time t is gamma^t - """ - - def __init__(self, gamma=0.9): - """ - Initialization of the :class:`ExponentialTimeWeighting` class. - :param float gamma: The decay factor. Default is 0.9. - """ - self.gamma = gamma - - def __call__(self, num_steps, device): - steps = torch.arange(num_steps, device=device, dtype=torch.float32) - weights = self.gamma**steps - return weights / weights.sum() - - -class LinearTimeWeighting(TimeWeightingInterface): - """ - Weighting scheme that changes linearly from a start weight to an end weight. - """ - - def __init__(self, start=0.1, end=1.0): - """ - Initialization of the :class:`LinearDecayTimeWeighting` class. - - :param float start: The starting weight. Default is 0.1. - :param float end: The ending weight. Default is 1.0. - """ - self.start = start - self.end = end - - def __call__(self, num_steps, device): - if num_steps == 1: - return torch.ones(1, device=device) - - weights = torch.linspace(self.start, self.end, num_steps, device=device) - return weights / weights.sum() diff --git a/pina/loss/time_weighting_interface.py b/pina/loss/time_weighting_interface.py deleted file mode 100644 index 9d9781351..000000000 --- a/pina/loss/time_weighting_interface.py +++ /dev/null @@ -1,24 +0,0 @@ -"""Module for the Time Weighting Interface.""" - -from abc import ABCMeta, abstractmethod -import torch - - -class TimeWeightingInterface(metaclass=ABCMeta): - """ - Abstract base class for all time weighting schemas. All time weighting - schemas should inherit from this class. - """ - - @abstractmethod - def __call__(self, num_steps, device): - """ - Compute the weights for the time steps. - - :param int num_steps: The number of time steps. - :param torch.device device: The device on which the weights should be - created. - :return: The weights for the time steps. - :rtype: torch.Tensor - """ - pass diff --git a/pina/solver/autoregressive_solver/autoregressive_solver.py b/pina/solver/autoregressive_solver/autoregressive_solver.py index d0a46c310..0606a3fd6 100644 --- a/pina/solver/autoregressive_solver/autoregressive_solver.py +++ b/pina/solver/autoregressive_solver/autoregressive_solver.py @@ -1,14 +1,7 @@ import torch -from torch.nn.modules.loss import _Loss - from pina.utils import check_consistency from pina.solver.solver import SingleSolverInterface -from pina.condition import AutoregressiveCondition -from pina.loss import ( - LossInterface, - TimeWeightingInterface, - ConstantTimeWeighting, -) +from pina.condition import DataCondition from .autoregressive_solver_interface import AutoregressiveSolverInterface @@ -19,12 +12,14 @@ class AutoregressiveSolver( Autoregressive Solver class. """ - accepted_conditions_types = AutoregressiveCondition + accepted_conditions_types = DataCondition def __init__( self, + unroll_instructions, problem, model, + eps=None, loss=None, optimizer=None, scheduler=None, @@ -33,8 +28,19 @@ def __init__( ): """ Initialization of the :class:`AutoregressiveSolver` class. + :param dict unroll_instructions: A dictionary specifying how to unroll each condition. + this is supposed to map condition names to dict objects with unroll instructions. + :param AbstractProblem problem: The problem to be solved. + :param torch.nn.Module model: The model to be trained. + :param torch.nn.Module or LossInterface or None loss: The loss function to be minimized. If None, defaults to MSELoss. + :param TorchOptimizer or None optimizer: The optimizer to be used. If None, no optimization is performed. + :param TorchScheduler or None scheduler: The learning rate scheduler to be used. If None, no scheduling is performed. + :param Weighting or None weighting: The weighting scheme for combining losses from different conditions. If None, equal weighting is applied. + :param bool use_lt: Whether to use learning rate tuning. """ + super().__init__( + unroll_instructions=unroll_instructions, problem=problem, model=model, loss=loss, @@ -44,45 +50,123 @@ def __init__( use_lt=use_lt, ) - def loss_data(self, input, target, unroll_length, time_weighting): + def loss_data(self, data, condition_unroll_instructions): """ Compute the data loss for the recursive autoregressive solver. This will be applied to each condition individually. + :param torch.Tensor data: all training data. + :param dict condition_unroll_instructions: instructions on how to unroll the model for this condition. + :return: Computed loss value. + :rtype: torch.Tensor """ - steps_to_predict = unroll_length - 1 - # weights are passed from the condition - weights = time_weighting(steps_to_predict, device=input.device) - total_loss = 0.0 - current_state = input + initial_data, unroll_data = self.create_unroll_windows( + data, condition_unroll_instructions + ) + + unroll_length = condition_unroll_instructions["unroll_length"] + current_state = initial_data # [num_unrolls, features] + + losses = [] + for step in range(unroll_length): + + predicted_state = self.forward(current_state) # [num_unrolls, features] + target_state = unroll_data[:, step, :] # [num_unrolls, features] + step_loss = self._loss_fn(predicted_state, target_state) + losses.append(step_loss) + current_state = predicted_state + + step_losses = torch.stack(losses) # [unroll_length] + + with torch.no_grad(): + weights = self.compute_adaptive_weights(step_losses.detach(), condition_unroll_instructions) + + weighted_loss = (step_losses * weights).sum() + return weighted_loss - for step in range(steps_to_predict): + def create_unroll_windows(self, data, condition_unroll_instructions): + """ + Create unroll windows for each condition from the data based on the provided instructions. + :param torch.Tensor data: The full data tensor. + :param dict condition_unroll_instructions: Instructions on how to unroll the model for this condition. + :return: Tuple of initial data and unroll data tensors. + :rtype: (torch.Tensor, torch.Tensor) + """ - predicted_next_state = self.forward( - current_state - ) # [batch_size, features] - actual_next_state = target[:, step, :] # [batch_size, features] + unroll_length = condition_unroll_instructions["unroll_length"] + + start_list = [] + unroll_list = [] + for starting_index in self.decide_starting_indices( + data, condition_unroll_instructions + ): + idx = starting_index.item() + start = data[idx] + target_start = idx + 1 + unroll = data[target_start : target_start + unroll_length, :] + start_list.append(start) + unroll_list.append(unroll) + initial_data = torch.stack(start_list) # [num_unrolls, features] + unroll_data = torch.stack(unroll_list) # [num_unrolls, unroll_length, features] + return initial_data, unroll_data + + def decide_starting_indices(self, data, condition_unroll_instructions): + """ + Decide the starting indices for unrolling based on the provided instructions. + :param torch.Tensor data: The full data tensor. + :param dict condition_unroll_instructions: Instructions on how to unroll the model for this condition. + :return: Tensor of starting indices. + :rtype: torch.Tensor + """ + n_step, n_features = data.shape + num_unrolls = condition_unroll_instructions.get("num_unrolls", None) + unroll_length = condition_unroll_instructions["unroll_length"] + randomize = condition_unroll_instructions.get("randomize", True) - step_loss = self.loss(predicted_next_state, actual_next_state) + max_start = n_step - unroll_length + indices = torch.arange(max_start, device=data.device) - total_loss += step_loss * weights[step] + if num_unrolls is not None and num_unrolls < len(indices): + indices = indices[:num_unrolls] - current_state = predicted_next_state.detach() + if randomize: + indices = indices[torch.randperm(len(indices), device=data.device)] - return total_loss + return indices + + def compute_adaptive_weights(self, step_losses, condition_unroll_instructions): + """ + Compute adaptive weights for each time step based on cumulative losses. + :param torch.Tensor step_losses: Tensor of shape [unroll_length] containing losses at each time step. + :return: Tensor of shape [unroll_length] containing normalized weights. + :rtype: torch.Tensor + """ + num_steps = len(step_losses) + eps = condition_unroll_instructions.get("eps", None) + if eps is None: + weights = torch.ones_like(step_losses) + else: + weights = torch.exp(-eps * torch.cumsum(step_losses, dim=0)) + + return weights / weights.sum() def predict(self, initial_state, num_steps): """ Make recursive predictions starting from an initial state. + :param torch.Tensor initial_state: Initial state tensor. + :param int num_steps: Number of steps to predict ahead. + :return: Tensor of predictions. + :rtype: torch.Tensor """ self.eval() # Set model to evaluation mode - + current_state = initial_state - predictions = [current_state] # Store initial state without batch dim + predictions = [current_state] + with torch.no_grad(): for step in range(num_steps): next_state = self.forward(current_state) - predictions.append(next_state) # Keep batch dim for storage + predictions.append(next_state) current_state = next_state - - return torch.stack(predictions) + + return torch.stack(predictions) \ No newline at end of file diff --git a/pina/solver/autoregressive_solver/autoregressive_solver_interface.py b/pina/solver/autoregressive_solver/autoregressive_solver_interface.py index e895705fe..d0a6f919a 100644 --- a/pina/solver/autoregressive_solver/autoregressive_solver_interface.py +++ b/pina/solver/autoregressive_solver/autoregressive_solver_interface.py @@ -7,58 +7,57 @@ from ..solver import SolverInterface from ...utils import check_consistency from ...loss.loss_interface import LossInterface -from ...loss import TimeWeightingInterface, ConstantTimeWeighting -from ...condition import AutoregressiveCondition +from ...condition import DataCondition class AutoregressiveSolverInterface(SolverInterface): - accepted_conditions_types = AutoregressiveCondition + def __init__(self, unroll_instructions, loss=None, **kwargs): + """ + Initialization of the :class:`AutoregressiveSolverInterface` class. + :param dict unroll_instructions: A dictionary specifying how to unroll each condition. + this is supposed to map condition names to dict objects with unroll instructions. + :param loss: The loss function to be minimized. If None, defaults to MSELoss. + :type loss: torch.nn.Module or LossInterface, optional + """ - def __init__(self, loss=None, **kwargs): + super().__init__(**kwargs) if loss is None: loss = torch.nn.MSELoss() - super().__init__(**kwargs) - check_consistency(loss, (LossInterface, _Loss), subclass=False) self._loss_fn = loss + self._unroll_instructions = unroll_instructions def optimization_cycle(self, batch): """ Optimization cycle for this family of solvers. Iterates over each conditions and each time applies the specialized loss_data function. + :param dict batch: A dictionary mapping condition names to data batches. + :return: A dictionary mapping condition names to computed loss values. + :rtype: dict """ condition_loss = {} for condition_name, points in batch: - condition = self.problem.conditions[condition_name] - - unroll_length = getattr(condition, "unroll_length", None) - time_weighting = getattr(condition, "time_weighting", None) - - if "unroll" in points: - loss = self.loss_data( + condition_unroll_instructions = self._unroll_instructions[condition_name] + loss = self.loss_data( points["input"], - points["unroll"], - unroll_length, - time_weighting, + condition_unroll_instructions, ) condition_loss[condition_name] = loss return condition_loss @abstractmethod - def loss_data(self, input, target, unroll_length, time_weighting): + def loss_data(self, input, condition_unroll_instructions): """ Computes the data loss for each condition. - N.B.: unroll_length and time_weighting are attributes of the condition. + N.B.: This loss_data function must make use of unroll_instructions to know how to unroll the model. - :param torch.Tensor input: Initial states. - :param torch.Tensor target: Target sequences. - :param int unroll_length: The number of steps to unroll (attribute of the condition). - :param TimeWeightingInterface time_weighting: The time weighting strategy (attribute of the condition). - :return: The average loss over all unroll steps. + :param torch.Tensor input: all training data. + :param dict condition_unroll_instructions: instructions on how to unroll the model for this condition. + :return: Computed loss value. :rtype: torch.Tensor """ pass @@ -83,11 +82,4 @@ def loss(self): :return: The loss function to be minimized. :rtype: torch.nn.Module """ - return self._loss_fn - - @property - def time_weighting(self): - """ - The time weighting strategy. - """ - return self._time_weighting + return self._loss_fn \ No newline at end of file From bb2f925f7388dd89ef40a56a169196ded3699954 Mon Sep 17 00:00:00 2001 From: Davide Miotti Date: Mon, 12 Jan 2026 18:21:42 +0100 Subject: [PATCH 03/10] add dataclass for managing unroll settings --- autoregressive_prova_generic_condition.py | 35 ++++----- pina/solver/__init__.py | 1 + pina/solver/autoregressive_solver/__init__.py | 1 + .../autoregressive_solver.py | 74 +++++++------------ .../autoregressive_solver_interface.py | 28 +++++-- 5 files changed, 69 insertions(+), 70 deletions(-) diff --git a/autoregressive_prova_generic_condition.py b/autoregressive_prova_generic_condition.py index 3c0796bbc..4812048fb 100644 --- a/autoregressive_prova_generic_condition.py +++ b/autoregressive_prova_generic_condition.py @@ -5,7 +5,7 @@ from pina.optim import TorchOptimizer from pina.problem import AbstractProblem from pina.condition.data_condition import DataCondition -from pina.solver import AutoregressiveSolver +from pina.solver import AutoregressiveSolver,UnrollInstructions NUM_TIMESTEPS = 100 NUM_FEATURES = 15 @@ -71,27 +71,28 @@ class Problem(AbstractProblem): # - unroll_length: length of each unroll window # - num_unrolls: number of unroll windows to create (if None, use all possible) # - randomize: whether to randomize the starting indices of the unroll windows -unroll_instructions = { - "data_condition_0": { - "unroll_length": 10, - "num_unrolls": 89, - "randomize": True, - "eps": 5.0 - }, - "data_condition_1": { - "unroll_length": 20, - "num_unrolls": 79, - "randomize": True, - "eps": 10.0 - }, -} +unroll_instructions_list = [ + UnrollInstructions( + condition_name="data_condition_0", + unroll_length=10, + num_unrolls=89, + randomize=True, + eps=5.0 + ), + UnrollInstructions( + condition_name="data_condition_1", + unroll_length=20, + num_unrolls=79, + randomize=True, + eps=10.0 + ), +] solver = AutoregressiveSolver( - unroll_instructions=unroll_instructions, + unroll_instructions_list=unroll_instructions_list, problem=problem, model=TestModel() if USE_TEST_MODEL else SimpleModel(), optimizer= TorchOptimizer(torch.optim.AdamW, lr=0.01), - eps=10.0, ) trainer = Trainer( diff --git a/pina/solver/__init__.py b/pina/solver/__init__.py index e7d48e2b3..8494df8b0 100644 --- a/pina/solver/__init__.py +++ b/pina/solver/__init__.py @@ -45,4 +45,5 @@ from .autoregressive_solver import ( AutoregressiveSolver, AutoregressiveSolverInterface, + UnrollInstructions, ) diff --git a/pina/solver/autoregressive_solver/__init__.py b/pina/solver/autoregressive_solver/__init__.py index 9ef7c43e1..ac0d60a12 100644 --- a/pina/solver/autoregressive_solver/__init__.py +++ b/pina/solver/autoregressive_solver/__init__.py @@ -2,3 +2,4 @@ from .autoregressive_solver import AutoregressiveSolver from .autoregressive_solver_interface import AutoregressiveSolverInterface +from .autoregressive_solver_interface import UnrollInstructions diff --git a/pina/solver/autoregressive_solver/autoregressive_solver.py b/pina/solver/autoregressive_solver/autoregressive_solver.py index 0606a3fd6..0a31d1ae2 100644 --- a/pina/solver/autoregressive_solver/autoregressive_solver.py +++ b/pina/solver/autoregressive_solver/autoregressive_solver.py @@ -3,7 +3,8 @@ from pina.solver.solver import SingleSolverInterface from pina.condition import DataCondition from .autoregressive_solver_interface import AutoregressiveSolverInterface - +from .autoregressive_solver_interface import UnrollInstructions +from typing import List class AutoregressiveSolver( AutoregressiveSolverInterface, SingleSolverInterface @@ -16,10 +17,9 @@ class AutoregressiveSolver( def __init__( self, - unroll_instructions, + unroll_instructions_list:List[UnrollInstructions], problem, model, - eps=None, loss=None, optimizer=None, scheduler=None, @@ -28,7 +28,7 @@ def __init__( ): """ Initialization of the :class:`AutoregressiveSolver` class. - :param dict unroll_instructions: A dictionary specifying how to unroll each condition. + :param list unroll_instructions_list: A list of UnrollInstructions, one for each condition. this is supposed to map condition names to dict objects with unroll instructions. :param AbstractProblem problem: The problem to be solved. :param torch.nn.Module model: The model to be trained. @@ -40,7 +40,7 @@ def __init__( """ super().__init__( - unroll_instructions=unroll_instructions, + unroll_instructions_list=unroll_instructions_list, problem=problem, model=model, loss=loss, @@ -50,25 +50,23 @@ def __init__( use_lt=use_lt, ) - def loss_data(self, data, condition_unroll_instructions): + def loss_data(self, data, unroll_instructions:UnrollInstructions): """ Compute the data loss for the recursive autoregressive solver. This will be applied to each condition individually. :param torch.Tensor data: all training data. - :param dict condition_unroll_instructions: instructions on how to unroll the model for this condition. + :param UnrollInstructions unroll_instructions: instructions on how to unroll the model for this condition. :return: Computed loss value. :rtype: torch.Tensor """ initial_data, unroll_data = self.create_unroll_windows( - data, condition_unroll_instructions + data, unroll_instructions ) - - unroll_length = condition_unroll_instructions["unroll_length"] current_state = initial_data # [num_unrolls, features] losses = [] - for step in range(unroll_length): + for step in range(unroll_instructions.unroll_length): predicted_state = self.forward(current_state) # [num_unrolls, features] target_state = unroll_data[:, step, :] # [num_unrolls, features] @@ -79,76 +77,60 @@ def loss_data(self, data, condition_unroll_instructions): step_losses = torch.stack(losses) # [unroll_length] with torch.no_grad(): - weights = self.compute_adaptive_weights(step_losses.detach(), condition_unroll_instructions) + eps = unroll_instructions.eps + if eps is None: + weights = torch.ones_like(step_losses) + else: + weights = torch.exp(-eps * torch.cumsum(step_losses, dim=0)) + weights = weights / weights.sum() - weighted_loss = (step_losses * weights).sum() - return weighted_loss + return (step_losses * weights).sum() - def create_unroll_windows(self, data, condition_unroll_instructions): + def create_unroll_windows(self, data, unroll_instructions:UnrollInstructions): """ Create unroll windows for each condition from the data based on the provided instructions. :param torch.Tensor data: The full data tensor. - :param dict condition_unroll_instructions: Instructions on how to unroll the model for this condition. + :param UnrollInstructions unroll_instructions: Instructions on how to unroll the model for this condition. :return: Tuple of initial data and unroll data tensors. :rtype: (torch.Tensor, torch.Tensor) """ - unroll_length = condition_unroll_instructions["unroll_length"] + unroll_length = unroll_instructions.unroll_length start_list = [] unroll_list = [] for starting_index in self.decide_starting_indices( - data, condition_unroll_instructions + data, unroll_instructions ): idx = starting_index.item() - start = data[idx] - target_start = idx + 1 - unroll = data[target_start : target_start + unroll_length, :] - start_list.append(start) - unroll_list.append(unroll) + start_list.append(data[idx]) + unroll_list.append(data[idx+1 : idx+1+unroll_length, :]) + initial_data = torch.stack(start_list) # [num_unrolls, features] unroll_data = torch.stack(unroll_list) # [num_unrolls, unroll_length, features] return initial_data, unroll_data - def decide_starting_indices(self, data, condition_unroll_instructions): + def decide_starting_indices(self, data, unroll_instructions:UnrollInstructions): """ Decide the starting indices for unrolling based on the provided instructions. :param torch.Tensor data: The full data tensor. - :param dict condition_unroll_instructions: Instructions on how to unroll the model for this condition. + :param UnrollInstructions unroll_instructions: Instructions on how to unroll the model for this condition. :return: Tensor of starting indices. :rtype: torch.Tensor """ n_step, n_features = data.shape - num_unrolls = condition_unroll_instructions.get("num_unrolls", None) - unroll_length = condition_unroll_instructions["unroll_length"] - randomize = condition_unroll_instructions.get("randomize", True) + num_unrolls = unroll_instructions.num_unrolls - max_start = n_step - unroll_length + max_start = n_step - unroll_instructions.unroll_length indices = torch.arange(max_start, device=data.device) if num_unrolls is not None and num_unrolls < len(indices): indices = indices[:num_unrolls] - if randomize: + if unroll_instructions.randomize: indices = indices[torch.randperm(len(indices), device=data.device)] return indices - - def compute_adaptive_weights(self, step_losses, condition_unroll_instructions): - """ - Compute adaptive weights for each time step based on cumulative losses. - :param torch.Tensor step_losses: Tensor of shape [unroll_length] containing losses at each time step. - :return: Tensor of shape [unroll_length] containing normalized weights. - :rtype: torch.Tensor - """ - num_steps = len(step_losses) - eps = condition_unroll_instructions.get("eps", None) - if eps is None: - weights = torch.ones_like(step_losses) - else: - weights = torch.exp(-eps * torch.cumsum(step_losses, dim=0)) - - return weights / weights.sum() def predict(self, initial_state, num_steps): """ diff --git a/pina/solver/autoregressive_solver/autoregressive_solver_interface.py b/pina/solver/autoregressive_solver/autoregressive_solver_interface.py index d0a6f919a..bf6a67462 100644 --- a/pina/solver/autoregressive_solver/autoregressive_solver_interface.py +++ b/pina/solver/autoregressive_solver/autoregressive_solver_interface.py @@ -3,19 +3,29 @@ from abc import abstractmethod import torch from torch.nn.modules.loss import _Loss +from dataclasses import dataclass from ..solver import SolverInterface from ...utils import check_consistency from ...loss.loss_interface import LossInterface from ...condition import DataCondition +from typing import Optional +@dataclass +class UnrollInstructions: + """Instructions for unrolling a single condition.""" + condition_name: str + unroll_length: int + num_unrolls: Optional[int] = None + randomize: bool = True + eps: Optional[float] = None class AutoregressiveSolverInterface(SolverInterface): - def __init__(self, unroll_instructions, loss=None, **kwargs): + def __init__(self, unroll_instructions_list, loss=None, **kwargs): """ Initialization of the :class:`AutoregressiveSolverInterface` class. - :param dict unroll_instructions: A dictionary specifying how to unroll each condition. + :param list unroll_instructions: A list of UnrollInstructions, one for each condition. this is supposed to map condition names to dict objects with unroll instructions. :param loss: The loss function to be minimized. If None, defaults to MSELoss. :type loss: torch.nn.Module or LossInterface, optional @@ -28,7 +38,7 @@ def __init__(self, unroll_instructions, loss=None, **kwargs): check_consistency(loss, (LossInterface, _Loss), subclass=False) self._loss_fn = loss - self._unroll_instructions = unroll_instructions + self._unroll_instructions_list = unroll_instructions_list def optimization_cycle(self, batch): """ @@ -41,22 +51,26 @@ def optimization_cycle(self, batch): condition_loss = {} for condition_name, points in batch: - condition_unroll_instructions = self._unroll_instructions[condition_name] + #find unroll instructions for this condition + unroll_instructions = next( + ui for ui in self._unroll_instructions_list + if ui.condition_name == condition_name + ) loss = self.loss_data( points["input"], - condition_unroll_instructions, + unroll_instructions, ) condition_loss[condition_name] = loss return condition_loss @abstractmethod - def loss_data(self, input, condition_unroll_instructions): + def loss_data(self, input, unroll_instructions:UnrollInstructions): """ Computes the data loss for each condition. N.B.: This loss_data function must make use of unroll_instructions to know how to unroll the model. :param torch.Tensor input: all training data. - :param dict condition_unroll_instructions: instructions on how to unroll the model for this condition. + :param UnrollInstruction unroll_instructions: instructions on how to unroll the model for this condition. :return: Computed loss value. :rtype: torch.Tensor """ From 1a3bfa8aea816de2f64e16b9fa0080124c9720a8 Mon Sep 17 00:00:00 2001 From: Davide Miotti Date: Tue, 13 Jan 2026 14:52:06 +0100 Subject: [PATCH 04/10] add docstings and tests --- autoregressive_prova_generic_condition.py | 150 ------------ .../autoregressive_solver.py | 157 +++++++++---- .../autoregressive_solver_interface.py | 109 ++++++--- .../test_solver/test_autoregressive_solver.py | 213 ++++++++++++++++++ 4 files changed, 410 insertions(+), 219 deletions(-) delete mode 100644 autoregressive_prova_generic_condition.py create mode 100644 tests/test_solver/test_autoregressive_solver.py diff --git a/autoregressive_prova_generic_condition.py b/autoregressive_prova_generic_condition.py deleted file mode 100644 index 4812048fb..000000000 --- a/autoregressive_prova_generic_condition.py +++ /dev/null @@ -1,150 +0,0 @@ -import torch -import matplotlib.pyplot as plt - -from pina import Trainer -from pina.optim import TorchOptimizer -from pina.problem import AbstractProblem -from pina.condition.data_condition import DataCondition -from pina.solver import AutoregressiveSolver,UnrollInstructions - -NUM_TIMESTEPS = 100 -NUM_FEATURES = 15 -USE_TEST_MODEL = False - -# ============================================================================ -# DATA -# ============================================================================ - -torch.manual_seed(42) - -y = torch.zeros(NUM_TIMESTEPS, NUM_FEATURES) -y[0] = torch.rand(NUM_FEATURES) # Random initial state - -for t in range(NUM_TIMESTEPS - 1): - y[t + 1] = 0.95 * y[t] # + 0.05 * torch.sin(y[t].sum()) - -# ============================================================================ -# TRAINING -# ============================================================================ - -class SimpleModel(torch.nn.Module): - def __init__(self): - super().__init__() - self.layers = torch.nn.Sequential( - torch.nn.Linear(y.shape[1], 20), - torch.nn.ReLU(), - torch.nn.Dropout(0.2), - torch.nn.Linear(20, y.shape[1]), - ) - - def forward(self, x): - return x + self.layers(x) - - -class TestModel(torch.nn.Module): - """ - Debug model that implements the EXACT transformation rule. - y[t+1] = 0.95 * y[t] - Expected loss is zero - """ - - def __init__(self, data_series=None): - super().__init__() - self.dummy_param = torch.nn.Parameter(torch.zeros(1)) - - def forward(self, x): - next_state = 0.95 * x # + 0.05 * torch.sin(x.sum(dim=1, keepdim=True)) - return next_state + 0.0 * self.dummy_param - - -class Problem(AbstractProblem): - output_variables = None - input_variables = None - conditions = { - "data_condition_0":DataCondition(input=y), - "data_condition_1":DataCondition(input=y), - } - -problem = Problem() - -#for each condition, define unroll instructions with these keys: -# - unroll_length: length of each unroll window -# - num_unrolls: number of unroll windows to create (if None, use all possible) -# - randomize: whether to randomize the starting indices of the unroll windows -unroll_instructions_list = [ - UnrollInstructions( - condition_name="data_condition_0", - unroll_length=10, - num_unrolls=89, - randomize=True, - eps=5.0 - ), - UnrollInstructions( - condition_name="data_condition_1", - unroll_length=20, - num_unrolls=79, - randomize=True, - eps=10.0 - ), -] - -solver = AutoregressiveSolver( - unroll_instructions_list=unroll_instructions_list, - problem=problem, - model=TestModel() if USE_TEST_MODEL else SimpleModel(), - optimizer= TorchOptimizer(torch.optim.AdamW, lr=0.01), -) - -trainer = Trainer( - solver, max_epochs=2000, accelerator="cpu", enable_model_summary=False, shuffle=False -) -trainer.train() - -# ============================================================================ -# VISUALIZATION -# ============================================================================ - -test_start_idx = 50 -num_prediction_steps = 30 - -initial_state = y[test_start_idx] # Shape: [features] -predictions = solver.predict(initial_state, num_prediction_steps) -actual = y[test_start_idx : test_start_idx + num_prediction_steps + 1] - -total_mse = torch.nn.functional.mse_loss(predictions[1:], actual[1:]) -print(f"\nOverall MSE (all {num_prediction_steps} steps): {total_mse:.6f}") - -# viauzlize single dof -dof_to_plot = [0, 3, 6, 9, 12] -colors = [ - "r", - "g", - "b", - "c", - "m", - "y", - "k", -] -plt.figure(figsize=(10, 6)) -for dof, color in zip(dof_to_plot, colors): - plt.plot( - range(test_start_idx, test_start_idx + num_prediction_steps + 1), - actual[:, dof].numpy(), - label="Actual", - marker="o", - color=color, - markerfacecolor="none", - ) - plt.plot( - range(test_start_idx, test_start_idx + num_prediction_steps + 1), - predictions[:, dof].numpy(), - label="Predicted", - marker="x", - color=color, - ) - -plt.title(f"Autoregressive Predictions vs Actual, MRSE: {total_mse:.6f}") -plt.legend() -plt.xlabel("Timestep") -plt.savefig(f"autoregressive_predictions.png") -plt.close() diff --git a/pina/solver/autoregressive_solver/autoregressive_solver.py b/pina/solver/autoregressive_solver/autoregressive_solver.py index 0a31d1ae2..c754aae4d 100644 --- a/pina/solver/autoregressive_solver/autoregressive_solver.py +++ b/pina/solver/autoregressive_solver/autoregressive_solver.py @@ -6,18 +6,39 @@ from .autoregressive_solver_interface import UnrollInstructions from typing import List + class AutoregressiveSolver( AutoregressiveSolverInterface, SingleSolverInterface ): - """ - Autoregressive Solver class. + r""" + Autoregressive Solver for learning dynamical systems. + + This solver learns a one-step transition function + :math:`\mathcal{M}: \mathbb{R}^n \rightarrow \mathbb{R}^n` that maps + a state :math:`\mathbf{y}_t` to the next state :math:`\mathbf{y}_{t+1}`. + + During training, the model is unrolled over multiple time steps to + learn long-term dynamics. Given an initial state :math:`\mathbf{y}_0`, + the model generates predictions recursively: + + .. math:: + \hat{\mathbf{y}}_{t+1} = \mathcal{M}(\hat{\mathbf{y}}_t), + \quad \hat{\mathbf{y}}_0 = \mathbf{y}_0 + + The loss is computed over the entire unroll window: + + .. math:: + \mathcal{L} = \sum_{t=1}^{T} w_t \|\hat{\mathbf{y}}_t - \mathbf{y}_t\|^2 + + where :math:`w_t` are exponential weights (if ``eps`` is specified) + that down-weight later predictions to stabilize training. """ accepted_conditions_types = DataCondition def __init__( self, - unroll_instructions_list:List[UnrollInstructions], + unroll_instructions_list: List[UnrollInstructions], problem, model, loss=None, @@ -28,15 +49,27 @@ def __init__( ): """ Initialization of the :class:`AutoregressiveSolver` class. - :param list unroll_instructions_list: A list of UnrollInstructions, one for each condition. - this is supposed to map condition names to dict objects with unroll instructions. - :param AbstractProblem problem: The problem to be solved. - :param torch.nn.Module model: The model to be trained. - :param torch.nn.Module or LossInterface or None loss: The loss function to be minimized. If None, defaults to MSELoss. - :param TorchOptimizer or None optimizer: The optimizer to be used. If None, no optimization is performed. - :param TorchScheduler or None scheduler: The learning rate scheduler to be used. If None, no scheduling is performed. - :param Weighting or None weighting: The weighting scheme for combining losses from different conditions. If None, equal weighting is applied. - :param bool use_lt: Whether to use learning rate tuning. + + :param list[UnrollInstructions] unroll_instructions_list: List of + :class:`UnrollInstructions` specifying how to create training + windows for each condition. + :param AbstractProblem problem: The problem instance containing + the time series data conditions. + :param torch.nn.Module model: Neural network that predicts the + next state given the current state. + :param torch.nn.Module loss: Loss function to minimize. + If ``None``, :class:`torch.nn.MSELoss` is used. + Default is ``None``. + :param TorchOptimizer optimizer: Optimizer for training. + If ``None``, :class:`torch.optim.Adam` is used. + Default is ``None``. + :param TorchScheduler scheduler: Learning rate scheduler. + If ``None``, no scheduling is applied. Default is ``None``. + :param WeightingInterface weighting: Weighting scheme for + combining losses from multiple conditions. + If ``None``, uniform weighting is used. Default is ``None``. + :param bool use_lt: Whether to use LabelTensors. + Default is ``False``. """ super().__init__( @@ -50,30 +83,37 @@ def __init__( use_lt=use_lt, ) - def loss_data(self, data, unroll_instructions:UnrollInstructions): + def loss_data(self, data, unroll_instructions: UnrollInstructions): """ Compute the data loss for the recursive autoregressive solver. - This will be applied to each condition individually. - :param torch.Tensor data: all training data. - :param UnrollInstructions unroll_instructions: instructions on how to unroll the model for this condition. - :return: Computed loss value. + + Creates unroll windows from the data, then iteratively predicts + each next state and computes the loss against the ground truth. + + :param torch.Tensor data: Time series with shape + ``[n_timesteps, n_features]``. + :param UnrollInstructions unroll_instructions: Configuration + for window creation and loss weighting. + :return: Weighted sum of step losses. :rtype: torch.Tensor """ initial_data, unroll_data = self.create_unroll_windows( data, unroll_instructions ) - current_state = initial_data # [num_unrolls, features] + current_state = initial_data # [num_unrolls, features] losses = [] for step in range(unroll_instructions.unroll_length): - predicted_state = self.forward(current_state) # [num_unrolls, features] + predicted_state = self.forward( + current_state + ) # [num_unrolls, features] target_state = unroll_data[:, step, :] # [num_unrolls, features] step_loss = self._loss_fn(predicted_state, target_state) losses.append(step_loss) current_state = predicted_state - + step_losses = torch.stack(losses) # [unroll_length] with torch.no_grad(): @@ -83,16 +123,28 @@ def loss_data(self, data, unroll_instructions:UnrollInstructions): else: weights = torch.exp(-eps * torch.cumsum(step_losses, dim=0)) weights = weights / weights.sum() - + return (step_losses * weights).sum() - def create_unroll_windows(self, data, unroll_instructions:UnrollInstructions): + def create_unroll_windows( + self, data, unroll_instructions: UnrollInstructions + ): """ - Create unroll windows for each condition from the data based on the provided instructions. - :param torch.Tensor data: The full data tensor. - :param UnrollInstructions unroll_instructions: Instructions on how to unroll the model for this condition. - :return: Tuple of initial data and unroll data tensors. - :rtype: (torch.Tensor, torch.Tensor) + Create unroll windows from time series data. + + Slices the input time series into overlapping windows, each + consisting of an initial state and subsequent target states. + + :param torch.Tensor data: Time series with shape + ``[n_timesteps, n_features]``. + :param UnrollInstructions unroll_instructions: Configuration + specifying window length and count. + :return: Tuple of ``(initial_data, unroll_data)`` where: + + - ``initial_data``: Shape ``[num_unrolls, n_features]`` + - ``unroll_data``: Shape ``[num_unrolls, unroll_length, n_features]`` + + :rtype: tuple[torch.Tensor, torch.Tensor] """ unroll_length = unroll_instructions.unroll_length @@ -104,22 +156,34 @@ def create_unroll_windows(self, data, unroll_instructions:UnrollInstructions): ): idx = starting_index.item() start_list.append(data[idx]) - unroll_list.append(data[idx+1 : idx+1+unroll_length, :]) + unroll_list.append(data[idx + 1 : idx + 1 + unroll_length, :]) - initial_data = torch.stack(start_list) # [num_unrolls, features] - unroll_data = torch.stack(unroll_list) # [num_unrolls, unroll_length, features] + initial_data = torch.stack(start_list) # [num_unrolls, features] + unroll_data = torch.stack( + unroll_list + ) # [num_unrolls, unroll_length, features] return initial_data, unroll_data - def decide_starting_indices(self, data, unroll_instructions:UnrollInstructions): + def decide_starting_indices( + self, data, unroll_instructions: UnrollInstructions + ): """ - Decide the starting indices for unrolling based on the provided instructions. - :param torch.Tensor data: The full data tensor. - :param UnrollInstructions unroll_instructions: Instructions on how to unroll the model for this condition. - :return: Tensor of starting indices. + Determine starting indices for unroll windows. + + Computes valid starting positions ensuring each window has + enough subsequent time steps for the specified unroll length. + + :param torch.Tensor data: Time series with shape + ``[n_timesteps, n_features]``. + :param UnrollInstructions unroll_instructions: Configuration + with ``unroll_length``, ``num_unrolls``, and ``randomize``. + :return: 1D tensor of starting indices. :rtype: torch.Tensor """ n_step, n_features = data.shape num_unrolls = unroll_instructions.num_unrolls + if unroll_instructions.unroll_length >= n_step: + return [] #TODO: decide if it is better to raise an error here max_start = n_step - unroll_instructions.unroll_length indices = torch.arange(max_start, device=data.device) @@ -134,21 +198,28 @@ def decide_starting_indices(self, data, unroll_instructions:UnrollInstructions): def predict(self, initial_state, num_steps): """ - Make recursive predictions starting from an initial state. - :param torch.Tensor initial_state: Initial state tensor. - :param int num_steps: Number of steps to predict ahead. - :return: Tensor of predictions. + Generate predictions by recursively applying the model. + + Starting from the initial state, applies the model repeatedly + to generate a trajectory of predicted states. + + :param torch.Tensor initial_state: Starting state with shape + ``[n_features]``. + :param int num_steps: Number of future time steps to predict. + :return: Predicted trajectory with shape + ``[num_steps + 1, n_features]``, where the first row is + the initial state. :rtype: torch.Tensor """ self.eval() # Set model to evaluation mode - + current_state = initial_state predictions = [current_state] - + with torch.no_grad(): for step in range(num_steps): next_state = self.forward(current_state) predictions.append(next_state) current_state = next_state - - return torch.stack(predictions) \ No newline at end of file + + return torch.stack(predictions) diff --git a/pina/solver/autoregressive_solver/autoregressive_solver_interface.py b/pina/solver/autoregressive_solver/autoregressive_solver_interface.py index bf6a67462..9d08b6b4e 100644 --- a/pina/solver/autoregressive_solver/autoregressive_solver_interface.py +++ b/pina/solver/autoregressive_solver/autoregressive_solver_interface.py @@ -13,22 +13,67 @@ @dataclass class UnrollInstructions: - """Instructions for unrolling a single condition.""" + """ + Configuration for creating unroll windows from a time series condition. + + This dataclass specifies how to slice a time series into overlapping + windows for autoregressive training. Each window consists of an initial + state and a sequence of subsequent states used as targets. + + :param str condition_name: Name of the condition in the problem's + conditions dictionary. Must match a key in + ``problem.conditions``. + :param int unroll_length: The length of each unroll window. + :param Optional[int] num_unrolls: The number of unroll windows to create. + If ``None``, all possible windows are used. Default is None. + :param bool randomize: Whether to randomize the starting indices of the unroll windows. + Default is True. + :param Optional[float] eps: Epsilon parameter for exponential loss weighting. + If ``None``, uniform weighting is applied. Default is ``None``. + + :Example: + >>> instructions = UnrollInstructions( + ... condition_name="trajectory", + ... unroll_length=10, + ... num_unrolls=100, + ... randomize=True, + ... eps=0.1 + ... ) + """ + condition_name: str unroll_length: int num_unrolls: Optional[int] = None randomize: bool = True eps: Optional[float] = None + class AutoregressiveSolverInterface(SolverInterface): + """ + Base class for autoregressive solvers. + + This interface defines solvers that learn to predict the next state + of a dynamical system given the current state. The solver uses an + unrolling strategy where predictions are made recursively over + multiple time steps during training. + The ``AutoregressiveSolverInterface`` is compatible with problems + containing :class:`~pina.condition.data_condition.DataCondition` + conditions, where the input represents a time series trajectory. + """ def __init__(self, unroll_instructions_list, loss=None, **kwargs): """ Initialization of the :class:`AutoregressiveSolverInterface` class. - :param list unroll_instructions: A list of UnrollInstructions, one for each condition. - this is supposed to map condition names to dict objects with unroll instructions. - :param loss: The loss function to be minimized. If None, defaults to MSELoss. - :type loss: torch.nn.Module or LossInterface, optional + + :param list[UnrollInstructions] unroll_instructions_list: List of + :class:`UnrollInstructions` objects, one for each condition + in the problem. Each instruction specifies how to create + unroll windows for training. + ::param torch.nn.Module loss: Loss function to minimize. + If ``None``, :class:`torch.nn.MSELoss` is used. + Default is ``None``. + :param kwargs: Additional keyword arguments passed to + :class:`~pina.solver.solver.SolverInterface`. """ super().__init__(**kwargs) @@ -43,35 +88,42 @@ def __init__(self, unroll_instructions_list, loss=None, **kwargs): def optimization_cycle(self, batch): """ Optimization cycle for this family of solvers. - Iterates over each conditions and each time applies the specialized loss_data function. - :param dict batch: A dictionary mapping condition names to data batches. - :return: A dictionary mapping condition names to computed loss values. - :rtype: dict + Iterates over each condition and each time applies the specialized loss_data function. + + :param list[tuple[str, dict]] batch: List of tuples where each + tuple contains a condition name and a dictionary with the + ``"input"`` key mapping to the time series tensor. + :return: Dictionary mapping condition names to computed loss values. + :rtype: dict[str, torch.Tensor] """ condition_loss = {} for condition_name, points in batch: - #find unroll instructions for this condition + # find unroll instructions for this condition unroll_instructions = next( - ui for ui in self._unroll_instructions_list + ui + for ui in self._unroll_instructions_list if ui.condition_name == condition_name ) loss = self.loss_data( - points["input"], - unroll_instructions, - ) + points["input"], + unroll_instructions, + ) condition_loss[condition_name] = loss return condition_loss @abstractmethod - def loss_data(self, input, unroll_instructions:UnrollInstructions): + def loss_data(self, input, unroll_instructions: UnrollInstructions): """ - Computes the data loss for each condition. - N.B.: This loss_data function must make use of unroll_instructions to know how to unroll the model. - - :param torch.Tensor input: all training data. - :param UnrollInstruction unroll_instructions: instructions on how to unroll the model for this condition. - :return: Computed loss value. + Compute the data loss for each condition. + This method must be implemented by subclasses to define the + specific loss computation strategy. + + :param torch.Tensor input: Time series data with shape + ``[n_timesteps, n_features]``. + :param UnrollInstructions unroll_instructions: Configuration + for creating unroll windows from the input data. + :return: Scalar loss value for this condition. :rtype: torch.Tensor """ pass @@ -79,11 +131,16 @@ def loss_data(self, input, unroll_instructions:UnrollInstructions): @abstractmethod def predict(self, initial_state, num_steps): """ - Make recursive predictions starting from an initial state. + Generate recursive predictions from an initial state. + + Starting from the initial state, repeatedly applies the model + to predict subsequent states. - :param torch.Tensor initial_state: Initial state tensor. - :param int num_steps: Number of steps to predict ahead. - :return: Tensor of predictions. + :param torch.Tensor initial_state: Starting state with shape + ``[n_features]`` or ``[batch_size, n_features]``. + :param int num_steps: Number of future steps to predict. + :return: Tensor of predictions with shape + ``[num_steps + 1, n_features]``, including the initial state. :rtype: torch.Tensor """ pass @@ -96,4 +153,4 @@ def loss(self): :return: The loss function to be minimized. :rtype: torch.nn.Module """ - return self._loss_fn \ No newline at end of file + return self._loss_fn diff --git a/tests/test_solver/test_autoregressive_solver.py b/tests/test_solver/test_autoregressive_solver.py new file mode 100644 index 000000000..467f90afe --- /dev/null +++ b/tests/test_solver/test_autoregressive_solver.py @@ -0,0 +1,213 @@ +import pytest +import torch + +from pina import Trainer +from pina.optim import TorchOptimizer +from pina.problem import AbstractProblem +from pina.condition.data_condition import DataCondition +from pina.solver import AutoregressiveSolver, UnrollInstructions + +NUM_TIMESTEPS = 10 +NUM_FEATURES = 3 + + +@pytest.fixture +def y_data(): + torch.manual_seed(42) + y = torch.zeros(NUM_TIMESTEPS, NUM_FEATURES) + y[0] = torch.rand(NUM_FEATURES) + for t in range(NUM_TIMESTEPS - 1): + y[t + 1] = 0.95 * y[t] + return y + + +# crate a test Model +class ExactModel(torch.nn.Module): + """ + This model implements the EXACT transformation rule. + y[t+1] = 0.95 * y[t] + Expected loss is zero + """ + + def __init__(self, data_series=None): + super().__init__() + self.dummy_param = torch.nn.Parameter(torch.zeros(1)) + + def forward(self, x): + next_state = 0.95 * x + return next_state + 0.0 * self.dummy_param + + +@pytest.fixture +def solver(y_data): + """Create a minimal solver for testing internal methods.""" + + class Problem(AbstractProblem): + output_variables = None + input_variables = None + conditions = {"data": DataCondition(input=y_data)} + + return AutoregressiveSolver( + unroll_instructions_list=[ + UnrollInstructions(condition_name="data", unroll_length=3) + ], + problem=Problem(), + model=ExactModel(), + ) + + +# Tests start here ============================================== + + +def test_exact_model(y_data): + class Problem(AbstractProblem): + output_variables = None + input_variables = None + conditions = { + "data_condition": DataCondition(input=y_data), + } + + unroll_instruction = UnrollInstructions( + condition_name="data_condition", + unroll_length=5, + ) + + solver = AutoregressiveSolver( + unroll_instructions_list=[unroll_instruction], + problem=Problem(), + model=ExactModel(), + optimizer=TorchOptimizer(torch.optim.AdamW, lr=0.01), + ) + + loss = solver.loss_data(y_data, unroll_instruction) + assert torch.isclose(loss, torch.tensor(0.0), atol=1e-6) + + trainer = Trainer( + solver=solver, + max_epochs=5, + accelerator="cpu", + shuffle=False, + enable_model_summary=False, + ) + trainer.train() + + loss_after_training = solver.loss_data(y_data, unroll_instruction) + assert torch.isclose(loss_after_training, torch.tensor(0.0), atol=1e-6) + + predictions = solver.predict( + initial_state=y_data[0], num_steps=NUM_TIMESTEPS - 1 + ) + expected_predictions = y_data + assert torch.allclose(predictions, expected_predictions, atol=1e-6) + + +def test_indices_sequential_when_no_randomize(solver, y_data): + """Indices should be [0, 1, 2, ...] when randomize=False.""" + instructions = UnrollInstructions( + condition_name="data", + unroll_length=3, + randomize=False, + ) + indices = solver.decide_starting_indices(y_data, instructions) + + # y_data has 10 timesteps, unroll_length=3 → max_start = 10 - 3 = 7 + expected = torch.arange(7) + assert torch.equal(indices, expected) + + +def test_indices_permuted_when_randomize(solver, y_data): + """Indices should contain same values but permuted when randomize=True.""" + instructions = UnrollInstructions( + condition_name="data", + unroll_length=3, + randomize=True, + ) + indices = solver.decide_starting_indices(y_data, instructions) + + expected_values = set(range(7)) + actual_values = set(indices.tolist()) + assert actual_values == expected_values + + +def test_num_unrolls_parameter(solver, y_data): + """num_unrolls should limit the number of indices returned.""" + instructions = UnrollInstructions( + condition_name="data", + unroll_length=3, + num_unrolls=3, + randomize=False, + ) + indices = solver.decide_starting_indices(y_data, instructions) + + assert len(indices) == 3 + assert torch.equal(indices, torch.arange(3)) + + +def test_num_unrolls_greater_than_max_possible(solver, y_data): + """num_unrolls > max_possible should return all possible indices.""" + unroll_length = 3 + maximum_number_of_unrolls = ( + NUM_TIMESTEPS - unroll_length + ) # 10 - unroll_length(3) = 7 + instructions = UnrollInstructions( + condition_name="data", + unroll_length=unroll_length, + num_unrolls=100, + randomize=False, + ) + indices = solver.decide_starting_indices(y_data, instructions) + + assert len(indices) == maximum_number_of_unrolls + + +def test_no_valid_indices_when_unroll_too_long(solver, y_data): + """When unroll_length >= n_timesteps, no valid indices exist.""" + instructions = UnrollInstructions( + condition_name="data", + unroll_length=NUM_TIMESTEPS + 1, + randomize=False, + ) + indices = solver.decide_starting_indices(y_data, instructions) + print(indices) + assert len(indices) == 0 + + +def test_unroll_window_shape(solver, y_data): + """Unroll windows should have correct shapes.""" + instructions = UnrollInstructions( + condition_name="data", + unroll_length=4, + num_unrolls=2, + randomize=False, + ) + initial_data, unroll_data = solver.create_unroll_windows( + y_data, instructions + ) + + assert initial_data.shape == (2, NUM_FEATURES) # [num_unrolls, features] + assert unroll_data.shape == ( + 2, + 4, + NUM_FEATURES, + ) # [num_unrolls, unroll_length, features] + + +def test_unroll_windows_content(solver, y_data): + """Verify actual content of unroll windows.""" + instructions = UnrollInstructions( + condition_name="data", + unroll_length=3, + num_unrolls=2, + randomize=False, + ) + initial_data, unroll_data = solver.create_unroll_windows( + y_data, instructions + ) + + # initial_data[i] should be y_data[i] + assert torch.equal(initial_data[0], y_data[0]) + assert torch.equal(initial_data[1], y_data[1]) + + # unroll_data[i] should be y_data[i+1 : i+1+unroll_length] + assert torch.equal(unroll_data[0], y_data[1:4]) + assert torch.equal(unroll_data[1], y_data[2:5]) From aa644c5b658782aacc9198c8b88baf4306b2cd3c Mon Sep 17 00:00:00 2001 From: Davide Miotti Date: Tue, 13 Jan 2026 15:10:45 +0100 Subject: [PATCH 05/10] fix formatting --- .../autoregressive_solver.py | 28 +++++++++---------- .../autoregressive_solver_interface.py | 7 +++-- 2 files changed, 18 insertions(+), 17 deletions(-) diff --git a/pina/solver/autoregressive_solver/autoregressive_solver.py b/pina/solver/autoregressive_solver/autoregressive_solver.py index c754aae4d..d91e1e254 100644 --- a/pina/solver/autoregressive_solver/autoregressive_solver.py +++ b/pina/solver/autoregressive_solver/autoregressive_solver.py @@ -16,7 +16,7 @@ class AutoregressiveSolver( This solver learns a one-step transition function :math:`\mathcal{M}: \mathbb{R}^n \rightarrow \mathbb{R}^n` that maps a state :math:`\mathbf{y}_t` to the next state :math:`\mathbf{y}_{t+1}`. - + During training, the model is unrolled over multiple time steps to learn long-term dynamics. Given an initial state :math:`\mathbf{y}_0`, the model generates predictions recursively: @@ -183,7 +183,7 @@ def decide_starting_indices( n_step, n_features = data.shape num_unrolls = unroll_instructions.num_unrolls if unroll_instructions.unroll_length >= n_step: - return [] #TODO: decide if it is better to raise an error here + return [] # TODO: decide if it is better to raise an error here max_start = n_step - unroll_instructions.unroll_length indices = torch.arange(max_start, device=data.device) @@ -198,18 +198,18 @@ def decide_starting_indices( def predict(self, initial_state, num_steps): """ - Generate predictions by recursively applying the model. - - Starting from the initial state, applies the model repeatedly - to generate a trajectory of predicted states. - - :param torch.Tensor initial_state: Starting state with shape - ``[n_features]``. - :param int num_steps: Number of future time steps to predict. - :return: Predicted trajectory with shape - ``[num_steps + 1, n_features]``, where the first row is - the initial state. - :rtype: torch.Tensor + Generate predictions by recursively applying the model. + + Starting from the initial state, applies the model repeatedly + to generate a trajectory of predicted states. + + :param torch.Tensor initial_state: Starting state with shape + ``[n_features]``. + :param int num_steps: Number of future time steps to predict. + :return: Predicted trajectory with shape + ``[num_steps + 1, n_features]``, where the first row is + the initial state. + :rtype: torch.Tensor """ self.eval() # Set model to evaluation mode diff --git a/pina/solver/autoregressive_solver/autoregressive_solver_interface.py b/pina/solver/autoregressive_solver/autoregressive_solver_interface.py index 9d08b6b4e..67f64b6cd 100644 --- a/pina/solver/autoregressive_solver/autoregressive_solver_interface.py +++ b/pina/solver/autoregressive_solver/autoregressive_solver_interface.py @@ -11,6 +11,7 @@ from ...condition import DataCondition from typing import Optional + @dataclass class UnrollInstructions: """ @@ -30,7 +31,7 @@ class UnrollInstructions: Default is True. :param Optional[float] eps: Epsilon parameter for exponential loss weighting. If ``None``, uniform weighting is applied. Default is ``None``. - + :Example: >>> instructions = UnrollInstructions( ... condition_name="trajectory", @@ -64,7 +65,7 @@ class AutoregressiveSolverInterface(SolverInterface): def __init__(self, unroll_instructions_list, loss=None, **kwargs): """ Initialization of the :class:`AutoregressiveSolverInterface` class. - + :param list[UnrollInstructions] unroll_instructions_list: List of :class:`UnrollInstructions` objects, one for each condition in the problem. Each instruction specifies how to create @@ -89,7 +90,7 @@ def optimization_cycle(self, batch): """ Optimization cycle for this family of solvers. Iterates over each condition and each time applies the specialized loss_data function. - + :param list[tuple[str, dict]] batch: List of tuples where each tuple contains a condition name and a dictionary with the ``"input"`` key mapping to the time series tensor. From c7e3011ad3d6de8c5b5fc8c60a168a4f9e90e862 Mon Sep 17 00:00:00 2001 From: Davide Miotti Date: Thu, 15 Jan 2026 17:39:36 +0100 Subject: [PATCH 06/10] separate function for unrolling --- pina/solver/__init__.py | 1 - pina/solver/autoregressive_solver/__init__.py | 1 - .../autoregressive_solver.py | 244 ++++++++++-------- .../autoregressive_solver_interface.py | 107 ++------ .../test_solver/test_autoregressive_solver.py | 217 +++++----------- 5 files changed, 232 insertions(+), 338 deletions(-) diff --git a/pina/solver/__init__.py b/pina/solver/__init__.py index 8494df8b0..e7d48e2b3 100644 --- a/pina/solver/__init__.py +++ b/pina/solver/__init__.py @@ -45,5 +45,4 @@ from .autoregressive_solver import ( AutoregressiveSolver, AutoregressiveSolverInterface, - UnrollInstructions, ) diff --git a/pina/solver/autoregressive_solver/__init__.py b/pina/solver/autoregressive_solver/__init__.py index ac0d60a12..9ef7c43e1 100644 --- a/pina/solver/autoregressive_solver/__init__.py +++ b/pina/solver/autoregressive_solver/__init__.py @@ -2,4 +2,3 @@ from .autoregressive_solver import AutoregressiveSolver from .autoregressive_solver_interface import AutoregressiveSolverInterface -from .autoregressive_solver_interface import UnrollInstructions diff --git a/pina/solver/autoregressive_solver/autoregressive_solver.py b/pina/solver/autoregressive_solver/autoregressive_solver.py index d91e1e254..eb38b1b0c 100644 --- a/pina/solver/autoregressive_solver/autoregressive_solver.py +++ b/pina/solver/autoregressive_solver/autoregressive_solver.py @@ -3,10 +3,8 @@ from pina.solver.solver import SingleSolverInterface from pina.condition import DataCondition from .autoregressive_solver_interface import AutoregressiveSolverInterface -from .autoregressive_solver_interface import UnrollInstructions from typing import List - class AutoregressiveSolver( AutoregressiveSolverInterface, SingleSolverInterface ): @@ -38,9 +36,9 @@ class AutoregressiveSolver( def __init__( self, - unroll_instructions_list: List[UnrollInstructions], problem, model, + conditions_settings={}, loss=None, optimizer=None, scheduler=None, @@ -50,9 +48,6 @@ def __init__( """ Initialization of the :class:`AutoregressiveSolver` class. - :param list[UnrollInstructions] unroll_instructions_list: List of - :class:`UnrollInstructions` specifying how to create training - windows for each condition. :param AbstractProblem problem: The problem instance containing the time series data conditions. :param torch.nn.Module model: Neural network that predicts the @@ -73,153 +68,186 @@ def __init__( """ super().__init__( - unroll_instructions_list=unroll_instructions_list, problem=problem, model=model, + conditions_settings=conditions_settings, loss=loss, optimizer=optimizer, scheduler=scheduler, weighting=weighting, use_lt=use_lt, ) - - def loss_data(self, data, unroll_instructions: UnrollInstructions): + + @staticmethod + def unroll( + data, unroll_length: int, num_unrolls=None, randomize: bool = True + ): """ - Compute the data loss for the recursive autoregressive solver. - - Creates unroll windows from the data, then iteratively predicts - each next state and computes the loss against the ground truth. + Create unroll windows from time series data. - :param torch.Tensor data: Time series with shape - ``[n_timesteps, n_features]``. - :param UnrollInstructions unroll_instructions: Configuration - for window creation and loss weighting. - :return: Weighted sum of step losses. + This is a pre-processing step. It slices the input time series into + overlapping windows of length ``Twin = unroll_length + 1`` along the + time axis (axis 0). Each window contains the initial state and the + subsequent target states used to compute a multi-step loss. + + :param torch.Tensor data: Time series tensor with shape ``[T, *state_shape]``. + The first axis is interpreted as time. + :param int unroll_length: Number of transitions in each window. + Each window has length ``unroll_length + 1``. + :param int num_unrolls: Maximum number of windows to return. If ``None``, + all valid windows are returned. Default is ``None``. + :param bool randomize: If ``True``, starting indices are randomly permuted + before applying ``num_unrolls``. Default is ``True``. + :return: Tensor of unroll windows with shape ``[Nw, unroll_length + 1, *state_shape]``. + If no valid windows exist, returns an empty tensor with shape + ``[0, unroll_length + 1, *state_shape]``. :rtype: torch.Tensor """ - - initial_data, unroll_data = self.create_unroll_windows( - data, unroll_instructions + starts = AutoregressiveSolver.decide_starting_indices( + data, unroll_length, num_unrolls, randomize ) - current_state = initial_data # [num_unrolls, features] + if starts.numel() == 0: + return torch.empty((0, unroll_length + 1, *data.shape[1:]), device=data.device) - losses = [] - for step in range(unroll_instructions.unroll_length): - - predicted_state = self.forward( - current_state - ) # [num_unrolls, features] - target_state = unroll_data[:, step, :] # [num_unrolls, features] - step_loss = self._loss_fn(predicted_state, target_state) - losses.append(step_loss) - current_state = predicted_state + windows = [data[int(start): int(start) + unroll_length + 1] for start in starts] - step_losses = torch.stack(losses) # [unroll_length] + return torch.stack(windows, dim=0) #[num_unrolls, unroll_length + 1, *data.shape[1:]] - with torch.no_grad(): - eps = unroll_instructions.eps - if eps is None: - weights = torch.ones_like(step_losses) - else: - weights = torch.exp(-eps * torch.cumsum(step_losses, dim=0)) - weights = weights / weights.sum() + @staticmethod + def decide_starting_indices(data, unroll_length, num_unrolls=None, randomize=True): + """ + Determine starting indices for unroll windows. - return (step_losses * weights).sum() + Computes valid starting positions ensuring each window has enough + subsequent time steps for the specified unroll length. - def create_unroll_windows( - self, data, unroll_instructions: UnrollInstructions - ): + :param torch.Tensor data: Time series tensor with shape ``[T, *state_shape]``. + :param int unroll_length: Number of transitions in each window. + :param int num_unrolls: Maximum number of indices to return. If ``None``, + all valid indices are returned. Default is ``None``. + :param bool randomize: If ``True``, indices are randomly permuted before + applying ``num_unrolls``. Default is ``True``. + :return: 1D tensor of starting indices with dtype ``torch.long``. + :rtype: torch.Tensor """ - Create unroll windows from time series data. + n_step = int(data.shape[0]) + twin = int(unroll_length + 1) + last_start = n_step - twin + if last_start < 0: + return torch.empty(0, dtype=torch.long, device=data.device) - Slices the input time series into overlapping windows, each - consisting of an initial state and subsequent target states. + indices = torch.arange(last_start + 1, device=data.device) - :param torch.Tensor data: Time series with shape - ``[n_timesteps, n_features]``. - :param UnrollInstructions unroll_instructions: Configuration - specifying window length and count. - :return: Tuple of ``(initial_data, unroll_data)`` where: + if randomize: + indices = indices[torch.randperm(len(indices), device=data.device)] - - ``initial_data``: Shape ``[num_unrolls, n_features]`` - - ``unroll_data``: Shape ``[num_unrolls, unroll_length, n_features]`` + if num_unrolls is not None and num_unrolls < len(indices): + indices = indices[:num_unrolls] - :rtype: tuple[torch.Tensor, torch.Tensor] - """ + return indices - unroll_length = unroll_instructions.unroll_length - - start_list = [] - unroll_list = [] - for starting_index in self.decide_starting_indices( - data, unroll_instructions - ): - idx = starting_index.item() - start_list.append(data[idx]) - unroll_list.append(data[idx + 1 : idx + 1 + unroll_length, :]) - - initial_data = torch.stack(start_list) # [num_unrolls, features] - unroll_data = torch.stack( - unroll_list - ) # [num_unrolls, unroll_length, features] - return initial_data, unroll_data - - def decide_starting_indices( - self, data, unroll_instructions: UnrollInstructions - ): + def loss_data(self, unroll, eps=None, aggregation_strategy=None): """ - Determine starting indices for unroll windows. + Compute the autoregressive multi-step data loss. - Computes valid starting positions ensuring each window has - enough subsequent time steps for the specified unroll length. + The input ``unroll`` is expected to be a batch of precomputed unroll windows + with shape ``[B, Twin, *state_shape]``. The first element along the ``Twin`` + axis is used as current state, and the following elements are the targets. - :param torch.Tensor data: Time series with shape - ``[n_timesteps, n_features]``. - :param UnrollInstructions unroll_instructions: Configuration - with ``unroll_length``, ``num_unrolls``, and ``randomize``. - :return: 1D tensor of starting indices. + :param torch.Tensor unroll: Batch of unroll windows with shape + ``[B, Twin, *state_shape]`` where ``Twin = unroll_length + 1``. + :param float eps: If provided, applies step weighting through + :meth:`weighting_strategy`. If ``None``, uniform normalized weights are used. + Default is ``None``. + :param callable aggregation_strategy: Reduction applied to the weighted per-step + losses. If ``None``, :func:`torch.sum` is used. Default is ``None``. + :return: Scalar loss value for the given batch. :rtype: torch.Tensor """ - n_step, n_features = data.shape - num_unrolls = unroll_instructions.num_unrolls - if unroll_instructions.unroll_length >= n_step: - return [] # TODO: decide if it is better to raise an error here - - max_start = n_step - unroll_instructions.unroll_length - indices = torch.arange(max_start, device=data.device) + # batch dimensition is unroll.shape[0] -the number of unrolls- + Twin = unroll.shape[1] - if num_unrolls is not None and num_unrolls < len(indices): - indices = indices[:num_unrolls] + current_state = unroll[:, 0, ...] # first time step of each batch + losses = [] + for step in range(1, Twin): - if unroll_instructions.randomize: - indices = indices[torch.randperm(len(indices), device=data.device)] + predicted_state = self.forward( + current_state + ) # [num_unrolls, features] + target_state = unroll[:, step, ...] # [num_unrolls, features] + step_loss = self._loss_fn(predicted_state, target_state) + losses.append(step_loss) + current_state = predicted_state - return indices + step_losses = torch.stack(losses) # [unroll_length] + + with torch.no_grad(): + weights = AutoregressiveSolver.weighting_strategy(step_losses, eps) + + if aggregation_strategy is None: + aggregation_strategy = torch.sum + + return aggregation_strategy(step_losses * weights) + + @staticmethod + def weighting_strategy(step_losses, eps=None, clamp=50.0): + """ + Compute normalized weights for per-step losses. + + :param torch.Tensor step_losses: 1D tensor of per-step losses with shape + ``[Twin - 1]``. + :param float eps: Weighting strength. If ``None``, uniform normalized weights + are returned. Default is ``None``. + :param float clamp: Clamp applied to log-weights before normalization to + improve numerical stability. Default is ``50.0``. + :return: 1D tensor of weights with the same shape as ``step_losses`` that sums to 1. + :rtype: torch.Tensor + """ + if eps is None: + weight = torch.ones_like(step_losses) / step_losses.numel() + else: + log_w = -eps * torch.cumsum(step_losses, dim=0) + log_w = torch.clamp(log_w, min=-clamp, max=clamp) # prevent overflow + weight = torch.softmax(log_w, dim=0) + # weights = torch.exp(-eps * torch.cumsum(step_losses, dim=0)) + return weight def predict(self, initial_state, num_steps): """ Generate predictions by recursively applying the model. - Starting from the initial state, applies the model repeatedly - to generate a trajectory of predicted states. - - :param torch.Tensor initial_state: Starting state with shape - ``[n_features]``. - :param int num_steps: Number of future time steps to predict. - :return: Predicted trajectory with shape - ``[num_steps + 1, n_features]``, where the first row is - the initial state. - :rtype: torch.Tensor + Starting from ``initial_state``, applies the model repeatedly to generate + a trajectory of length ``num_steps + 1`` (including the initial state). + + :param torch.Tensor initial_state: Starting state. Supported shapes: + - ``[n_features]`` (unbatched, 1D) + - ``[B, n_features]`` (batched) + More general tensors ``[*state_shape]`` / ``[B, *state_shape]`` are also + supported, provided the model can process them. + :param int num_steps: Number of future time steps to predict. + :return: Predicted trajectory including the initial state. Shape: + - ``[num_steps + 1, *state_shape]`` if unbatched input + - ``[num_steps + 1, B, *state_shape]`` if batched input + :rtype: torch.Tensor """ self.eval() # Set model to evaluation mode current_state = initial_state - predictions = [current_state] + added_batch = False + if current_state.dim() == 1: + current_state = current_state.unsqueeze(0) + added_batch = True + + predictions = [current_state] with torch.no_grad(): for step in range(num_steps): next_state = self.forward(current_state) predictions.append(next_state) current_state = next_state - return torch.stack(predictions) + out = torch.stack(predictions, dim=0) + if added_batch: + out = out[:, 0, ...] # remove batch dimension + + return out diff --git a/pina/solver/autoregressive_solver/autoregressive_solver_interface.py b/pina/solver/autoregressive_solver/autoregressive_solver_interface.py index 67f64b6cd..b76fd842e 100644 --- a/pina/solver/autoregressive_solver/autoregressive_solver_interface.py +++ b/pina/solver/autoregressive_solver/autoregressive_solver_interface.py @@ -11,69 +11,28 @@ from ...condition import DataCondition from typing import Optional - -@dataclass -class UnrollInstructions: - """ - Configuration for creating unroll windows from a time series condition. - - This dataclass specifies how to slice a time series into overlapping - windows for autoregressive training. Each window consists of an initial - state and a sequence of subsequent states used as targets. - - :param str condition_name: Name of the condition in the problem's - conditions dictionary. Must match a key in - ``problem.conditions``. - :param int unroll_length: The length of each unroll window. - :param Optional[int] num_unrolls: The number of unroll windows to create. - If ``None``, all possible windows are used. Default is None. - :param bool randomize: Whether to randomize the starting indices of the unroll windows. - Default is True. - :param Optional[float] eps: Epsilon parameter for exponential loss weighting. - If ``None``, uniform weighting is applied. Default is ``None``. - - :Example: - >>> instructions = UnrollInstructions( - ... condition_name="trajectory", - ... unroll_length=10, - ... num_unrolls=100, - ... randomize=True, - ... eps=0.1 - ... ) - """ - - condition_name: str - unroll_length: int - num_unrolls: Optional[int] = None - randomize: bool = True - eps: Optional[float] = None - - class AutoregressiveSolverInterface(SolverInterface): """ Base class for autoregressive solvers. - This interface defines solvers that learn to predict the next state - of a dynamical system given the current state. The solver uses an - unrolling strategy where predictions are made recursively over - multiple time steps during training. - The ``AutoregressiveSolverInterface`` is compatible with problems - containing :class:`~pina.condition.data_condition.DataCondition` - conditions, where the input represents a time series trajectory. + The training pipeline expects :class:`~pina.condition.data_condition.DataCondition` + conditions. In the recommended configuration, each DataCondition input is a + collection of unroll windows with shape ``[Nw, Twin, *state_shape]``, where + ``Twin = unroll_length + 1``. The Trainer batches along the first axis, producing + ``[B, Twin, *state_shape]`` tensors passed to :meth:`loss_data`. """ - def __init__(self, unroll_instructions_list, loss=None, **kwargs): + def __init__(self, conditions_settings={}, loss=None, **kwargs): """ Initialization of the :class:`AutoregressiveSolverInterface` class. - :param list[UnrollInstructions] unroll_instructions_list: List of - :class:`UnrollInstructions` objects, one for each condition - in the problem. Each instruction specifies how to create - unroll windows for training. - ::param torch.nn.Module loss: Loss function to minimize. - If ``None``, :class:`torch.nn.MSELoss` is used. - Default is ``None``. - :param kwargs: Additional keyword arguments passed to + :param dict conditions_settings: Dictionary mapping condition names to a + dictionary of keyword arguments forwarded to :meth:`loss_data`. + Example keys: ``eps``, ``aggregation_strategy``. + If ``None``, an empty dict is used. Default is ``None``. + :param torch.nn.Module loss: Loss function to minimize. If ``None``, + :class:`torch.nn.MSELoss` is used. Default is ``None``. + :param kwargs: Additional keyword arguments forwarded to :class:`~pina.solver.solver.SolverInterface`. """ @@ -84,7 +43,8 @@ def __init__(self, unroll_instructions_list, loss=None, **kwargs): check_consistency(loss, (LossInterface, _Loss), subclass=False) self._loss_fn = loss - self._unroll_instructions_list = unroll_instructions_list + + self.conditions_settings = conditions_settings def optimization_cycle(self, batch): """ @@ -100,48 +60,35 @@ def optimization_cycle(self, batch): condition_loss = {} for condition_name, points in batch: - # find unroll instructions for this condition - unroll_instructions = next( - ui - for ui in self._unroll_instructions_list - if ui.condition_name == condition_name - ) + condition_settings = self.conditions_settings.get(condition_name, {}) loss = self.loss_data( points["input"], - unroll_instructions, + **condition_settings ) condition_loss[condition_name] = loss return condition_loss @abstractmethod - def loss_data(self, input, unroll_instructions: UnrollInstructions): + def loss_data(self, input, **settings): """ Compute the data loss for each condition. This method must be implemented by subclasses to define the specific loss computation strategy. - - :param torch.Tensor input: Time series data with shape - ``[n_timesteps, n_features]``. - :param UnrollInstructions unroll_instructions: Configuration - for creating unroll windows from the input data. - :return: Scalar loss value for this condition. - :rtype: torch.Tensor """ pass @abstractmethod def predict(self, initial_state, num_steps): """ - Generate recursive predictions from an initial state. - - Starting from the initial state, repeatedly applies the model - to predict subsequent states. - - :param torch.Tensor initial_state: Starting state with shape - ``[n_features]`` or ``[batch_size, n_features]``. - :param int num_steps: Number of future steps to predict. - :return: Tensor of predictions with shape - ``[num_steps + 1, n_features]``, including the initial state. + Generate predictions by recursively applying the model. + + :param torch.Tensor initial_state: Starting state. Supported shapes are: + - ``[*state_shape]`` (unbatched) + - ``[B, *state_shape]`` (batched) + :param int num_steps: Number of future time steps to predict. + :return: Predicted trajectory including the initial state. Shape: + - ``[num_steps + 1, *state_shape]`` if unbatched input + - ``[num_steps + 1, B, *state_shape]`` if batched input :rtype: torch.Tensor """ pass diff --git a/tests/test_solver/test_autoregressive_solver.py b/tests/test_solver/test_autoregressive_solver.py index 467f90afe..d1497b24d 100644 --- a/tests/test_solver/test_autoregressive_solver.py +++ b/tests/test_solver/test_autoregressive_solver.py @@ -5,21 +5,23 @@ from pina.optim import TorchOptimizer from pina.problem import AbstractProblem from pina.condition.data_condition import DataCondition -from pina.solver import AutoregressiveSolver, UnrollInstructions +from pina.solver import AutoregressiveSolver NUM_TIMESTEPS = 10 NUM_FEATURES = 3 -@pytest.fixture -def y_data(): +def _make_series(T=NUM_TIMESTEPS, F=NUM_FEATURES): torch.manual_seed(42) - y = torch.zeros(NUM_TIMESTEPS, NUM_FEATURES) - y[0] = torch.rand(NUM_FEATURES) - for t in range(NUM_TIMESTEPS - 1): + y = torch.zeros(T, F) + y[0] = torch.rand(F) + for t in range(T - 1): y[t + 1] = 0.95 * y[t] return y +@pytest.fixture +def y_data(): + return _make_series() # crate a test Model class ExactModel(torch.nn.Module): @@ -38,176 +40,95 @@ def forward(self, x): return next_state + 0.0 * self.dummy_param -@pytest.fixture -def solver(y_data): - """Create a minimal solver for testing internal methods.""" - - class Problem(AbstractProblem): - output_variables = None - input_variables = None - conditions = {"data": DataCondition(input=y_data)} +# Tests start here ============================================== - return AutoregressiveSolver( - unroll_instructions_list=[ - UnrollInstructions(condition_name="data", unroll_length=3) - ], - problem=Problem(), - model=ExactModel(), +def test_unroll_shape_and_content(y_data): + # unroll_length=4 -> Twin=5 + w = AutoregressiveSolver.unroll(y_data, unroll_length=4, num_unrolls=2, randomize=False) + assert w.shape == (2, 5, NUM_FEATURES) + # deterministic starts: 0 and 1 + assert torch.allclose(w[0], y_data[0:5]) + assert torch.allclose(w[1], y_data[1:6]) + +def test_decide_starting_indices_edge_cases(y_data): + idx = AutoregressiveSolver.decide_starting_indices(y_data, unroll_length=3, num_unrolls=None, randomize=False) + # T=10, Twin=4 => last_start=6 => 0..6 + assert torch.equal(idx, torch.arange(7)) + + idx_empty = AutoregressiveSolver.decide_starting_indices( + y_data, unroll_length=NUM_TIMESTEPS + 5, num_unrolls=None, randomize=False ) + assert idx_empty.numel() == 0 -# Tests start here ============================================== +def test_exact_model(y_data): + windows = AutoregressiveSolver.unroll( + y_data, unroll_length=5, num_unrolls=4, randomize=False) -def test_exact_model(y_data): class Problem(AbstractProblem): output_variables = None input_variables = None conditions = { - "data_condition": DataCondition(input=y_data), + "data_condition": DataCondition(input=windows), } - unroll_instruction = UnrollInstructions( - condition_name="data_condition", - unroll_length=5, - ) - + conditions_settings = { + "data_condition": {"eps": None, "aggregation_strategy": torch.sum}, + } solver = AutoregressiveSolver( - unroll_instructions_list=[unroll_instruction], problem=Problem(), + conditions_settings=conditions_settings, model=ExactModel(), optimizer=TorchOptimizer(torch.optim.AdamW, lr=0.01), ) - loss = solver.loss_data(y_data, unroll_instruction) + loss = solver.loss_data(windows, **conditions_settings["data_condition"]) assert torch.isclose(loss, torch.tensor(0.0), atol=1e-6) - trainer = Trainer( - solver=solver, - max_epochs=5, - accelerator="cpu", - shuffle=False, - enable_model_summary=False, - ) - trainer.train() - - loss_after_training = solver.loss_data(y_data, unroll_instruction) - assert torch.isclose(loss_after_training, torch.tensor(0.0), atol=1e-6) - - predictions = solver.predict( - initial_state=y_data[0], num_steps=NUM_TIMESTEPS - 1 - ) - expected_predictions = y_data - assert torch.allclose(predictions, expected_predictions, atol=1e-6) - - -def test_indices_sequential_when_no_randomize(solver, y_data): - """Indices should be [0, 1, 2, ...] when randomize=False.""" - instructions = UnrollInstructions( - condition_name="data", - unroll_length=3, - randomize=False, - ) - indices = solver.decide_starting_indices(y_data, instructions) - - # y_data has 10 timesteps, unroll_length=3 → max_start = 10 - 3 = 7 - expected = torch.arange(7) - assert torch.equal(indices, expected) - +def test_predict_matches_ground_truth(y_data): + class Problem(AbstractProblem): + output_variables = None + input_variables = None + conditions = {"data": DataCondition(input=y_data)} -def test_indices_permuted_when_randomize(solver, y_data): - """Indices should contain same values but permuted when randomize=True.""" - instructions = UnrollInstructions( - condition_name="data", - unroll_length=3, - randomize=True, - ) - indices = solver.decide_starting_indices(y_data, instructions) + solver = AutoregressiveSolver(problem=Problem(), model=ExactModel()) - expected_values = set(range(7)) - actual_values = set(indices.tolist()) - assert actual_values == expected_values + pred = solver.predict(y_data[0], num_steps=NUM_TIMESTEPS - 1) + assert pred.shape == y_data.shape + assert torch.allclose(pred, y_data, atol=1e-6) +def test_weighting_strategy_is_finite_and_normalized(): + step_losses = torch.tensor([1.0, 2.0, 3.0]) + w = AutoregressiveSolver.weighting_strategy(step_losses, eps=1.0) + assert torch.isfinite(w).all() + assert torch.isclose(w.sum(), torch.tensor(1.0), atol=1e-6) -def test_num_unrolls_parameter(solver, y_data): - """num_unrolls should limit the number of indices returned.""" - instructions = UnrollInstructions( - condition_name="data", - unroll_length=3, - num_unrolls=3, - randomize=False, - ) - indices = solver.decide_starting_indices(y_data, instructions) - - assert len(indices) == 3 - assert torch.equal(indices, torch.arange(3)) - - -def test_num_unrolls_greater_than_max_possible(solver, y_data): - """num_unrolls > max_possible should return all possible indices.""" - unroll_length = 3 - maximum_number_of_unrolls = ( - NUM_TIMESTEPS - unroll_length - ) # 10 - unroll_length(3) = 7 - instructions = UnrollInstructions( - condition_name="data", - unroll_length=unroll_length, - num_unrolls=100, - randomize=False, - ) - indices = solver.decide_starting_indices(y_data, instructions) + w2 = AutoregressiveSolver.weighting_strategy(step_losses, eps=None) + assert torch.isclose(w2.sum(), torch.tensor(1.0), atol=1e-6) - assert len(indices) == maximum_number_of_unrolls +def test_trainer_integration_one_epoch(y_data): + windows = AutoregressiveSolver.unroll(y_data, unroll_length=5, num_unrolls=None, randomize=False) + class Problem(AbstractProblem): + output_variables = None + input_variables = None + conditions = {"data": DataCondition(input=windows)} -def test_no_valid_indices_when_unroll_too_long(solver, y_data): - """When unroll_length >= n_timesteps, no valid indices exist.""" - instructions = UnrollInstructions( - condition_name="data", - unroll_length=NUM_TIMESTEPS + 1, - randomize=False, - ) - indices = solver.decide_starting_indices(y_data, instructions) - print(indices) - assert len(indices) == 0 - - -def test_unroll_window_shape(solver, y_data): - """Unroll windows should have correct shapes.""" - instructions = UnrollInstructions( - condition_name="data", - unroll_length=4, - num_unrolls=2, - randomize=False, - ) - initial_data, unroll_data = solver.create_unroll_windows( - y_data, instructions + solver = AutoregressiveSolver( + problem=Problem(), + model=ExactModel(), + optimizer=TorchOptimizer(torch.optim.AdamW, lr=1e-2), + conditions_settings={"data": {"eps": None, "aggregation_strategy": torch.sum}}, ) - assert initial_data.shape == (2, NUM_FEATURES) # [num_unrolls, features] - assert unroll_data.shape == ( - 2, - 4, - NUM_FEATURES, - ) # [num_unrolls, unroll_length, features] - - -def test_unroll_windows_content(solver, y_data): - """Verify actual content of unroll windows.""" - instructions = UnrollInstructions( - condition_name="data", - unroll_length=3, - num_unrolls=2, - randomize=False, - ) - initial_data, unroll_data = solver.create_unroll_windows( - y_data, instructions + trainer = Trainer( + solver=solver, + max_epochs=1, ) + trainer.train() - # initial_data[i] should be y_data[i] - assert torch.equal(initial_data[0], y_data[0]) - assert torch.equal(initial_data[1], y_data[1]) - - # unroll_data[i] should be y_data[i+1 : i+1+unroll_length] - assert torch.equal(unroll_data[0], y_data[1:4]) - assert torch.equal(unroll_data[1], y_data[2:5]) + # Just check we didn't produce NaNs somewhere + with torch.no_grad(): + loss = solver.loss_data(windows[:4], eps=None, aggregation_strategy=torch.sum) + assert torch.isfinite(loss) \ No newline at end of file From 6175854b0d6f131c72715177905bd73d9dc2f301 Mon Sep 17 00:00:00 2001 From: Davide Miotti Date: Thu, 15 Jan 2026 17:45:31 +0100 Subject: [PATCH 07/10] format code --- .../autoregressive_solver.py | 32 +++++++++++----- .../autoregressive_solver_interface.py | 8 ++-- .../test_solver/test_autoregressive_solver.py | 37 +++++++++++++++---- 3 files changed, 55 insertions(+), 22 deletions(-) diff --git a/pina/solver/autoregressive_solver/autoregressive_solver.py b/pina/solver/autoregressive_solver/autoregressive_solver.py index eb38b1b0c..349300c7e 100644 --- a/pina/solver/autoregressive_solver/autoregressive_solver.py +++ b/pina/solver/autoregressive_solver/autoregressive_solver.py @@ -5,6 +5,7 @@ from .autoregressive_solver_interface import AutoregressiveSolverInterface from typing import List + class AutoregressiveSolver( AutoregressiveSolverInterface, SingleSolverInterface ): @@ -77,7 +78,7 @@ def __init__( weighting=weighting, use_lt=use_lt, ) - + @staticmethod def unroll( data, unroll_length: int, num_unrolls=None, randomize: bool = True @@ -107,14 +108,23 @@ def unroll( data, unroll_length, num_unrolls, randomize ) if starts.numel() == 0: - return torch.empty((0, unroll_length + 1, *data.shape[1:]), device=data.device) + return torch.empty( + (0, unroll_length + 1, *data.shape[1:]), device=data.device + ) - windows = [data[int(start): int(start) + unroll_length + 1] for start in starts] + windows = [ + data[int(start) : int(start) + unroll_length + 1] + for start in starts + ] - return torch.stack(windows, dim=0) #[num_unrolls, unroll_length + 1, *data.shape[1:]] + return torch.stack( + windows, dim=0 + ) # [num_unrolls, unroll_length + 1, *data.shape[1:]] @staticmethod - def decide_starting_indices(data, unroll_length, num_unrolls=None, randomize=True): + def decide_starting_indices( + data, unroll_length, num_unrolls=None, randomize=True + ): """ Determine starting indices for unroll windows. @@ -167,7 +177,7 @@ def loss_data(self, unroll, eps=None, aggregation_strategy=None): # batch dimensition is unroll.shape[0] -the number of unrolls- Twin = unroll.shape[1] - current_state = unroll[:, 0, ...] # first time step of each batch + current_state = unroll[:, 0, ...] # first time step of each batch losses = [] for step in range(1, Twin): @@ -180,15 +190,15 @@ def loss_data(self, unroll, eps=None, aggregation_strategy=None): current_state = predicted_state step_losses = torch.stack(losses) # [unroll_length] - + with torch.no_grad(): weights = AutoregressiveSolver.weighting_strategy(step_losses, eps) if aggregation_strategy is None: aggregation_strategy = torch.sum - + return aggregation_strategy(step_losses * weights) - + @staticmethod def weighting_strategy(step_losses, eps=None, clamp=50.0): """ @@ -207,7 +217,9 @@ def weighting_strategy(step_losses, eps=None, clamp=50.0): weight = torch.ones_like(step_losses) / step_losses.numel() else: log_w = -eps * torch.cumsum(step_losses, dim=0) - log_w = torch.clamp(log_w, min=-clamp, max=clamp) # prevent overflow + log_w = torch.clamp( + log_w, min=-clamp, max=clamp + ) # prevent overflow weight = torch.softmax(log_w, dim=0) # weights = torch.exp(-eps * torch.cumsum(step_losses, dim=0)) return weight diff --git a/pina/solver/autoregressive_solver/autoregressive_solver_interface.py b/pina/solver/autoregressive_solver/autoregressive_solver_interface.py index b76fd842e..b526dd1e4 100644 --- a/pina/solver/autoregressive_solver/autoregressive_solver_interface.py +++ b/pina/solver/autoregressive_solver/autoregressive_solver_interface.py @@ -11,6 +11,7 @@ from ...condition import DataCondition from typing import Optional + class AutoregressiveSolverInterface(SolverInterface): """ Base class for autoregressive solvers. @@ -60,11 +61,10 @@ def optimization_cycle(self, batch): condition_loss = {} for condition_name, points in batch: - condition_settings = self.conditions_settings.get(condition_name, {}) - loss = self.loss_data( - points["input"], - **condition_settings + condition_settings = self.conditions_settings.get( + condition_name, {} ) + loss = self.loss_data(points["input"], **condition_settings) condition_loss[condition_name] = loss return condition_loss diff --git a/tests/test_solver/test_autoregressive_solver.py b/tests/test_solver/test_autoregressive_solver.py index d1497b24d..0a6e77767 100644 --- a/tests/test_solver/test_autoregressive_solver.py +++ b/tests/test_solver/test_autoregressive_solver.py @@ -19,10 +19,12 @@ def _make_series(T=NUM_TIMESTEPS, F=NUM_FEATURES): y[t + 1] = 0.95 * y[t] return y + @pytest.fixture def y_data(): return _make_series() + # crate a test Model class ExactModel(torch.nn.Module): """ @@ -42,21 +44,30 @@ def forward(self, x): # Tests start here ============================================== + def test_unroll_shape_and_content(y_data): # unroll_length=4 -> Twin=5 - w = AutoregressiveSolver.unroll(y_data, unroll_length=4, num_unrolls=2, randomize=False) + w = AutoregressiveSolver.unroll( + y_data, unroll_length=4, num_unrolls=2, randomize=False + ) assert w.shape == (2, 5, NUM_FEATURES) # deterministic starts: 0 and 1 assert torch.allclose(w[0], y_data[0:5]) assert torch.allclose(w[1], y_data[1:6]) + def test_decide_starting_indices_edge_cases(y_data): - idx = AutoregressiveSolver.decide_starting_indices(y_data, unroll_length=3, num_unrolls=None, randomize=False) + idx = AutoregressiveSolver.decide_starting_indices( + y_data, unroll_length=3, num_unrolls=None, randomize=False + ) # T=10, Twin=4 => last_start=6 => 0..6 assert torch.equal(idx, torch.arange(7)) idx_empty = AutoregressiveSolver.decide_starting_indices( - y_data, unroll_length=NUM_TIMESTEPS + 5, num_unrolls=None, randomize=False + y_data, + unroll_length=NUM_TIMESTEPS + 5, + num_unrolls=None, + randomize=False, ) assert idx_empty.numel() == 0 @@ -64,7 +75,8 @@ def test_decide_starting_indices_edge_cases(y_data): def test_exact_model(y_data): windows = AutoregressiveSolver.unroll( - y_data, unroll_length=5, num_unrolls=4, randomize=False) + y_data, unroll_length=5, num_unrolls=4, randomize=False + ) class Problem(AbstractProblem): output_variables = None @@ -86,6 +98,7 @@ class Problem(AbstractProblem): loss = solver.loss_data(windows, **conditions_settings["data_condition"]) assert torch.isclose(loss, torch.tensor(0.0), atol=1e-6) + def test_predict_matches_ground_truth(y_data): class Problem(AbstractProblem): output_variables = None @@ -98,6 +111,7 @@ class Problem(AbstractProblem): assert pred.shape == y_data.shape assert torch.allclose(pred, y_data, atol=1e-6) + def test_weighting_strategy_is_finite_and_normalized(): step_losses = torch.tensor([1.0, 2.0, 3.0]) w = AutoregressiveSolver.weighting_strategy(step_losses, eps=1.0) @@ -107,8 +121,11 @@ def test_weighting_strategy_is_finite_and_normalized(): w2 = AutoregressiveSolver.weighting_strategy(step_losses, eps=None) assert torch.isclose(w2.sum(), torch.tensor(1.0), atol=1e-6) + def test_trainer_integration_one_epoch(y_data): - windows = AutoregressiveSolver.unroll(y_data, unroll_length=5, num_unrolls=None, randomize=False) + windows = AutoregressiveSolver.unroll( + y_data, unroll_length=5, num_unrolls=None, randomize=False + ) class Problem(AbstractProblem): output_variables = None @@ -119,7 +136,9 @@ class Problem(AbstractProblem): problem=Problem(), model=ExactModel(), optimizer=TorchOptimizer(torch.optim.AdamW, lr=1e-2), - conditions_settings={"data": {"eps": None, "aggregation_strategy": torch.sum}}, + conditions_settings={ + "data": {"eps": None, "aggregation_strategy": torch.sum} + }, ) trainer = Trainer( @@ -130,5 +149,7 @@ class Problem(AbstractProblem): # Just check we didn't produce NaNs somewhere with torch.no_grad(): - loss = solver.loss_data(windows[:4], eps=None, aggregation_strategy=torch.sum) - assert torch.isfinite(loss) \ No newline at end of file + loss = solver.loss_data( + windows[:4], eps=None, aggregation_strategy=torch.sum + ) + assert torch.isfinite(loss) From 2eadb5a6fc47343b0af8acef91799f9981ca7c70 Mon Sep 17 00:00:00 2001 From: Davide Miotti Date: Fri, 16 Jan 2026 18:53:00 +0100 Subject: [PATCH 08/10] improve stability of adaptive weights --- autoregressive_codice_prova.py | 185 ++++++++++++++++++ .../autoregressive_solver.py | 88 ++++++--- .../autoregressive_solver_interface.py | 21 +- 3 files changed, 260 insertions(+), 34 deletions(-) create mode 100644 autoregressive_codice_prova.py diff --git a/autoregressive_codice_prova.py b/autoregressive_codice_prova.py new file mode 100644 index 000000000..69b194d35 --- /dev/null +++ b/autoregressive_codice_prova.py @@ -0,0 +1,185 @@ +import torch +import matplotlib.pyplot as plt + +from pina import Trainer +from pina.optim import TorchOptimizer +from pina.problem import AbstractProblem +from pina.condition.data_condition import DataCondition +from pina.solver import AutoregressiveSolver + +NUM_TIMESTEPS = 100 +NUM_FEATURES = 15 +USE_TEST_MODEL = False + +# ============================================================================ +# DATA +# ============================================================================ + +torch.manual_seed(42) + +y = torch.zeros(NUM_TIMESTEPS, NUM_FEATURES) +y[0] = torch.rand(NUM_FEATURES) # Random initial state + +for t in range(NUM_TIMESTEPS - 1): + y[t + 1] = 0.95 * y[t] # + 0.05 * torch.sin(y[t].sum()) + +# ============================================================================ +# TRAINING +# ============================================================================ + +class SimpleModel(torch.nn.Module): + def __init__(self): + super().__init__() + self.layers = torch.nn.Sequential( + torch.nn.Linear(y.shape[1], 15), + torch.nn.Tanh(), + # torch.nn.Dropout(0.1), + torch.nn.Linear(15, y.shape[1]), + ) + + def forward(self, x): + return x + self.layers(x) + + +class TestModel(torch.nn.Module): + """ + Debug model that implements the EXACT transformation rule. + y[t+1] = 0.95 * y[t] + Expected loss is zero + """ + + def __init__(self, data_series=None): + super().__init__() + self.dummy_param = torch.nn.Parameter(torch.zeros(1)) + + def forward(self, x): + next_state = 0.95 * x # + 0.05 * torch.sin(x.sum(dim=1, keepdim=True)) + return next_state + 0.0 * self.dummy_param + +# create a problem with duplicated data conditions +class Problem(AbstractProblem): + output_variables = None + input_variables = None + + # create two different unroll datasets: short and medium + y_short = AutoregressiveSolver.unroll( + y, unroll_length=4, num_unrolls=20, randomize=False + ) + y_medium = AutoregressiveSolver.unroll( + y, unroll_length=10, num_unrolls=15, randomize=False + ) + y_long = AutoregressiveSolver.unroll( + y, unroll_length=20, num_unrolls=10, randomize=False + ) + + conditions = {} + + inactive_conditions = { + "short": DataCondition(input=y_short), + "medium": DataCondition(input=y_medium), + "long": DataCondition(input=y_long), + } + + # Settings kept separate from the DataCondition objects + conditions_settings = { + "short": {"eps": 0.1}, + "medium": {"eps": 1.0}, + "long": {"eps": 2.0}, + } + + +problem = Problem() + +# helper that allows to activate or replace a condition at runtime +def activate_condition(problem, name, data=None, settings=None): + """ + Activate a single condition by name. + + `conditions_settings` is left untouched unless `settings` is explicitly + provided and no entry exists yet for `name`. + """ + # if data is provided, (re)register condition in inactive store + if data is not None: + problem.inactive_conditions[name] = DataCondition(input=data) + + problem.conditions = {} + problem.conditions[name] = problem.inactive_conditions[name] + + if settings is not None: + problem.conditions_settings[name] = settings + +# configure solver and trainer +solver = AutoregressiveSolver( + problem=problem, + model=TestModel() if USE_TEST_MODEL else SimpleModel(), + optimizer=TorchOptimizer(torch.optim.AdamW, lr=0.011), +) + + +print("Beginning phase 1: training with 'short' condition only") +activate_condition(problem, "short") +trainer1 = Trainer(solver, max_epochs=300, accelerator="cpu", enable_model_summary=False) +trainer1.train() + +print("Beginning phase 2: training with 'medium' condition added") +activate_condition(problem, "medium") +trainer2 = Trainer(solver, max_epochs=500, accelerator="cpu", enable_model_summary=False) +trainer2.train() + +print("Beginning phase 3: training with 'long' condition added") +activate_condition(problem, "long") +trainer3 = Trainer(solver, max_epochs=900, accelerator="cpu", enable_model_summary=False) +trainer3.train() + + +# ============================================================================ +test_start_idx = 50 +num_prediction_steps = 49 +initial_state = y[test_start_idx] # Shape: [features] +predictions = solver.predict(initial_state, num_prediction_steps) +actual = y[test_start_idx : test_start_idx + num_prediction_steps + 1] + +print("\n=== PREDICTION DEBUG ===") +for i in range(min(10, num_prediction_steps)): + pred_val = predictions[i].mean().item() + actual_val = actual[i].mean().item() + error = (predictions[i] - actual[i]).abs().mean().item() + print(f"Step {i}: pred={pred_val:.4f}, actual={actual_val:.4f}, error={error:.4f}") + +total_mse = torch.nn.functional.mse_loss(predictions[1:], actual[1:]) +print(f"\nOverall MSE (all {num_prediction_steps} steps): {total_mse:.6f}") + +# visualize single dof +dof_to_plot = [0, 3, 6, 9, 12] +colors = [ + "r", + "g", + "b", + "c", + "m", + "y", + "k", +] +plt.figure(figsize=(10, 6)) +for dof, color in zip(dof_to_plot, colors): + plt.plot( + range(test_start_idx, test_start_idx + num_prediction_steps + 1), + actual[:, dof].numpy(), + label="Actual", + marker="o", + color=color, + markerfacecolor="none", + ) + plt.plot( + range(test_start_idx, test_start_idx + num_prediction_steps + 1), + predictions[:, dof].numpy(), + label="Predicted", + marker="x", + color=color, + ) + +plt.title(f"Autoregressive Predictions vs Actual, MRSE: {total_mse:.6f}") +plt.legend() +plt.xlabel("Timestep") +plt.savefig(f"autoregressive_predictions.png") +plt.close() diff --git a/pina/solver/autoregressive_solver/autoregressive_solver.py b/pina/solver/autoregressive_solver/autoregressive_solver.py index 349300c7e..e377804c7 100644 --- a/pina/solver/autoregressive_solver/autoregressive_solver.py +++ b/pina/solver/autoregressive_solver/autoregressive_solver.py @@ -4,6 +4,9 @@ from pina.condition import DataCondition from .autoregressive_solver_interface import AutoregressiveSolverInterface from typing import List +import logging + +logger = logging.getLogger(__name__) class AutoregressiveSolver( @@ -39,12 +42,12 @@ def __init__( self, problem, model, - conditions_settings={}, loss=None, optimizer=None, scheduler=None, weighting=None, use_lt=False, + N_epochs_with_same_weights=10, ): """ Initialization of the :class:`AutoregressiveSolver` class. @@ -66,18 +69,24 @@ def __init__( If ``None``, uniform weighting is used. Default is ``None``. :param bool use_lt: Whether to use LabelTensors. Default is ``False``. + :param int N_epochs_with_same_weights: Number of epochs to keep the same adaptive weights + before recomputing them. Default is ``10``. """ super().__init__( problem=problem, model=model, - conditions_settings=conditions_settings, loss=loss, optimizer=optimizer, scheduler=scheduler, weighting=weighting, use_lt=use_lt, ) + # cache for per-condition adaptive weights and epoch-based update control + # this is the most generic way to implement periodic weight updates I found + self._cached_weights = {} + self._epochs_since_update = 0 + self.N_epochs_with_same_weights = N_epochs_with_same_weights @staticmethod def unroll( @@ -156,7 +165,7 @@ def decide_starting_indices( return indices - def loss_data(self, unroll, eps=None, aggregation_strategy=None): + def loss_data(self, unroll, eps=None, aggregation_strategy=None, condition_name=None): """ Compute the autoregressive multi-step data loss. @@ -187,42 +196,73 @@ def loss_data(self, unroll, eps=None, aggregation_strategy=None): target_state = unroll[:, step, ...] # [num_unrolls, features] step_loss = self._loss_fn(predicted_state, target_state) losses.append(step_loss) + + if logger.isEnabledFor(logging.DEBUG) and (step <= 3 or torch.isnan(step_loss)): + logger.debug( + " Step %d: loss=%.4e, pred=[%.3f, %.3f]", + step, + float(step_loss.item()), + float(predicted_state.min()), + float(predicted_state.max()), + ) + current_state = predicted_state step_losses = torch.stack(losses) # [unroll_length] with torch.no_grad(): - weights = AutoregressiveSolver.weighting_strategy(step_losses, eps) + condition_name = condition_name or "default" + weights = self.get_weights(condition_name, step_losses, eps) + if logger.isEnabledFor(logging.DEBUG): + logger.debug(" Losses: %s", step_losses.detach().cpu().numpy().round(4)) + logger.debug(" Weights: %s", weights.cpu().numpy().round(4)) + logger.debug(" Weight ratio: %.1f", float(weights.max() / weights.min())) if aggregation_strategy is None: aggregation_strategy = torch.sum return aggregation_strategy(step_losses * weights) - @staticmethod - def weighting_strategy(step_losses, eps=None, clamp=50.0): + def _compute_adaptive_weights(self, step_losses, eps): """ - Compute normalized weights for per-step losses. - - :param torch.Tensor step_losses: 1D tensor of per-step losses with shape - ``[Twin - 1]``. - :param float eps: Weighting strength. If ``None``, uniform normalized weights - are returned. Default is ``None``. - :param float clamp: Clamp applied to log-weights before normalization to - improve numerical stability. Default is ``50.0``. - :return: 1D tensor of weights with the same shape as ``step_losses`` that sums to 1. + Actual computation of adaptive weights. + :param torch.Tensor step_losses: 1D tensor of per-step losses. + :param float eps: Weighting parameter. + :return: Computed weights tensor. :rtype: torch.Tensor """ + print(f"updating weights, eps={eps}") + if eps is None: - weight = torch.ones_like(step_losses) / step_losses.numel() - else: - log_w = -eps * torch.cumsum(step_losses, dim=0) - log_w = torch.clamp( - log_w, min=-clamp, max=clamp - ) # prevent overflow - weight = torch.softmax(log_w, dim=0) - # weights = torch.exp(-eps * torch.cumsum(step_losses, dim=0)) - return weight + return torch.ones_like(step_losses) / step_losses.numel() + + log_w = torch.clamp(-eps * torch.cumsum(step_losses, dim=0), -20, 20) + return torch.softmax(log_w, dim=0) + + def get_weights(self, condition_name, step_losses, eps): + """ + Return cached weights or compute new ones. + :param str condition_name: Name of the condition. + :param torch.Tensor step_losses: 1D tensor of per-step losses. + :param float eps: Weighting parameter. + :return: Weights tensor. + :rtype: torch.Tensor + """ + cached = self._cached_weights.get(condition_name, None) + if cached is None: + cached = self._compute_adaptive_weights(step_losses, eps).cpu() + self._cached_weights[condition_name] = cached + return cached.to(step_losses.device) + + def on_train_epoch_end(self): + """ + Hook called by Lightning at the end of each epoch. + Forces periodic recalculation of weights by clearing the cache. + """ + self._epochs_since_update += 1 + if self._epochs_since_update >= self.N_epochs_with_same_weights: + self._cached_weights.clear() + self._epochs_since_update = 0 def predict(self, initial_state, num_steps): """ diff --git a/pina/solver/autoregressive_solver/autoregressive_solver_interface.py b/pina/solver/autoregressive_solver/autoregressive_solver_interface.py index b526dd1e4..788f6c081 100644 --- a/pina/solver/autoregressive_solver/autoregressive_solver_interface.py +++ b/pina/solver/autoregressive_solver/autoregressive_solver_interface.py @@ -23,14 +23,10 @@ class AutoregressiveSolverInterface(SolverInterface): ``[B, Twin, *state_shape]`` tensors passed to :meth:`loss_data`. """ - def __init__(self, conditions_settings={}, loss=None, **kwargs): + def __init__(self, loss=None, **kwargs): """ Initialization of the :class:`AutoregressiveSolverInterface` class. - :param dict conditions_settings: Dictionary mapping condition names to a - dictionary of keyword arguments forwarded to :meth:`loss_data`. - Example keys: ``eps``, ``aggregation_strategy``. - If ``None``, an empty dict is used. Default is ``None``. :param torch.nn.Module loss: Loss function to minimize. If ``None``, :class:`torch.nn.MSELoss` is used. Default is ``None``. :param kwargs: Additional keyword arguments forwarded to @@ -45,8 +41,6 @@ def __init__(self, conditions_settings={}, loss=None, **kwargs): check_consistency(loss, (LossInterface, _Loss), subclass=False) self._loss_fn = loss - self.conditions_settings = conditions_settings - def optimization_cycle(self, batch): """ Optimization cycle for this family of solvers. @@ -61,10 +55,17 @@ def optimization_cycle(self, batch): condition_loss = {} for condition_name, points in batch: - condition_settings = self.conditions_settings.get( - condition_name, {} + settings = {} + if hasattr(self.problem, "conditions_settings"): + settings = self.problem.conditions_settings.get( + condition_name, {} + ) + + loss = self.loss_data( + points["input"], + eps=settings.get("eps"), + condition_name=condition_name, ) - loss = self.loss_data(points["input"], **condition_settings) condition_loss[condition_name] = loss return condition_loss From 370c217490b601c2bdd2600ecee6e5fe6fda186f Mon Sep 17 00:00:00 2001 From: Davide Miotti Date: Mon, 19 Jan 2026 17:35:29 +0100 Subject: [PATCH 09/10] add end-to-end-test and improve mean --- .../autoregressive_solver.py | 106 +++++++---- .../test_solver/test_autoregressive_solver.py | 174 +++++++++++++++--- 2 files changed, 218 insertions(+), 62 deletions(-) diff --git a/pina/solver/autoregressive_solver/autoregressive_solver.py b/pina/solver/autoregressive_solver/autoregressive_solver.py index e377804c7..756c7e944 100644 --- a/pina/solver/autoregressive_solver/autoregressive_solver.py +++ b/pina/solver/autoregressive_solver/autoregressive_solver.py @@ -47,7 +47,7 @@ def __init__( scheduler=None, weighting=None, use_lt=False, - N_epochs_with_same_weights=10, + reset_weighting_at_epoch_start=True, ): """ Initialization of the :class:`AutoregressiveSolver` class. @@ -69,8 +69,11 @@ def __init__( If ``None``, uniform weighting is used. Default is ``None``. :param bool use_lt: Whether to use LabelTensors. Default is ``False``. - :param int N_epochs_with_same_weights: Number of epochs to keep the same adaptive weights - before recomputing them. Default is ``10``. + :param bool reset_weighting_at_epoch_start: If ``True``, resets + the running averages used for adaptive weighting at the start + of each epoch. Default is ``True``. This parameter is for an advanced + use case, setting it to False can improve stability, especially + when data per epoch are very scarse. """ super().__init__( @@ -82,11 +85,9 @@ def __init__( weighting=weighting, use_lt=use_lt, ) - # cache for per-condition adaptive weights and epoch-based update control - # this is the most generic way to implement periodic weight updates I found - self._cached_weights = {} - self._epochs_since_update = 0 - self.N_epochs_with_same_weights = N_epochs_with_same_weights + self._running_avg_step_losses = {} + self._running_step_counts = {} + self.reset_weighting_at_epoch_start = reset_weighting_at_epoch_start @staticmethod def unroll( @@ -165,7 +166,9 @@ def decide_starting_indices( return indices - def loss_data(self, unroll, eps=None, aggregation_strategy=None, condition_name=None): + def loss_data( + self, unroll, eps=None, aggregation_strategy=None, condition_name=None + ): """ Compute the autoregressive multi-step data loss. @@ -197,15 +200,15 @@ def loss_data(self, unroll, eps=None, aggregation_strategy=None, condition_name= step_loss = self._loss_fn(predicted_state, target_state) losses.append(step_loss) - if logger.isEnabledFor(logging.DEBUG) and (step <= 3 or torch.isnan(step_loss)): + if step <= 3 or torch.isnan(step_loss): logger.debug( " Step %d: loss=%.4e, pred=[%.3f, %.3f]", step, float(step_loss.item()), - float(predicted_state.min()), - float(predicted_state.max()), + float(predicted_state.detach().min()), + float(predicted_state.detach().max()), ) - + current_state = predicted_state step_losses = torch.stack(losses) # [unroll_length] @@ -213,16 +216,52 @@ def loss_data(self, unroll, eps=None, aggregation_strategy=None, condition_name= with torch.no_grad(): condition_name = condition_name or "default" weights = self.get_weights(condition_name, step_losses, eps) - if logger.isEnabledFor(logging.DEBUG): - logger.debug(" Losses: %s", step_losses.detach().cpu().numpy().round(4)) - logger.debug(" Weights: %s", weights.cpu().numpy().round(4)) - logger.debug(" Weight ratio: %.1f", float(weights.max() / weights.min())) + + logger.debug( + " Losses: %s", step_losses.detach().cpu().numpy().round(4) + ) + logger.debug(" Weights: %s", weights.cpu().numpy().round(4)) + logger.debug( + " Weight ratio: %.1f", float(weights.max() / weights.min()) + ) if aggregation_strategy is None: aggregation_strategy = torch.sum return aggregation_strategy(step_losses * weights) + def get_weights(self, condition_name, step_losses, eps): + """ + Return cached weights or compute new ones. + :param str condition_name: Name of the condition. + :param torch.Tensor step_losses: 1D tensor of per-step losses. + :param float eps: Weighting parameter. + :return: Weights tensor. + :rtype: torch.Tensor + """ + key = condition_name or "default" + x = step_losses.detach() + + if x.dim() != 1: + raise ValueError( + f"step_losses must be a 1D tensor, got shape {x.shape}" + ) + + if key not in self._running_avg_step_losses: + self._running_avg_step_losses[key] = x.clone() + self._running_step_counts[key] = 1 + else: + self._running_step_counts[key] += 1 + k = self._running_step_counts[key] + # update running average + self._running_avg_step_losses[key] += ( + x - self._running_avg_step_losses[key] + ) / k + + return self._compute_adaptive_weights( + self._running_avg_step_losses[key], eps + ) + def _compute_adaptive_weights(self, step_losses, eps): """ Actual computation of adaptive weights. @@ -231,38 +270,25 @@ def _compute_adaptive_weights(self, step_losses, eps): :return: Computed weights tensor. :rtype: torch.Tensor """ - print(f"updating weights, eps={eps}") + logger.debug(f"updating weights, eps={eps}") if eps is None: return torch.ones_like(step_losses) / step_losses.numel() + # normalize to mean 1 (avoid too large exponents) + step_losses = step_losses / (step_losses.mean() + 1e-12) + log_w = torch.clamp(-eps * torch.cumsum(step_losses, dim=0), -20, 20) return torch.softmax(log_w, dim=0) - def get_weights(self, condition_name, step_losses, eps): - """ - Return cached weights or compute new ones. - :param str condition_name: Name of the condition. - :param torch.Tensor step_losses: 1D tensor of per-step losses. - :param float eps: Weighting parameter. - :return: Weights tensor. - :rtype: torch.Tensor - """ - cached = self._cached_weights.get(condition_name, None) - if cached is None: - cached = self._compute_adaptive_weights(step_losses, eps).cpu() - self._cached_weights[condition_name] = cached - return cached.to(step_losses.device) - - def on_train_epoch_end(self): + def on_train_epoch_start(self): """ - Hook called by Lightning at the end of each epoch. - Forces periodic recalculation of weights by clearing the cache. + Hook called by Lightning at the beginning of each epoch. + Forces periodic cleaning of he dictionaries used for weighting estimate. """ - self._epochs_since_update += 1 - if self._epochs_since_update >= self.N_epochs_with_same_weights: - self._cached_weights.clear() - self._epochs_since_update = 0 + if self.reset_weighting_at_epoch_start: + self._running_avg_step_losses.clear() + self._running_step_counts.clear() def predict(self, initial_state, num_steps): """ diff --git a/tests/test_solver/test_autoregressive_solver.py b/tests/test_solver/test_autoregressive_solver.py index 0a6e77767..7126dec0e 100644 --- a/tests/test_solver/test_autoregressive_solver.py +++ b/tests/test_solver/test_autoregressive_solver.py @@ -7,11 +7,11 @@ from pina.condition.data_condition import DataCondition from pina.solver import AutoregressiveSolver -NUM_TIMESTEPS = 10 -NUM_FEATURES = 3 +# Set random seed for reproducibility +torch.manual_seed(42) -def _make_series(T=NUM_TIMESTEPS, F=NUM_FEATURES): +def _make_series(T, F): torch.manual_seed(42) y = torch.zeros(T, F) y[0] = torch.rand(F) @@ -20,12 +20,109 @@ def _make_series(T=NUM_TIMESTEPS, F=NUM_FEATURES): return y +### END-TO-END ############################################################################# + + +@pytest.fixture +def y_data_large(): + return _make_series(T=100, F=15) + + +class MinimalModel(torch.nn.Module): + """ + Minimal model that applies a linear transformation. + Used for end-to-end testing. Since the problem dynamic is linear, this model + should in principle learn the correct transformation. + """ + + def __init__(self): + super().__init__() + self.layers = torch.nn.Linear(15, 15, bias=False) + + def forward(self, x): + return x + self.layers(x) + + +def test_end_to_end(y_data_large): + """ + End-to-end test with MinimalModel. + This test performs a 3-phase training with increasing unroll lengths, shows how to use + the AutoregressiveSolver with curriculum learning + """ + + # AbstratProblem with empty conditions and conditions_settings to be filled later + class Problem(AbstractProblem): + output_variables = None + input_variables = None + conditions = {} + conditions_settings = {} + + problem = Problem() + + solver = AutoregressiveSolver( + problem=problem, + model=MinimalModel(), + optimizer=TorchOptimizer(torch.optim.AdamW, lr=0.018), + ) + # PHASE1: train with 'short' condition only + y_short = AutoregressiveSolver.unroll( + y_data_large, unroll_length=4, num_unrolls=20, randomize=False + ) + problem.conditions["short"] = DataCondition(input=y_short) + problem.conditions_settings["short"] = {"eps": 0.1} + trainer1 = Trainer( + solver, max_epochs=300, accelerator="cpu", enable_model_summary=False + ) + trainer1.train() + + # PHASE2: train with 'medium' condition only + y_medium = AutoregressiveSolver.unroll( + y_data_large, unroll_length=10, num_unrolls=15, randomize=False + ) + problem.conditions.clear() + problem.conditions["medium"] = DataCondition(input=y_medium) + problem.conditions_settings.clear() + problem.conditions_settings["medium"] = {"eps": 0.15} + trainer2 = Trainer( + solver, max_epochs=500, accelerator="cpu", enable_model_summary=False + ) + trainer2.train() + + # PHASE3: train with 'long' condition only + y_long = AutoregressiveSolver.unroll( + y_data_large, unroll_length=20, num_unrolls=10, randomize=False + ) + problem.conditions.clear() + problem.conditions["long"] = DataCondition(input=y_long) + problem.conditions_settings.clear() + problem.conditions_settings["long"] = {"eps": 0.3} + trainer3 = Trainer( + solver, max_epochs=400, accelerator="cpu", enable_model_summary=False + ) + trainer3.train() + + test_start_idx = 50 + num_predictions = 49 + start_state = y_data_large[test_start_idx] + ground_truth = y_data_large[ + test_start_idx : test_start_idx + num_predictions + 1 + ] + prediction = solver.predict(start_state, num_steps=num_predictions) + total_mse = torch.nn.functional.mse_loss(prediction[1:], ground_truth[1:]) + assert total_mse < 1e-6 + + +### UNIT TESTS ############################################################################# + +NUM_TIMESTEPS = 10 +NUM_FEATURES = 3 + + @pytest.fixture def y_data(): - return _make_series() + return _make_series(T=10, F=3) -# crate a test Model class ExactModel(torch.nn.Module): """ This model implements the EXACT transformation rule. @@ -42,9 +139,6 @@ def forward(self, x): return next_state + 0.0 * self.dummy_param -# Tests start here ============================================== - - def test_unroll_shape_and_content(y_data): # unroll_length=4 -> Twin=5 w = AutoregressiveSolver.unroll( @@ -73,7 +167,6 @@ def test_decide_starting_indices_edge_cases(y_data): def test_exact_model(y_data): - windows = AutoregressiveSolver.unroll( y_data, unroll_length=5, num_unrolls=4, randomize=False ) @@ -85,17 +178,18 @@ class Problem(AbstractProblem): "data_condition": DataCondition(input=windows), } - conditions_settings = { - "data_condition": {"eps": None, "aggregation_strategy": torch.sum}, - } solver = AutoregressiveSolver( problem=Problem(), - conditions_settings=conditions_settings, model=ExactModel(), optimizer=TorchOptimizer(torch.optim.AdamW, lr=0.01), ) - loss = solver.loss_data(windows, **conditions_settings["data_condition"]) + loss = solver.loss_data( + windows, + eps=None, + aggregation_strategy=torch.sum, + condition_name="data_condition", + ) assert torch.isclose(loss, torch.tensor(0.0), atol=1e-6) @@ -112,15 +206,27 @@ class Problem(AbstractProblem): assert torch.allclose(pred, y_data, atol=1e-6) -def test_weighting_strategy_is_finite_and_normalized(): +def test_adaptive_weights_are_finite_and_normalized(y_data): + class Problem(AbstractProblem): + output_variables = None + input_variables = None + conditions = {"data": DataCondition(input=y_data)} + + solver = AutoregressiveSolver(problem=Problem(), model=ExactModel()) + step_losses = torch.tensor([1.0, 2.0, 3.0]) - w = AutoregressiveSolver.weighting_strategy(step_losses, eps=1.0) + + w = solver._compute_adaptive_weights(step_losses, eps=1.0) assert torch.isfinite(w).all() assert torch.isclose(w.sum(), torch.tensor(1.0), atol=1e-6) - w2 = AutoregressiveSolver.weighting_strategy(step_losses, eps=None) + w2 = solver._compute_adaptive_weights(step_losses, eps=None) assert torch.isclose(w2.sum(), torch.tensor(1.0), atol=1e-6) + w3 = solver.get_weights("data", step_losses, eps=1.0) + assert torch.isfinite(w3).all() + assert torch.isclose(w3.sum(), torch.tensor(1.0), atol=1e-6) + def test_trainer_integration_one_epoch(y_data): windows = AutoregressiveSolver.unroll( @@ -136,20 +242,44 @@ class Problem(AbstractProblem): problem=Problem(), model=ExactModel(), optimizer=TorchOptimizer(torch.optim.AdamW, lr=1e-2), - conditions_settings={ - "data": {"eps": None, "aggregation_strategy": torch.sum} - }, ) trainer = Trainer( solver=solver, max_epochs=1, + accelerator="cpu", ) trainer.train() - # Just check we didn't produce NaNs somewhere with torch.no_grad(): loss = solver.loss_data( - windows[:4], eps=None, aggregation_strategy=torch.sum + windows[:4], + eps=None, + aggregation_strategy=torch.sum, + condition_name="data", ) assert torch.isfinite(loss) + + +def test_weight_cache_resets_on_epoch_start(y_data): + class Problem(AbstractProblem): + output_variables = None + input_variables = None + conditions = {"data": DataCondition(input=y_data)} + + solver = AutoregressiveSolver( + problem=Problem(), + model=ExactModel(), + reset_weighting_at_epoch_start=True, + ) + + step_losses = torch.tensor([1.0, 2.0, 3.0]) + + _ = solver.get_weights("data", step_losses, eps=1.0) + assert "data" in solver._running_avg_step_losses + assert "data" in solver._running_step_counts + + solver.on_train_epoch_start() + + assert solver._running_avg_step_losses == {} + assert solver._running_step_counts == {} From 70cf989b098fb4215224822a88016be9483751d1 Mon Sep 17 00:00:00 2001 From: Davide Miotti Date: Wed, 21 Jan 2026 16:31:09 +0100 Subject: [PATCH 10/10] add support for multiple time series --- .../autoregressive_solver.py | 82 ++++++++++--------- .../autoregressive_solver_interface.py | 10 +-- .../test_solver/test_autoregressive_solver.py | 57 ++++++++----- 3 files changed, 85 insertions(+), 64 deletions(-) diff --git a/pina/solver/autoregressive_solver/autoregressive_solver.py b/pina/solver/autoregressive_solver/autoregressive_solver.py index 756c7e944..5ea4c9b48 100644 --- a/pina/solver/autoregressive_solver/autoregressive_solver.py +++ b/pina/solver/autoregressive_solver/autoregressive_solver.py @@ -96,44 +96,60 @@ def unroll( """ Create unroll windows from time series data. - This is a pre-processing step. It slices the input time series into + This is a pre-processing step. It slices data into overlapping windows of length ``Twin = unroll_length + 1`` along the - time axis (axis 0). Each window contains the initial state and the + time axis which MUST BE data.shape[1]. + Each window contains the initial state and the subsequent target states used to compute a multi-step loss. - :param torch.Tensor data: Time series tensor with shape ``[T, *state_shape]``. - The first axis is interpreted as time. + Accepts: + - (B, T, *state_shape) tensor, where B is the number of time series, + T is the length of each time series. + + :param torch.Tensor data: Time series tensor with shape: + ``[B, T, *state_shape]`` B=data.shape[0] is the number of time series. :param int unroll_length: Number of transitions in each window. Each window has length ``unroll_length + 1``. :param int num_unrolls: Maximum number of windows to return. If ``None``, all valid windows are returned. Default is ``None``. :param bool randomize: If ``True``, starting indices are randomly permuted before applying ``num_unrolls``. Default is ``True``. - :return: Tensor of unroll windows with shape ``[Nw, unroll_length + 1, *state_shape]``. + :return: Tensor with shape ``[Nw, Twin, *state_shape]`` where Twin = unroll_length + 1 and + Nw = B*num_unrolls for B time seires. If no valid windows exist, returns an empty tensor with shape ``[0, unroll_length + 1, *state_shape]``. :rtype: torch.Tensor """ + if data.dim() < 3: + raise ValueError( + f"data must have at least 3 dimensions (B, T, *state_shape), got {data.shape}" + ) + + twin = unroll_length + 1 + n_steps = data.shape[ + 1 + ] # TODO: implement check that ensures that time dim is 1 starts = AutoregressiveSolver.decide_starting_indices( - data, unroll_length, num_unrolls, randomize + n_steps, unroll_length, num_unrolls, randomize, device=data.device ) if starts.numel() == 0: return torch.empty( - (0, unroll_length + 1, *data.shape[1:]), device=data.device + (0, unroll_length + 1, *data.shape[2:]), device=data.device ) windows = [ - data[int(start) : int(start) + unroll_length + 1] - for start in starts + data[:, int(s) : int(s) + twin, ...] for s in starts.tolist() ] - - return torch.stack( - windows, dim=0 - ) # [num_unrolls, unroll_length + 1, *data.shape[1:]] + windows = torch.stack(windows, dim=1) + return windows.reshape(-1, twin, *data.shape[2:]) @staticmethod def decide_starting_indices( - data, unroll_length, num_unrolls=None, randomize=True + n_steps: int, + unroll_length: int, + num_unrolls=None, + randomize=True, + device=None, ): """ Determine starting indices for unroll windows. @@ -141,25 +157,27 @@ def decide_starting_indices( Computes valid starting positions ensuring each window has enough subsequent time steps for the specified unroll length. - :param torch.Tensor data: Time series tensor with shape ``[T, *state_shape]``. + :param int n_steps: Total number of time steps in the data. :param int unroll_length: Number of transitions in each window. :param int num_unrolls: Maximum number of indices to return. If ``None``, all valid indices are returned. Default is ``None``. :param bool randomize: If ``True``, indices are randomly permuted before applying ``num_unrolls``. Default is ``True``. - :return: 1D tensor of starting indices with dtype ``torch.long``. + :param torch.device device: Device for the output tensor. If ``None``, uses the default device. + Default is ``None``. + :return: 1D tensor of starting indices for unroll windows. :rtype: torch.Tensor """ - n_step = int(data.shape[0]) - twin = int(unroll_length + 1) - last_start = n_step - twin + twin = unroll_length + 1 + last_start = n_steps - twin + if last_start < 0: - return torch.empty(0, dtype=torch.long, device=data.device) + return torch.empty(0, dtype=torch.long, device=device) - indices = torch.arange(last_start + 1, device=data.device) + indices = torch.arange(last_start + 1, device=device) if randomize: - indices = indices[torch.randperm(len(indices), device=data.device)] + indices = indices[torch.randperm(len(indices), device=device)] if num_unrolls is not None and num_unrolls < len(indices): indices = indices[:num_unrolls] @@ -298,25 +316,15 @@ def predict(self, initial_state, num_steps): a trajectory of length ``num_steps + 1`` (including the initial state). :param torch.Tensor initial_state: Starting state. Supported shapes: - - ``[n_features]`` (unbatched, 1D) - - ``[B, n_features]`` (batched) - More general tensors ``[*state_shape]`` / ``[B, *state_shape]`` are also - supported, provided the model can process them. + - ``[B, *state_shape]`` B: number of time series :param int num_steps: Number of future time steps to predict. :return: Predicted trajectory including the initial state. Shape: - - ``[num_steps + 1, *state_shape]`` if unbatched input - - ``[num_steps + 1, B, *state_shape]`` if batched input + - ``[B, num_steps + 1, *state_shape]`` if batched input :rtype: torch.Tensor """ self.eval() # Set model to evaluation mode current_state = initial_state - - added_batch = False - if current_state.dim() == 1: - current_state = current_state.unsqueeze(0) - added_batch = True - predictions = [current_state] with torch.no_grad(): for step in range(num_steps): @@ -324,8 +332,8 @@ def predict(self, initial_state, num_steps): predictions.append(next_state) current_state = next_state - out = torch.stack(predictions, dim=0) - if added_batch: - out = out[:, 0, ...] # remove batch dimension + out = torch.stack( + predictions, dim=1 + ) # [B, num_steps + 1, *state_shape] return out diff --git a/pina/solver/autoregressive_solver/autoregressive_solver_interface.py b/pina/solver/autoregressive_solver/autoregressive_solver_interface.py index 788f6c081..a284ac4e2 100644 --- a/pina/solver/autoregressive_solver/autoregressive_solver_interface.py +++ b/pina/solver/autoregressive_solver/autoregressive_solver_interface.py @@ -83,13 +83,11 @@ def predict(self, initial_state, num_steps): """ Generate predictions by recursively applying the model. - :param torch.Tensor initial_state: Starting state. Supported shapes are: - - ``[*state_shape]`` (unbatched) - - ``[B, *state_shape]`` (batched) + :param torch.Tensor initial_state: Starting state. + Supported shapes is ``[B, *state_shape]`` (batched) :param int num_steps: Number of future time steps to predict. - :return: Predicted trajectory including the initial state. Shape: - - ``[num_steps + 1, *state_shape]`` if unbatched input - - ``[num_steps + 1, B, *state_shape]`` if batched input + :return: Predicted trajectory including the initial state. + Shape:``[B, num_steps + 1, *state_shape]`` if batched input :rtype: torch.Tensor """ pass diff --git a/tests/test_solver/test_autoregressive_solver.py b/tests/test_solver/test_autoregressive_solver.py index 7126dec0e..46d8ccdc4 100644 --- a/tests/test_solver/test_autoregressive_solver.py +++ b/tests/test_solver/test_autoregressive_solver.py @@ -9,14 +9,17 @@ # Set random seed for reproducibility torch.manual_seed(42) +NUM_TIMESERIES = 2 def _make_series(T, F): torch.manual_seed(42) - y = torch.zeros(T, F) + num_t_series = NUM_TIMESERIES + y = torch.zeros(num_t_series, T, F) y[0] = torch.rand(F) + y[1] = torch.rand(F) for t in range(T - 1): - y[t + 1] = 0.95 * y[t] + y[:, t + 1] = 0.95 * y[:, t] return y @@ -62,7 +65,7 @@ class Problem(AbstractProblem): solver = AutoregressiveSolver( problem=problem, model=MinimalModel(), - optimizer=TorchOptimizer(torch.optim.AdamW, lr=0.018), + optimizer=TorchOptimizer(torch.optim.AdamW, lr=0.008), ) # PHASE1: train with 'short' condition only y_short = AutoregressiveSolver.unroll( @@ -82,9 +85,9 @@ class Problem(AbstractProblem): problem.conditions.clear() problem.conditions["medium"] = DataCondition(input=y_medium) problem.conditions_settings.clear() - problem.conditions_settings["medium"] = {"eps": 0.15} + problem.conditions_settings["medium"] = {"eps": 0.2} trainer2 = Trainer( - solver, max_epochs=500, accelerator="cpu", enable_model_summary=False + solver, max_epochs=1500, accelerator="cpu", enable_model_summary=False ) trainer2.train() @@ -95,20 +98,24 @@ class Problem(AbstractProblem): problem.conditions.clear() problem.conditions["long"] = DataCondition(input=y_long) problem.conditions_settings.clear() - problem.conditions_settings["long"] = {"eps": 0.3} + problem.conditions_settings["long"] = {"eps": 0.2} trainer3 = Trainer( - solver, max_epochs=400, accelerator="cpu", enable_model_summary=False + solver, max_epochs=4000, accelerator="cpu", enable_model_summary=False ) trainer3.train() test_start_idx = 50 num_predictions = 49 - start_state = y_data_large[test_start_idx] + start_state = y_data_large[:, test_start_idx, :] ground_truth = y_data_large[ - test_start_idx : test_start_idx + num_predictions + 1 + :, test_start_idx : test_start_idx + num_predictions + 1, : ] prediction = solver.predict(start_state, num_steps=num_predictions) - total_mse = torch.nn.functional.mse_loss(prediction[1:], ground_truth[1:]) + + assert prediction.shape == ground_truth.shape + total_mse = torch.nn.functional.mse_loss( + prediction[:, 1:, :], ground_truth[:, 1:, :] + ) assert total_mse < 1e-6 @@ -140,25 +147,33 @@ def forward(self, x): def test_unroll_shape_and_content(y_data): - # unroll_length=4 -> Twin=5 + B, T, F = y_data.shape + w = AutoregressiveSolver.unroll( y_data, unroll_length=4, num_unrolls=2, randomize=False ) - assert w.shape == (2, 5, NUM_FEATURES) - # deterministic starts: 0 and 1 - assert torch.allclose(w[0], y_data[0:5]) - assert torch.allclose(w[1], y_data[1:6]) + assert w.shape == (2 * B, 5, F) + + # windows for first time series + assert torch.allclose(w[0], y_data[0, 0:5, :]) + assert torch.allclose(w[1], y_data[0, 1:6, :]) + + # windows for second time series + assert torch.allclose(w[2], y_data[1, 0:5, :]) + assert torch.allclose(w[3], y_data[1, 1:6, :]) def test_decide_starting_indices_edge_cases(y_data): + n_steps = y_data.shape[1] + # print("n_steps is ",n_steps) idx = AutoregressiveSolver.decide_starting_indices( - y_data, unroll_length=3, num_unrolls=None, randomize=False + n_steps, unroll_length=3, num_unrolls=None, randomize=False ) # T=10, Twin=4 => last_start=6 => 0..6 assert torch.equal(idx, torch.arange(7)) idx_empty = AutoregressiveSolver.decide_starting_indices( - y_data, + n_steps, unroll_length=NUM_TIMESTEPS + 5, num_unrolls=None, randomize=False, @@ -201,7 +216,7 @@ class Problem(AbstractProblem): solver = AutoregressiveSolver(problem=Problem(), model=ExactModel()) - pred = solver.predict(y_data[0], num_steps=NUM_TIMESTEPS - 1) + pred = solver.predict(y_data[:, 0, :], num_steps=NUM_TIMESTEPS - 1) assert pred.shape == y_data.shape assert torch.allclose(pred, y_data, atol=1e-6) @@ -216,9 +231,9 @@ class Problem(AbstractProblem): step_losses = torch.tensor([1.0, 2.0, 3.0]) - w = solver._compute_adaptive_weights(step_losses, eps=1.0) - assert torch.isfinite(w).all() - assert torch.isclose(w.sum(), torch.tensor(1.0), atol=1e-6) + w1 = solver._compute_adaptive_weights(step_losses, eps=1.0) + assert torch.isfinite(w1).all() + assert torch.isclose(w1.sum(), torch.tensor(1.0), atol=1e-6) w2 = solver._compute_adaptive_weights(step_losses, eps=None) assert torch.isclose(w2.sum(), torch.tensor(1.0), atol=1e-6)