diff --git a/autoregressive_codice_prova.py b/autoregressive_codice_prova.py new file mode 100644 index 000000000..69b194d35 --- /dev/null +++ b/autoregressive_codice_prova.py @@ -0,0 +1,185 @@ +import torch +import matplotlib.pyplot as plt + +from pina import Trainer +from pina.optim import TorchOptimizer +from pina.problem import AbstractProblem +from pina.condition.data_condition import DataCondition +from pina.solver import AutoregressiveSolver + +NUM_TIMESTEPS = 100 +NUM_FEATURES = 15 +USE_TEST_MODEL = False + +# ============================================================================ +# DATA +# ============================================================================ + +torch.manual_seed(42) + +y = torch.zeros(NUM_TIMESTEPS, NUM_FEATURES) +y[0] = torch.rand(NUM_FEATURES) # Random initial state + +for t in range(NUM_TIMESTEPS - 1): + y[t + 1] = 0.95 * y[t] # + 0.05 * torch.sin(y[t].sum()) + +# ============================================================================ +# TRAINING +# ============================================================================ + +class SimpleModel(torch.nn.Module): + def __init__(self): + super().__init__() + self.layers = torch.nn.Sequential( + torch.nn.Linear(y.shape[1], 15), + torch.nn.Tanh(), + # torch.nn.Dropout(0.1), + torch.nn.Linear(15, y.shape[1]), + ) + + def forward(self, x): + return x + self.layers(x) + + +class TestModel(torch.nn.Module): + """ + Debug model that implements the EXACT transformation rule. + y[t+1] = 0.95 * y[t] + Expected loss is zero + """ + + def __init__(self, data_series=None): + super().__init__() + self.dummy_param = torch.nn.Parameter(torch.zeros(1)) + + def forward(self, x): + next_state = 0.95 * x # + 0.05 * torch.sin(x.sum(dim=1, keepdim=True)) + return next_state + 0.0 * self.dummy_param + +# create a problem with duplicated data conditions +class Problem(AbstractProblem): + output_variables = None + input_variables = None + + # create two different unroll datasets: short and medium + y_short = AutoregressiveSolver.unroll( + y, unroll_length=4, num_unrolls=20, randomize=False + ) + y_medium = AutoregressiveSolver.unroll( + y, unroll_length=10, num_unrolls=15, randomize=False + ) + y_long = AutoregressiveSolver.unroll( + y, unroll_length=20, num_unrolls=10, randomize=False + ) + + conditions = {} + + inactive_conditions = { + "short": DataCondition(input=y_short), + "medium": DataCondition(input=y_medium), + "long": DataCondition(input=y_long), + } + + # Settings kept separate from the DataCondition objects + conditions_settings = { + "short": {"eps": 0.1}, + "medium": {"eps": 1.0}, + "long": {"eps": 2.0}, + } + + +problem = Problem() + +# helper that allows to activate or replace a condition at runtime +def activate_condition(problem, name, data=None, settings=None): + """ + Activate a single condition by name. + + `conditions_settings` is left untouched unless `settings` is explicitly + provided and no entry exists yet for `name`. + """ + # if data is provided, (re)register condition in inactive store + if data is not None: + problem.inactive_conditions[name] = DataCondition(input=data) + + problem.conditions = {} + problem.conditions[name] = problem.inactive_conditions[name] + + if settings is not None: + problem.conditions_settings[name] = settings + +# configure solver and trainer +solver = AutoregressiveSolver( + problem=problem, + model=TestModel() if USE_TEST_MODEL else SimpleModel(), + optimizer=TorchOptimizer(torch.optim.AdamW, lr=0.011), +) + + +print("Beginning phase 1: training with 'short' condition only") +activate_condition(problem, "short") +trainer1 = Trainer(solver, max_epochs=300, accelerator="cpu", enable_model_summary=False) +trainer1.train() + +print("Beginning phase 2: training with 'medium' condition added") +activate_condition(problem, "medium") +trainer2 = Trainer(solver, max_epochs=500, accelerator="cpu", enable_model_summary=False) +trainer2.train() + +print("Beginning phase 3: training with 'long' condition added") +activate_condition(problem, "long") +trainer3 = Trainer(solver, max_epochs=900, accelerator="cpu", enable_model_summary=False) +trainer3.train() + + +# ============================================================================ +test_start_idx = 50 +num_prediction_steps = 49 +initial_state = y[test_start_idx] # Shape: [features] +predictions = solver.predict(initial_state, num_prediction_steps) +actual = y[test_start_idx : test_start_idx + num_prediction_steps + 1] + +print("\n=== PREDICTION DEBUG ===") +for i in range(min(10, num_prediction_steps)): + pred_val = predictions[i].mean().item() + actual_val = actual[i].mean().item() + error = (predictions[i] - actual[i]).abs().mean().item() + print(f"Step {i}: pred={pred_val:.4f}, actual={actual_val:.4f}, error={error:.4f}") + +total_mse = torch.nn.functional.mse_loss(predictions[1:], actual[1:]) +print(f"\nOverall MSE (all {num_prediction_steps} steps): {total_mse:.6f}") + +# visualize single dof +dof_to_plot = [0, 3, 6, 9, 12] +colors = [ + "r", + "g", + "b", + "c", + "m", + "y", + "k", +] +plt.figure(figsize=(10, 6)) +for dof, color in zip(dof_to_plot, colors): + plt.plot( + range(test_start_idx, test_start_idx + num_prediction_steps + 1), + actual[:, dof].numpy(), + label="Actual", + marker="o", + color=color, + markerfacecolor="none", + ) + plt.plot( + range(test_start_idx, test_start_idx + num_prediction_steps + 1), + predictions[:, dof].numpy(), + label="Predicted", + marker="x", + color=color, + ) + +plt.title(f"Autoregressive Predictions vs Actual, MRSE: {total_mse:.6f}") +plt.legend() +plt.xlabel("Timestep") +plt.savefig(f"autoregressive_predictions.png") +plt.close() diff --git a/pina/solver/__init__.py b/pina/solver/__init__.py index 43f18078f..e7d48e2b3 100644 --- a/pina/solver/__init__.py +++ b/pina/solver/__init__.py @@ -18,6 +18,7 @@ "DeepEnsembleSupervisedSolver", "DeepEnsemblePINN", "GAROM", + "AutoregressiveSolver", ] from .solver import SolverInterface, SingleSolverInterface, MultiSolverInterface @@ -41,3 +42,7 @@ DeepEnsemblePINN, ) from .garom import GAROM +from .autoregressive_solver import ( + AutoregressiveSolver, + AutoregressiveSolverInterface, +) diff --git a/pina/solver/autoregressive_solver/__init__.py b/pina/solver/autoregressive_solver/__init__.py new file mode 100644 index 000000000..9ef7c43e1 --- /dev/null +++ b/pina/solver/autoregressive_solver/__init__.py @@ -0,0 +1,4 @@ +__all__ = ["AutoregressiveSolver", "AutoregressiveSolverInterface"] + +from .autoregressive_solver import AutoregressiveSolver +from .autoregressive_solver_interface import AutoregressiveSolverInterface diff --git a/pina/solver/autoregressive_solver/autoregressive_solver.py b/pina/solver/autoregressive_solver/autoregressive_solver.py new file mode 100644 index 000000000..5ea4c9b48 --- /dev/null +++ b/pina/solver/autoregressive_solver/autoregressive_solver.py @@ -0,0 +1,339 @@ +import torch +from pina.utils import check_consistency +from pina.solver.solver import SingleSolverInterface +from pina.condition import DataCondition +from .autoregressive_solver_interface import AutoregressiveSolverInterface +from typing import List +import logging + +logger = logging.getLogger(__name__) + + +class AutoregressiveSolver( + AutoregressiveSolverInterface, SingleSolverInterface +): + r""" + Autoregressive Solver for learning dynamical systems. + + This solver learns a one-step transition function + :math:`\mathcal{M}: \mathbb{R}^n \rightarrow \mathbb{R}^n` that maps + a state :math:`\mathbf{y}_t` to the next state :math:`\mathbf{y}_{t+1}`. + + During training, the model is unrolled over multiple time steps to + learn long-term dynamics. Given an initial state :math:`\mathbf{y}_0`, + the model generates predictions recursively: + + .. math:: + \hat{\mathbf{y}}_{t+1} = \mathcal{M}(\hat{\mathbf{y}}_t), + \quad \hat{\mathbf{y}}_0 = \mathbf{y}_0 + + The loss is computed over the entire unroll window: + + .. math:: + \mathcal{L} = \sum_{t=1}^{T} w_t \|\hat{\mathbf{y}}_t - \mathbf{y}_t\|^2 + + where :math:`w_t` are exponential weights (if ``eps`` is specified) + that down-weight later predictions to stabilize training. + """ + + accepted_conditions_types = DataCondition + + def __init__( + self, + problem, + model, + loss=None, + optimizer=None, + scheduler=None, + weighting=None, + use_lt=False, + reset_weighting_at_epoch_start=True, + ): + """ + Initialization of the :class:`AutoregressiveSolver` class. + + :param AbstractProblem problem: The problem instance containing + the time series data conditions. + :param torch.nn.Module model: Neural network that predicts the + next state given the current state. + :param torch.nn.Module loss: Loss function to minimize. + If ``None``, :class:`torch.nn.MSELoss` is used. + Default is ``None``. + :param TorchOptimizer optimizer: Optimizer for training. + If ``None``, :class:`torch.optim.Adam` is used. + Default is ``None``. + :param TorchScheduler scheduler: Learning rate scheduler. + If ``None``, no scheduling is applied. Default is ``None``. + :param WeightingInterface weighting: Weighting scheme for + combining losses from multiple conditions. + If ``None``, uniform weighting is used. Default is ``None``. + :param bool use_lt: Whether to use LabelTensors. + Default is ``False``. + :param bool reset_weighting_at_epoch_start: If ``True``, resets + the running averages used for adaptive weighting at the start + of each epoch. Default is ``True``. This parameter is for an advanced + use case, setting it to False can improve stability, especially + when data per epoch are very scarse. + """ + + super().__init__( + problem=problem, + model=model, + loss=loss, + optimizer=optimizer, + scheduler=scheduler, + weighting=weighting, + use_lt=use_lt, + ) + self._running_avg_step_losses = {} + self._running_step_counts = {} + self.reset_weighting_at_epoch_start = reset_weighting_at_epoch_start + + @staticmethod + def unroll( + data, unroll_length: int, num_unrolls=None, randomize: bool = True + ): + """ + Create unroll windows from time series data. + + This is a pre-processing step. It slices data into + overlapping windows of length ``Twin = unroll_length + 1`` along the + time axis which MUST BE data.shape[1]. + Each window contains the initial state and the + subsequent target states used to compute a multi-step loss. + + Accepts: + - (B, T, *state_shape) tensor, where B is the number of time series, + T is the length of each time series. + + :param torch.Tensor data: Time series tensor with shape: + ``[B, T, *state_shape]`` B=data.shape[0] is the number of time series. + :param int unroll_length: Number of transitions in each window. + Each window has length ``unroll_length + 1``. + :param int num_unrolls: Maximum number of windows to return. If ``None``, + all valid windows are returned. Default is ``None``. + :param bool randomize: If ``True``, starting indices are randomly permuted + before applying ``num_unrolls``. Default is ``True``. + :return: Tensor with shape ``[Nw, Twin, *state_shape]`` where Twin = unroll_length + 1 and + Nw = B*num_unrolls for B time seires. + If no valid windows exist, returns an empty tensor with shape + ``[0, unroll_length + 1, *state_shape]``. + :rtype: torch.Tensor + """ + if data.dim() < 3: + raise ValueError( + f"data must have at least 3 dimensions (B, T, *state_shape), got {data.shape}" + ) + + twin = unroll_length + 1 + n_steps = data.shape[ + 1 + ] # TODO: implement check that ensures that time dim is 1 + starts = AutoregressiveSolver.decide_starting_indices( + n_steps, unroll_length, num_unrolls, randomize, device=data.device + ) + if starts.numel() == 0: + return torch.empty( + (0, unroll_length + 1, *data.shape[2:]), device=data.device + ) + + windows = [ + data[:, int(s) : int(s) + twin, ...] for s in starts.tolist() + ] + windows = torch.stack(windows, dim=1) + return windows.reshape(-1, twin, *data.shape[2:]) + + @staticmethod + def decide_starting_indices( + n_steps: int, + unroll_length: int, + num_unrolls=None, + randomize=True, + device=None, + ): + """ + Determine starting indices for unroll windows. + + Computes valid starting positions ensuring each window has enough + subsequent time steps for the specified unroll length. + + :param int n_steps: Total number of time steps in the data. + :param int unroll_length: Number of transitions in each window. + :param int num_unrolls: Maximum number of indices to return. If ``None``, + all valid indices are returned. Default is ``None``. + :param bool randomize: If ``True``, indices are randomly permuted before + applying ``num_unrolls``. Default is ``True``. + :param torch.device device: Device for the output tensor. If ``None``, uses the default device. + Default is ``None``. + :return: 1D tensor of starting indices for unroll windows. + :rtype: torch.Tensor + """ + twin = unroll_length + 1 + last_start = n_steps - twin + + if last_start < 0: + return torch.empty(0, dtype=torch.long, device=device) + + indices = torch.arange(last_start + 1, device=device) + + if randomize: + indices = indices[torch.randperm(len(indices), device=device)] + + if num_unrolls is not None and num_unrolls < len(indices): + indices = indices[:num_unrolls] + + return indices + + def loss_data( + self, unroll, eps=None, aggregation_strategy=None, condition_name=None + ): + """ + Compute the autoregressive multi-step data loss. + + The input ``unroll`` is expected to be a batch of precomputed unroll windows + with shape ``[B, Twin, *state_shape]``. The first element along the ``Twin`` + axis is used as current state, and the following elements are the targets. + + :param torch.Tensor unroll: Batch of unroll windows with shape + ``[B, Twin, *state_shape]`` where ``Twin = unroll_length + 1``. + :param float eps: If provided, applies step weighting through + :meth:`weighting_strategy`. If ``None``, uniform normalized weights are used. + Default is ``None``. + :param callable aggregation_strategy: Reduction applied to the weighted per-step + losses. If ``None``, :func:`torch.sum` is used. Default is ``None``. + :return: Scalar loss value for the given batch. + :rtype: torch.Tensor + """ + # batch dimensition is unroll.shape[0] -the number of unrolls- + Twin = unroll.shape[1] + + current_state = unroll[:, 0, ...] # first time step of each batch + losses = [] + for step in range(1, Twin): + + predicted_state = self.forward( + current_state + ) # [num_unrolls, features] + target_state = unroll[:, step, ...] # [num_unrolls, features] + step_loss = self._loss_fn(predicted_state, target_state) + losses.append(step_loss) + + if step <= 3 or torch.isnan(step_loss): + logger.debug( + " Step %d: loss=%.4e, pred=[%.3f, %.3f]", + step, + float(step_loss.item()), + float(predicted_state.detach().min()), + float(predicted_state.detach().max()), + ) + + current_state = predicted_state + + step_losses = torch.stack(losses) # [unroll_length] + + with torch.no_grad(): + condition_name = condition_name or "default" + weights = self.get_weights(condition_name, step_losses, eps) + + logger.debug( + " Losses: %s", step_losses.detach().cpu().numpy().round(4) + ) + logger.debug(" Weights: %s", weights.cpu().numpy().round(4)) + logger.debug( + " Weight ratio: %.1f", float(weights.max() / weights.min()) + ) + + if aggregation_strategy is None: + aggregation_strategy = torch.sum + + return aggregation_strategy(step_losses * weights) + + def get_weights(self, condition_name, step_losses, eps): + """ + Return cached weights or compute new ones. + :param str condition_name: Name of the condition. + :param torch.Tensor step_losses: 1D tensor of per-step losses. + :param float eps: Weighting parameter. + :return: Weights tensor. + :rtype: torch.Tensor + """ + key = condition_name or "default" + x = step_losses.detach() + + if x.dim() != 1: + raise ValueError( + f"step_losses must be a 1D tensor, got shape {x.shape}" + ) + + if key not in self._running_avg_step_losses: + self._running_avg_step_losses[key] = x.clone() + self._running_step_counts[key] = 1 + else: + self._running_step_counts[key] += 1 + k = self._running_step_counts[key] + # update running average + self._running_avg_step_losses[key] += ( + x - self._running_avg_step_losses[key] + ) / k + + return self._compute_adaptive_weights( + self._running_avg_step_losses[key], eps + ) + + def _compute_adaptive_weights(self, step_losses, eps): + """ + Actual computation of adaptive weights. + :param torch.Tensor step_losses: 1D tensor of per-step losses. + :param float eps: Weighting parameter. + :return: Computed weights tensor. + :rtype: torch.Tensor + """ + logger.debug(f"updating weights, eps={eps}") + + if eps is None: + return torch.ones_like(step_losses) / step_losses.numel() + + # normalize to mean 1 (avoid too large exponents) + step_losses = step_losses / (step_losses.mean() + 1e-12) + + log_w = torch.clamp(-eps * torch.cumsum(step_losses, dim=0), -20, 20) + return torch.softmax(log_w, dim=0) + + def on_train_epoch_start(self): + """ + Hook called by Lightning at the beginning of each epoch. + Forces periodic cleaning of he dictionaries used for weighting estimate. + """ + if self.reset_weighting_at_epoch_start: + self._running_avg_step_losses.clear() + self._running_step_counts.clear() + + def predict(self, initial_state, num_steps): + """ + Generate predictions by recursively applying the model. + + Starting from ``initial_state``, applies the model repeatedly to generate + a trajectory of length ``num_steps + 1`` (including the initial state). + + :param torch.Tensor initial_state: Starting state. Supported shapes: + - ``[B, *state_shape]`` B: number of time series + :param int num_steps: Number of future time steps to predict. + :return: Predicted trajectory including the initial state. Shape: + - ``[B, num_steps + 1, *state_shape]`` if batched input + :rtype: torch.Tensor + """ + self.eval() # Set model to evaluation mode + + current_state = initial_state + predictions = [current_state] + with torch.no_grad(): + for step in range(num_steps): + next_state = self.forward(current_state) + predictions.append(next_state) + current_state = next_state + + out = torch.stack( + predictions, dim=1 + ) # [B, num_steps + 1, *state_shape] + + return out diff --git a/pina/solver/autoregressive_solver/autoregressive_solver_interface.py b/pina/solver/autoregressive_solver/autoregressive_solver_interface.py new file mode 100644 index 000000000..a284ac4e2 --- /dev/null +++ b/pina/solver/autoregressive_solver/autoregressive_solver_interface.py @@ -0,0 +1,103 @@ +"""Module for the Autoregressive solver interface.""" + +from abc import abstractmethod +import torch +from torch.nn.modules.loss import _Loss +from dataclasses import dataclass + +from ..solver import SolverInterface +from ...utils import check_consistency +from ...loss.loss_interface import LossInterface +from ...condition import DataCondition +from typing import Optional + + +class AutoregressiveSolverInterface(SolverInterface): + """ + Base class for autoregressive solvers. + + The training pipeline expects :class:`~pina.condition.data_condition.DataCondition` + conditions. In the recommended configuration, each DataCondition input is a + collection of unroll windows with shape ``[Nw, Twin, *state_shape]``, where + ``Twin = unroll_length + 1``. The Trainer batches along the first axis, producing + ``[B, Twin, *state_shape]`` tensors passed to :meth:`loss_data`. + """ + + def __init__(self, loss=None, **kwargs): + """ + Initialization of the :class:`AutoregressiveSolverInterface` class. + + :param torch.nn.Module loss: Loss function to minimize. If ``None``, + :class:`torch.nn.MSELoss` is used. Default is ``None``. + :param kwargs: Additional keyword arguments forwarded to + :class:`~pina.solver.solver.SolverInterface`. + """ + + super().__init__(**kwargs) + + if loss is None: + loss = torch.nn.MSELoss() + + check_consistency(loss, (LossInterface, _Loss), subclass=False) + self._loss_fn = loss + + def optimization_cycle(self, batch): + """ + Optimization cycle for this family of solvers. + Iterates over each condition and each time applies the specialized loss_data function. + + :param list[tuple[str, dict]] batch: List of tuples where each + tuple contains a condition name and a dictionary with the + ``"input"`` key mapping to the time series tensor. + :return: Dictionary mapping condition names to computed loss values. + :rtype: dict[str, torch.Tensor] + """ + + condition_loss = {} + for condition_name, points in batch: + settings = {} + if hasattr(self.problem, "conditions_settings"): + settings = self.problem.conditions_settings.get( + condition_name, {} + ) + + loss = self.loss_data( + points["input"], + eps=settings.get("eps"), + condition_name=condition_name, + ) + condition_loss[condition_name] = loss + return condition_loss + + @abstractmethod + def loss_data(self, input, **settings): + """ + Compute the data loss for each condition. + This method must be implemented by subclasses to define the + specific loss computation strategy. + """ + pass + + @abstractmethod + def predict(self, initial_state, num_steps): + """ + Generate predictions by recursively applying the model. + + :param torch.Tensor initial_state: Starting state. + Supported shapes is ``[B, *state_shape]`` (batched) + :param int num_steps: Number of future time steps to predict. + :return: Predicted trajectory including the initial state. + Shape:``[B, num_steps + 1, *state_shape]`` if batched input + :rtype: torch.Tensor + """ + pass + + @property + def loss(self): + """ + The loss function to be minimized. + + :return: The loss function to be minimized. + :rtype: torch.nn.Module + """ + return self._loss_fn diff --git a/tests/test_solver/test_autoregressive_solver.py b/tests/test_solver/test_autoregressive_solver.py new file mode 100644 index 000000000..46d8ccdc4 --- /dev/null +++ b/tests/test_solver/test_autoregressive_solver.py @@ -0,0 +1,300 @@ +import pytest +import torch + +from pina import Trainer +from pina.optim import TorchOptimizer +from pina.problem import AbstractProblem +from pina.condition.data_condition import DataCondition +from pina.solver import AutoregressiveSolver + +# Set random seed for reproducibility +torch.manual_seed(42) +NUM_TIMESERIES = 2 + + +def _make_series(T, F): + torch.manual_seed(42) + num_t_series = NUM_TIMESERIES + y = torch.zeros(num_t_series, T, F) + y[0] = torch.rand(F) + y[1] = torch.rand(F) + for t in range(T - 1): + y[:, t + 1] = 0.95 * y[:, t] + return y + + +### END-TO-END ############################################################################# + + +@pytest.fixture +def y_data_large(): + return _make_series(T=100, F=15) + + +class MinimalModel(torch.nn.Module): + """ + Minimal model that applies a linear transformation. + Used for end-to-end testing. Since the problem dynamic is linear, this model + should in principle learn the correct transformation. + """ + + def __init__(self): + super().__init__() + self.layers = torch.nn.Linear(15, 15, bias=False) + + def forward(self, x): + return x + self.layers(x) + + +def test_end_to_end(y_data_large): + """ + End-to-end test with MinimalModel. + This test performs a 3-phase training with increasing unroll lengths, shows how to use + the AutoregressiveSolver with curriculum learning + """ + + # AbstratProblem with empty conditions and conditions_settings to be filled later + class Problem(AbstractProblem): + output_variables = None + input_variables = None + conditions = {} + conditions_settings = {} + + problem = Problem() + + solver = AutoregressiveSolver( + problem=problem, + model=MinimalModel(), + optimizer=TorchOptimizer(torch.optim.AdamW, lr=0.008), + ) + # PHASE1: train with 'short' condition only + y_short = AutoregressiveSolver.unroll( + y_data_large, unroll_length=4, num_unrolls=20, randomize=False + ) + problem.conditions["short"] = DataCondition(input=y_short) + problem.conditions_settings["short"] = {"eps": 0.1} + trainer1 = Trainer( + solver, max_epochs=300, accelerator="cpu", enable_model_summary=False + ) + trainer1.train() + + # PHASE2: train with 'medium' condition only + y_medium = AutoregressiveSolver.unroll( + y_data_large, unroll_length=10, num_unrolls=15, randomize=False + ) + problem.conditions.clear() + problem.conditions["medium"] = DataCondition(input=y_medium) + problem.conditions_settings.clear() + problem.conditions_settings["medium"] = {"eps": 0.2} + trainer2 = Trainer( + solver, max_epochs=1500, accelerator="cpu", enable_model_summary=False + ) + trainer2.train() + + # PHASE3: train with 'long' condition only + y_long = AutoregressiveSolver.unroll( + y_data_large, unroll_length=20, num_unrolls=10, randomize=False + ) + problem.conditions.clear() + problem.conditions["long"] = DataCondition(input=y_long) + problem.conditions_settings.clear() + problem.conditions_settings["long"] = {"eps": 0.2} + trainer3 = Trainer( + solver, max_epochs=4000, accelerator="cpu", enable_model_summary=False + ) + trainer3.train() + + test_start_idx = 50 + num_predictions = 49 + start_state = y_data_large[:, test_start_idx, :] + ground_truth = y_data_large[ + :, test_start_idx : test_start_idx + num_predictions + 1, : + ] + prediction = solver.predict(start_state, num_steps=num_predictions) + + assert prediction.shape == ground_truth.shape + total_mse = torch.nn.functional.mse_loss( + prediction[:, 1:, :], ground_truth[:, 1:, :] + ) + assert total_mse < 1e-6 + + +### UNIT TESTS ############################################################################# + +NUM_TIMESTEPS = 10 +NUM_FEATURES = 3 + + +@pytest.fixture +def y_data(): + return _make_series(T=10, F=3) + + +class ExactModel(torch.nn.Module): + """ + This model implements the EXACT transformation rule. + y[t+1] = 0.95 * y[t] + Expected loss is zero + """ + + def __init__(self, data_series=None): + super().__init__() + self.dummy_param = torch.nn.Parameter(torch.zeros(1)) + + def forward(self, x): + next_state = 0.95 * x + return next_state + 0.0 * self.dummy_param + + +def test_unroll_shape_and_content(y_data): + B, T, F = y_data.shape + + w = AutoregressiveSolver.unroll( + y_data, unroll_length=4, num_unrolls=2, randomize=False + ) + assert w.shape == (2 * B, 5, F) + + # windows for first time series + assert torch.allclose(w[0], y_data[0, 0:5, :]) + assert torch.allclose(w[1], y_data[0, 1:6, :]) + + # windows for second time series + assert torch.allclose(w[2], y_data[1, 0:5, :]) + assert torch.allclose(w[3], y_data[1, 1:6, :]) + + +def test_decide_starting_indices_edge_cases(y_data): + n_steps = y_data.shape[1] + # print("n_steps is ",n_steps) + idx = AutoregressiveSolver.decide_starting_indices( + n_steps, unroll_length=3, num_unrolls=None, randomize=False + ) + # T=10, Twin=4 => last_start=6 => 0..6 + assert torch.equal(idx, torch.arange(7)) + + idx_empty = AutoregressiveSolver.decide_starting_indices( + n_steps, + unroll_length=NUM_TIMESTEPS + 5, + num_unrolls=None, + randomize=False, + ) + assert idx_empty.numel() == 0 + + +def test_exact_model(y_data): + windows = AutoregressiveSolver.unroll( + y_data, unroll_length=5, num_unrolls=4, randomize=False + ) + + class Problem(AbstractProblem): + output_variables = None + input_variables = None + conditions = { + "data_condition": DataCondition(input=windows), + } + + solver = AutoregressiveSolver( + problem=Problem(), + model=ExactModel(), + optimizer=TorchOptimizer(torch.optim.AdamW, lr=0.01), + ) + + loss = solver.loss_data( + windows, + eps=None, + aggregation_strategy=torch.sum, + condition_name="data_condition", + ) + assert torch.isclose(loss, torch.tensor(0.0), atol=1e-6) + + +def test_predict_matches_ground_truth(y_data): + class Problem(AbstractProblem): + output_variables = None + input_variables = None + conditions = {"data": DataCondition(input=y_data)} + + solver = AutoregressiveSolver(problem=Problem(), model=ExactModel()) + + pred = solver.predict(y_data[:, 0, :], num_steps=NUM_TIMESTEPS - 1) + assert pred.shape == y_data.shape + assert torch.allclose(pred, y_data, atol=1e-6) + + +def test_adaptive_weights_are_finite_and_normalized(y_data): + class Problem(AbstractProblem): + output_variables = None + input_variables = None + conditions = {"data": DataCondition(input=y_data)} + + solver = AutoregressiveSolver(problem=Problem(), model=ExactModel()) + + step_losses = torch.tensor([1.0, 2.0, 3.0]) + + w1 = solver._compute_adaptive_weights(step_losses, eps=1.0) + assert torch.isfinite(w1).all() + assert torch.isclose(w1.sum(), torch.tensor(1.0), atol=1e-6) + + w2 = solver._compute_adaptive_weights(step_losses, eps=None) + assert torch.isclose(w2.sum(), torch.tensor(1.0), atol=1e-6) + + w3 = solver.get_weights("data", step_losses, eps=1.0) + assert torch.isfinite(w3).all() + assert torch.isclose(w3.sum(), torch.tensor(1.0), atol=1e-6) + + +def test_trainer_integration_one_epoch(y_data): + windows = AutoregressiveSolver.unroll( + y_data, unroll_length=5, num_unrolls=None, randomize=False + ) + + class Problem(AbstractProblem): + output_variables = None + input_variables = None + conditions = {"data": DataCondition(input=windows)} + + solver = AutoregressiveSolver( + problem=Problem(), + model=ExactModel(), + optimizer=TorchOptimizer(torch.optim.AdamW, lr=1e-2), + ) + + trainer = Trainer( + solver=solver, + max_epochs=1, + accelerator="cpu", + ) + trainer.train() + + with torch.no_grad(): + loss = solver.loss_data( + windows[:4], + eps=None, + aggregation_strategy=torch.sum, + condition_name="data", + ) + assert torch.isfinite(loss) + + +def test_weight_cache_resets_on_epoch_start(y_data): + class Problem(AbstractProblem): + output_variables = None + input_variables = None + conditions = {"data": DataCondition(input=y_data)} + + solver = AutoregressiveSolver( + problem=Problem(), + model=ExactModel(), + reset_weighting_at_epoch_start=True, + ) + + step_losses = torch.tensor([1.0, 2.0, 3.0]) + + _ = solver.get_weights("data", step_losses, eps=1.0) + assert "data" in solver._running_avg_step_losses + assert "data" in solver._running_step_counts + + solver.on_train_epoch_start() + + assert solver._running_avg_step_losses == {} + assert solver._running_step_counts == {}