Skip to content
Open
185 changes: 185 additions & 0 deletions autoregressive_codice_prova.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,185 @@
import torch
import matplotlib.pyplot as plt

from pina import Trainer
from pina.optim import TorchOptimizer
from pina.problem import AbstractProblem
from pina.condition.data_condition import DataCondition
from pina.solver import AutoregressiveSolver

NUM_TIMESTEPS = 100
NUM_FEATURES = 15
USE_TEST_MODEL = False

# ============================================================================
# DATA
# ============================================================================

torch.manual_seed(42)

y = torch.zeros(NUM_TIMESTEPS, NUM_FEATURES)
y[0] = torch.rand(NUM_FEATURES) # Random initial state

for t in range(NUM_TIMESTEPS - 1):
y[t + 1] = 0.95 * y[t] # + 0.05 * torch.sin(y[t].sum())

# ============================================================================
# TRAINING
# ============================================================================

class SimpleModel(torch.nn.Module):
def __init__(self):
super().__init__()
self.layers = torch.nn.Sequential(
torch.nn.Linear(y.shape[1], 15),
torch.nn.Tanh(),
# torch.nn.Dropout(0.1),
torch.nn.Linear(15, y.shape[1]),
)

def forward(self, x):
return x + self.layers(x)


class TestModel(torch.nn.Module):
"""
Debug model that implements the EXACT transformation rule.
y[t+1] = 0.95 * y[t]
Expected loss is zero
"""

def __init__(self, data_series=None):
super().__init__()
self.dummy_param = torch.nn.Parameter(torch.zeros(1))

def forward(self, x):
next_state = 0.95 * x # + 0.05 * torch.sin(x.sum(dim=1, keepdim=True))
return next_state + 0.0 * self.dummy_param

# create a problem with duplicated data conditions
class Problem(AbstractProblem):
output_variables = None
input_variables = None

# create two different unroll datasets: short and medium
y_short = AutoregressiveSolver.unroll(
y, unroll_length=4, num_unrolls=20, randomize=False
)
y_medium = AutoregressiveSolver.unroll(
y, unroll_length=10, num_unrolls=15, randomize=False
)
y_long = AutoregressiveSolver.unroll(
y, unroll_length=20, num_unrolls=10, randomize=False
)

conditions = {}

inactive_conditions = {
"short": DataCondition(input=y_short),
"medium": DataCondition(input=y_medium),
"long": DataCondition(input=y_long),
}

# Settings kept separate from the DataCondition objects
conditions_settings = {
"short": {"eps": 0.1},
"medium": {"eps": 1.0},
"long": {"eps": 2.0},
}


problem = Problem()

# helper that allows to activate or replace a condition at runtime
def activate_condition(problem, name, data=None, settings=None):
"""
Activate a single condition by name.

`conditions_settings` is left untouched unless `settings` is explicitly
provided and no entry exists yet for `name`.
"""
# if data is provided, (re)register condition in inactive store
if data is not None:
problem.inactive_conditions[name] = DataCondition(input=data)

problem.conditions = {}
problem.conditions[name] = problem.inactive_conditions[name]

if settings is not None:
problem.conditions_settings[name] = settings

# configure solver and trainer
solver = AutoregressiveSolver(
problem=problem,
model=TestModel() if USE_TEST_MODEL else SimpleModel(),
optimizer=TorchOptimizer(torch.optim.AdamW, lr=0.011),
)


print("Beginning phase 1: training with 'short' condition only")
activate_condition(problem, "short")
trainer1 = Trainer(solver, max_epochs=300, accelerator="cpu", enable_model_summary=False)
trainer1.train()

print("Beginning phase 2: training with 'medium' condition added")
activate_condition(problem, "medium")
trainer2 = Trainer(solver, max_epochs=500, accelerator="cpu", enable_model_summary=False)
trainer2.train()

print("Beginning phase 3: training with 'long' condition added")
activate_condition(problem, "long")
trainer3 = Trainer(solver, max_epochs=900, accelerator="cpu", enable_model_summary=False)
trainer3.train()


# ============================================================================
test_start_idx = 50
num_prediction_steps = 49
initial_state = y[test_start_idx] # Shape: [features]
predictions = solver.predict(initial_state, num_prediction_steps)
actual = y[test_start_idx : test_start_idx + num_prediction_steps + 1]

print("\n=== PREDICTION DEBUG ===")
for i in range(min(10, num_prediction_steps)):
pred_val = predictions[i].mean().item()
actual_val = actual[i].mean().item()
error = (predictions[i] - actual[i]).abs().mean().item()
print(f"Step {i}: pred={pred_val:.4f}, actual={actual_val:.4f}, error={error:.4f}")

total_mse = torch.nn.functional.mse_loss(predictions[1:], actual[1:])
print(f"\nOverall MSE (all {num_prediction_steps} steps): {total_mse:.6f}")

# visualize single dof
dof_to_plot = [0, 3, 6, 9, 12]
colors = [
"r",
"g",
"b",
"c",
"m",
"y",
"k",
]
plt.figure(figsize=(10, 6))
for dof, color in zip(dof_to_plot, colors):
plt.plot(
range(test_start_idx, test_start_idx + num_prediction_steps + 1),
actual[:, dof].numpy(),
label="Actual",
marker="o",
color=color,
markerfacecolor="none",
)
plt.plot(
range(test_start_idx, test_start_idx + num_prediction_steps + 1),
predictions[:, dof].numpy(),
label="Predicted",
marker="x",
color=color,
)

plt.title(f"Autoregressive Predictions vs Actual, MRSE: {total_mse:.6f}")
plt.legend()
plt.xlabel("Timestep")
plt.savefig(f"autoregressive_predictions.png")
plt.close()
5 changes: 5 additions & 0 deletions pina/solver/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
"DeepEnsembleSupervisedSolver",
"DeepEnsemblePINN",
"GAROM",
"AutoregressiveSolver",
]

from .solver import SolverInterface, SingleSolverInterface, MultiSolverInterface
Expand All @@ -41,3 +42,7 @@
DeepEnsemblePINN,
)
from .garom import GAROM
from .autoregressive_solver import (
AutoregressiveSolver,
AutoregressiveSolverInterface,
)
4 changes: 4 additions & 0 deletions pina/solver/autoregressive_solver/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,4 @@
__all__ = ["AutoregressiveSolver", "AutoregressiveSolverInterface"]

from .autoregressive_solver import AutoregressiveSolver
from .autoregressive_solver_interface import AutoregressiveSolverInterface
Loading