diff --git a/docs/source/handlers.rst b/docs/source/handlers.rst index 49c84dab28..56142819a1 100644 --- a/docs/source/handlers.rst +++ b/docs/source/handlers.rst @@ -83,6 +83,12 @@ Panoptic Quality metrics handler :members: +Calibration Error metrics handler +--------------------------------- +.. autoclass:: CalibrationError + :members: + + Mean squared error metrics handler ---------------------------------- .. autoclass:: MeanSquaredError diff --git a/docs/source/metrics.rst b/docs/source/metrics.rst index 88fbea7ff0..ae3d4a0c92 100644 --- a/docs/source/metrics.rst +++ b/docs/source/metrics.rst @@ -180,6 +180,15 @@ Metrics .. autoclass:: MetricsReloadedCategorical :members: +`Calibration Error` +------------------- +.. autofunction:: calibration_binning + +.. autoclass:: CalibrationReduction + :members: + +.. autoclass:: CalibrationErrorMetric + :members: Utilities diff --git a/monai/handlers/__init__.py b/monai/handlers/__init__.py index ed5db8a7f3..7fc7b6df57 100644 --- a/monai/handlers/__init__.py +++ b/monai/handlers/__init__.py @@ -12,6 +12,7 @@ from __future__ import annotations from .average_precision import AveragePrecision +from .calibration import CalibrationError from .checkpoint_loader import CheckpointLoader from .checkpoint_saver import CheckpointSaver from .classification_saver import ClassificationSaver diff --git a/monai/handlers/calibration.py b/monai/handlers/calibration.py new file mode 100644 index 0000000000..c5fac0919f --- /dev/null +++ b/monai/handlers/calibration.py @@ -0,0 +1,109 @@ +# Copyright (c) MONAI Consortium +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# http://www.apache.org/licenses/LICENSE-2.0 +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import annotations + +from collections.abc import Callable + +from monai.handlers.ignite_metric import IgniteMetricHandler +from monai.metrics import CalibrationErrorMetric, CalibrationReduction +from monai.utils import MetricReduction + +__all__ = ["CalibrationError"] + + +class CalibrationError(IgniteMetricHandler): + """ + Ignite handler to compute Calibration Error during training or evaluation. + + **Why Calibration Matters:** + + A well-calibrated model produces probability estimates that match the true likelihood of correctness. + For example, predictions with 80% confidence should be correct approximately 80% of the time. + Modern neural networks often exhibit poor calibration (typically overconfident), which can be + problematic in medical imaging where probability estimates may inform clinical decisions. + + This handler wraps :py:class:`~monai.metrics.CalibrationErrorMetric` for use with PyTorch Ignite + engines, automatically computing and aggregating calibration errors across iterations. + + **Supported Calibration Metrics:** + + - **Expected Calibration Error (ECE)**: Weighted average of per-bin errors (most common). + - **Average Calibration Error (ACE)**: Unweighted average across bins. + - **Maximum Calibration Error (MCE)**: Worst-case calibration error. + + Args: + num_bins: Number of equally-spaced bins for calibration computation. Defaults to 20. + include_background: Whether to include the first channel (index 0) in computation. + Set to ``False`` to exclude background in segmentation tasks. Defaults to ``True``. + calibration_reduction: Calibration error reduction mode. Options: ``"expected"`` (ECE), + ``"average"`` (ACE), ``"maximum"`` (MCE). Defaults to ``"expected"``. + metric_reduction: Reduction across batch/channel after computing per-sample errors. + Options: ``"none"``, ``"mean"``, ``"sum"``, ``"mean_batch"``, ``"sum_batch"``, + ``"mean_channel"``, ``"sum_channel"``. Defaults to ``"mean"``. + output_transform: Callable to extract ``(y_pred, y)`` from ``engine.state.output``. + See `Ignite concepts `_ and + the batch output transform tutorial in the MONAI tutorials repository. + save_details: If ``True``, saves per-sample/per-channel metric values to + ``engine.state.metric_details[name]``. Defaults to ``True``. + + References: + - Guo, C., et al. "On Calibration of Modern Neural Networks." ICML 2017. + https://proceedings.mlr.press/v70/guo17a.html + - Barfoot, T., et al. "Average Calibration Error: A Differentiable Loss for Improved + Reliability in Image Segmentation." MICCAI 2024. + https://papers.miccai.org/miccai-2024/091-Paper3075.html + + See Also: + - :py:class:`~monai.metrics.CalibrationErrorMetric`: The underlying metric class. + - :py:func:`~monai.metrics.calibration_binning`: Low-level binning for reliability diagrams. + + Example: + >>> from monai.handlers import CalibrationError, from_engine + >>> from ignite.engine import Engine + >>> + >>> def evaluation_step(engine, batch): + ... # Returns dict with "pred" (probabilities) and "label" (one-hot) + ... return {"pred": model(batch["image"]), "label": batch["label"]} + >>> + >>> evaluator = Engine(evaluation_step) + >>> + >>> # Attach calibration error handler + >>> CalibrationError( + ... num_bins=15, + ... include_background=False, + ... calibration_reduction="expected", + ... output_transform=from_engine(["pred", "label"]), + ... ).attach(evaluator, name="ECE") + >>> + >>> # After evaluation, access results + >>> evaluator.run(val_loader) + >>> ece = evaluator.state.metrics["ECE"] + >>> print(f"Expected Calibration Error: {ece:.4f}") + """ + + def __init__( + self, + num_bins: int = 20, + include_background: bool = True, + calibration_reduction: CalibrationReduction | str = CalibrationReduction.EXPECTED, + metric_reduction: MetricReduction | str = MetricReduction.MEAN, + output_transform: Callable = lambda x: x, + save_details: bool = True, + ) -> None: + metric_fn = CalibrationErrorMetric( + num_bins=num_bins, + include_background=include_background, + calibration_reduction=calibration_reduction, + metric_reduction=metric_reduction, + ) + + super().__init__(metric_fn=metric_fn, output_transform=output_transform, save_details=save_details) diff --git a/monai/metrics/__init__.py b/monai/metrics/__init__.py index ae20903cfd..0da25feca9 100644 --- a/monai/metrics/__init__.py +++ b/monai/metrics/__init__.py @@ -13,6 +13,7 @@ from .active_learning_metrics import LabelQualityScore, VarianceMetric, compute_variance, label_quality_score from .average_precision import AveragePrecisionMetric, compute_average_precision +from .calibration import CalibrationErrorMetric, CalibrationReduction, calibration_binning from .confusion_matrix import ConfusionMatrixMetric, compute_confusion_matrix_metric, get_confusion_matrix from .cumulative_average import CumulativeAverage from .f_beta_score import FBetaScore diff --git a/monai/metrics/calibration.py b/monai/metrics/calibration.py new file mode 100644 index 0000000000..b3df6aff28 --- /dev/null +++ b/monai/metrics/calibration.py @@ -0,0 +1,331 @@ +# Copyright (c) MONAI Consortium +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# http://www.apache.org/licenses/LICENSE-2.0 +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import annotations + +from typing import Any + +import torch + +from monai.metrics.metric import CumulativeIterationMetric +from monai.metrics.utils import do_metric_reduction, ignore_background +from monai.utils import MetricReduction +from monai.utils.enums import StrEnum + +__all__ = ["calibration_binning", "CalibrationErrorMetric", "CalibrationReduction"] + + +def calibration_binning( + y_pred: torch.Tensor, y: torch.Tensor, num_bins: int = 20, right: bool = False +) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]: + """ + Compute calibration bins for predicted probabilities and ground truth labels. + + This function implements hard binning for calibration analysis, grouping predictions + into bins based on their confidence values and computing statistics for each bin. + These statistics can be used to assess model calibration or plot reliability diagrams. + + A well-calibrated model should have predicted probabilities that match empirical accuracy. + For example, among all predictions with 80% confidence, approximately 80% should be correct. + This function provides the per-bin statistics needed to evaluate this property. + + The function operates on input and target tensors with batch and channel dimensions, + handling each batch and channel separately. For bins that do not contain any elements, + the mean predicted values and mean ground truth values are set to NaN. + + Args: + y_pred: Predicted probabilities with shape ``(B, C, spatial...)``, where B is batch size, + C is number of classes/channels, and spatial can be any number of dimensions (H, W, D, etc.). + Values should be in the range [0, 1]. + y: Ground truth tensor with the same shape as ``y_pred``. Should be one-hot encoded + or contain binary values (0 or 1) indicating the true class membership. + num_bins: Number of equally-spaced bins to divide the [0, 1] probability range into. + Defaults to 20. Must be >= 1. + right: Determines bin boundary inclusion. If False (default), bins include the left + boundary and exclude the right (i.e., [left, right)). If True, bins exclude the + left boundary and include the right (i.e., (left, right]). + + Returns: + A tuple of three tensors, each with shape ``(B, C, num_bins)``: + - **mean_p_per_bin**: Mean predicted probability for samples in each bin. + - **mean_gt_per_bin**: Mean ground truth value (empirical accuracy) for samples in each bin. + - **bin_counts**: Number of samples falling into each bin. + + Bins with no samples have NaN values for mean_p_per_bin and mean_gt_per_bin. + + Raises: + ValueError: If ``y_pred`` and ``y`` have different shapes, if input has fewer than + 3 dimensions, or if ``num_bins < 1``. + + References: + - Guo, C., et al. "On Calibration of Modern Neural Networks." ICML 2017. + https://proceedings.mlr.press/v70/guo17a.html + - Barfoot, T., et al. "Average Calibration Error: A Differentiable Loss for Improved + Reliability in Image Segmentation." MICCAI 2024. + https://papers.miccai.org/miccai-2024/091-Paper3075.html + + Note: + This function uses nested loops over batch and channel dimensions for binning operations. + For reliability diagram visualization, use the returned statistics to plot mean predicted + probability vs. empirical accuracy for each bin. + + Example: + >>> import torch + >>> # Binary segmentation: batch=1, channels=2, spatial=4x4 + >>> y_pred = torch.rand(1, 2, 4, 4) # predicted probabilities + >>> y = torch.randint(0, 2, (1, 2, 4, 4)).float() # one-hot ground truth + >>> mean_p, mean_gt, counts = calibration_binning(y_pred, y, num_bins=10) + >>> # mean_p, mean_gt, counts each have shape (1, 2, 10) + """ + # Input validation + if y_pred.shape != y.shape: + raise ValueError(f"y_pred and y must have the same shape, got {y_pred.shape} and {y.shape}.") + if y_pred.ndim < 3: + raise ValueError(f"y_pred must have shape (B, C, spatial...), got ndim={y_pred.ndim}.") + if num_bins < 1: + raise ValueError(f"num_bins must be >= 1, got {num_bins}.") + + batch_size, num_channels = y_pred.shape[:2] + boundaries = torch.linspace( + start=0.0, end=1.0 + torch.finfo(torch.float32).eps, steps=num_bins + 1, device=y_pred.device + ) + + mean_p_per_bin = torch.zeros(batch_size, num_channels, num_bins, device=y_pred.device) + mean_gt_per_bin = torch.zeros_like(mean_p_per_bin) + bin_counts = torch.zeros_like(mean_p_per_bin) + + y_pred_flat = y_pred.flatten(start_dim=2).float() + y_flat = y.flatten(start_dim=2).float() + + for b in range(batch_size): + for c in range(num_channels): + values_p = y_pred_flat[b, c, :] + values_gt = y_flat[b, c, :] + + # Compute bin indices and clamp to valid range to handle out-of-range values + bin_idx = torch.bucketize(values_p, boundaries[1:], right=right) + bin_idx = bin_idx.clamp(max=num_bins - 1) + + # Compute bin counts using scatter_add + counts = torch.zeros(num_bins, device=y_pred.device, dtype=torch.float32) + counts.scatter_add_(0, bin_idx, torch.ones_like(values_p)) + bin_counts[b, c, :] = counts + + # Compute sums for mean calculation using scatter_add (more compatible than scatter_reduce) + sum_p = torch.zeros(num_bins, device=y_pred.device, dtype=torch.float32) + sum_p.scatter_add_(0, bin_idx, values_p) + + sum_gt = torch.zeros(num_bins, device=y_pred.device, dtype=torch.float32) + sum_gt.scatter_add_(0, bin_idx, values_gt) + + # Compute means, avoiding division by zero + safe_counts = counts.clamp(min=1) + mean_p_per_bin[b, c, :] = sum_p / safe_counts + mean_gt_per_bin[b, c, :] = sum_gt / safe_counts + + # Set empty bins to NaN + mean_p_per_bin[bin_counts == 0] = torch.nan + mean_gt_per_bin[bin_counts == 0] = torch.nan + + return mean_p_per_bin, mean_gt_per_bin, bin_counts + + +class CalibrationReduction(StrEnum): + """ + Enumeration of calibration error reduction methods for aggregating per-bin calibration errors. + + - **EXPECTED**: Expected Calibration Error (ECE) - weighted average of per-bin errors by bin count. + This is the most commonly used calibration metric, giving more weight to bins with more samples. + - **AVERAGE**: Average Calibration Error (ACE) - unweighted mean of per-bin errors. + Treats all bins equally regardless of sample count. + - **MAXIMUM**: Maximum Calibration Error (MCE) - worst-case calibration error across all bins. + Useful for identifying the confidence range with poorest calibration. + + References: + - Naeini, M.P., et al. "Obtaining Well Calibrated Probabilities Using Bayesian Binning." AAAI 2015. + - Guo, C., et al. "On Calibration of Modern Neural Networks." ICML 2017. + """ + + EXPECTED = "expected" + AVERAGE = "average" + MAXIMUM = "maximum" + + +class CalibrationErrorMetric(CumulativeIterationMetric): + """ + Compute the Calibration Error between predicted probabilities and ground truth labels. + + **Why Calibration Matters:** + + A well-calibrated classifier produces probability estimates that reflect true correctness likelihood. + For instance, if a model predicts 80% probability for class A, a well calibrated and reliable model + should be correct approximately 80% of the time among all such predictions. + Modern neural networks, despite high accuracy, are often poorly calibrated, as they tend to be + overconfident in their predictions. + This is particularly important in medical imaging where probability estimates may inform clinical decisions. + + **How It Works:** + + This metric uses a binning approach: predictions are grouped into bins based on their confidence + (predicted probability), and for each bin, the average confidence is compared to the empirical + accuracy (fraction of correct predictions). The calibration error measures the discrepancy between + these values across all bins. + + Three reduction modes are supported: + + - **Expected Calibration Error (ECE)**: Weighted average of per-bin errors, where weights are + proportional to the number of samples in each bin. Most commonly used metric. + - **Average Calibration Error (ACE)**: Simple unweighted average across bins. + - **Maximum Calibration Error (MCE)**: The largest calibration error among all bins. + + The metric supports both single-channel and multi-channel data in the format ``(B, C, H, W[, D])``, + where B is batch size, C is number of classes, and H, W, D are spatial dimensions. + + Args: + num_bins: Number of equally-spaced bins to divide the [0, 1] probability range into. + Defaults to 20. + include_background: Whether to include the first channel (index 0) in the computation. + Set to ``False`` to exclude background class, which is useful in segmentation tasks + where background may dominate and skew calibration results. Defaults to ``True``. + calibration_reduction: Method for calculating calibration error from binned data. + Available modes: ``"expected"`` (ECE), ``"average"`` (ACE), ``"maximum"`` (MCE). + Defaults to ``"expected"``. + metric_reduction: Reduction mode to apply across batch/channel dimensions after computing + per-sample calibration errors. Available modes: ``"none"``, ``"mean"``, ``"sum"``, + ``"mean_batch"``, ``"sum_batch"``, ``"mean_channel"``, ``"sum_channel"``. + Defaults to ``"mean"``. + get_not_nans: If ``True``, ``aggregate()`` returns a tuple ``(metric, not_nans)`` where + ``not_nans`` is the count of non-NaN values. Defaults to ``False``. + right: Bin boundary inclusion rule. If ``False`` (default), bins are ``[left, right)``. + If ``True``, bins are ``(left, right]``. + + References: + - Guo, C., et al. "On Calibration of Modern Neural Networks." ICML 2017. + https://proceedings.mlr.press/v70/guo17a.html + - Barfoot, T., et al. "Average Calibration Error: A Differentiable Loss for Improved + Reliability in Image Segmentation." MICCAI 2024. + https://papers.miccai.org/miccai-2024/091-Paper3075.html + + See Also: + - :py:class:`monai.handlers.CalibrationError`: Ignite handler wrapper for this metric. + - :py:func:`calibration_binning`: Low-level binning function for reliability diagrams. + + Example: + Typical execution steps follow :py:class:`monai.metrics.metric.Cumulative`. + + >>> import torch + >>> from monai.metrics import CalibrationErrorMetric + >>> from monai.transforms import Activations, AsDiscrete + >>> + >>> # Setup transforms for probability conversion + >>> num_classes = 3 + >>> softmax = Activations(softmax=True) # convert logits to probabilities + >>> to_onehot = AsDiscrete(to_onehot=num_classes) # convert labels to one-hot + >>> + >>> # Create metric (Expected Calibration Error, excluding background) + >>> metric = CalibrationErrorMetric( + ... num_bins=15, + ... include_background=False, + ... calibration_reduction="expected" + ... ) + >>> + >>> # Evaluation loop + >>> for batch_data in dataloader: + ... logits, labels = model(batch_data) + ... preds = softmax(logits) # shape: (B, C, H, W) with values in [0, 1] + ... labels_onehot = to_onehot(labels) # shape: (B, C, H, W) with values 0 or 1 + ... metric(y_pred=preds, y=labels_onehot) + >>> + >>> # Get final calibration error + >>> ece = metric.aggregate() + >>> print(f"Expected Calibration Error: {ece:.4f}") + """ + + def __init__( + self, + num_bins: int = 20, + include_background: bool = True, + calibration_reduction: CalibrationReduction | str = CalibrationReduction.EXPECTED, + metric_reduction: MetricReduction | str = MetricReduction.MEAN, + get_not_nans: bool = False, + right: bool = False, + ) -> None: + super().__init__() + self.num_bins = num_bins + self.include_background = include_background + self.calibration_reduction = CalibrationReduction(calibration_reduction) + self.metric_reduction = metric_reduction + self.get_not_nans = get_not_nans + self.right = right + + def _compute_tensor(self, y_pred: torch.Tensor, y: torch.Tensor, **kwargs: Any) -> torch.Tensor: # type: ignore[override] + """ + Compute calibration error for the given predictions and ground truth. + + Args: + y_pred: input data to compute. It should be in the format of (batch, channel, spatial...). + It represents probability predictions of the model. + y: ground truth in one-hot format. It should be in the format of (batch, channel, spatial...). + The values should be binarized. + **kwargs: additional keyword arguments (unused, for API compatibility). + + Returns: + Calibration error tensor with shape (batch, channel). + """ + if not self.include_background: + y_pred, y = ignore_background(y_pred=y_pred, y=y) + + mean_p_per_bin, mean_gt_per_bin, bin_counts = calibration_binning( + y_pred=y_pred, y=y, num_bins=self.num_bins, right=self.right + ) + + # Calculate the absolute differences, ignoring nan values + abs_diff = torch.abs(mean_p_per_bin - mean_gt_per_bin) + + if self.calibration_reduction == CalibrationReduction.EXPECTED: + # Calculate the weighted sum of absolute differences + return torch.nansum(abs_diff * bin_counts, dim=-1) / torch.sum(bin_counts, dim=-1) + elif self.calibration_reduction == CalibrationReduction.AVERAGE: + return torch.nanmean(abs_diff, dim=-1) # Average across all dimensions, ignoring nan + elif self.calibration_reduction == CalibrationReduction.MAXIMUM: + # Replace NaN with -inf for max computation, then restore NaN for all-NaN cases + abs_diff_for_max = torch.nan_to_num(abs_diff, nan=float("-inf")) + max_vals = torch.max(abs_diff_for_max, dim=-1).values + # Restore NaN where all bins were empty (max is -inf) + max_vals = torch.where( + max_vals == float("-inf"), torch.tensor(float("nan"), device=max_vals.device), max_vals + ) + return max_vals + else: + raise ValueError(f"Unsupported calibration reduction: {self.calibration_reduction}") + + def aggregate( + self, reduction: MetricReduction | str | None = None + ) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]: + """ + Execute reduction logic for the output of `_compute_tensor`. + + Args: + reduction: define mode of reduction to the metrics, will only apply reduction on `not-nan` values, + available reduction modes: {``"none"``, ``"mean"``, ``"sum"``, ``"mean_batch"``, ``"sum_batch"``, + ``"mean_channel"``, ``"sum_channel"``}, default to `self.metric_reduction`. if "none", will not + do reduction. + + Returns: + If `get_not_nans` is True, returns a tuple (metric, not_nans), otherwise returns only the metric. + """ + data = self.get_buffer() + if not isinstance(data, torch.Tensor): + raise ValueError("the data to aggregate must be PyTorch Tensor.") + + # do metric reduction + f, not_nans = do_metric_reduction(data, reduction or self.metric_reduction) + return (f, not_nans) if self.get_not_nans else f diff --git a/tests/handlers/test_handler_calibration_error.py b/tests/handlers/test_handler_calibration_error.py new file mode 100644 index 0000000000..1837f721c7 --- /dev/null +++ b/tests/handlers/test_handler_calibration_error.py @@ -0,0 +1,176 @@ +# Copyright (c) MONAI Consortium +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# http://www.apache.org/licenses/LICENSE-2.0 +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import annotations + +import unittest + +import torch +from parameterized import parameterized + +from monai.handlers import CalibrationError, from_engine +from monai.utils import IgniteInfo, min_version, optional_import +from tests.test_utils import assert_allclose + +Engine, has_ignite = optional_import("ignite.engine", IgniteInfo.OPT_IMPORT_VERSION, min_version, "Engine") + +_device = "cuda:0" if torch.cuda.is_available() else "cpu" + +# Test cases for handler +# Format: [input_params, expected_value, expected_rows, expected_channels] +TEST_CASE_1 = [ + { + "num_bins": 5, + "include_background": True, + "calibration_reduction": "expected", + "metric_reduction": "mean", + "output_transform": from_engine(["pred", "label"]), + }, + 0.2250, + 4, # 2 batches * 2 iterations + 2, # 2 channels +] + +TEST_CASE_2 = [ + { + "num_bins": 5, + "include_background": False, + "calibration_reduction": "expected", + "metric_reduction": "mean", + "output_transform": from_engine(["pred", "label"]), + }, + 0.2500, + 4, # 2 batches * 2 iterations + 1, # 1 channel (background excluded) +] + +TEST_CASE_3 = [ + { + "num_bins": 5, + "include_background": True, + "calibration_reduction": "average", + "metric_reduction": "mean", + "output_transform": from_engine(["pred", "label"]), + }, + 0.2584, # Mean of [[0.2000, 0.4667], [0.2000, 0.1667]] + 4, + 2, +] + +TEST_CASE_4 = [ + { + "num_bins": 5, + "include_background": True, + "calibration_reduction": "maximum", + "metric_reduction": "mean", + "output_transform": from_engine(["pred", "label"]), + }, + 0.4000, # Mean of [[0.3000, 0.7000], [0.3000, 0.3000]] + 4, + 2, +] + + +@unittest.skipUnless(has_ignite, "Requires pytorch-ignite") +class TestHandlerCalibrationError(unittest.TestCase): + + @parameterized.expand([TEST_CASE_1, TEST_CASE_2, TEST_CASE_3, TEST_CASE_4]) + def test_compute(self, input_params, expected_value, expected_rows, expected_channels): + calibration_metric = CalibrationError(**input_params) + + # Test data: 2 batches with 2 channels each + y_pred = torch.tensor( + [[[[0.7, 0.3], [0.1, 0.9]], [[0.7, 0.3], [0.5, 0.5]]], [[[0.9, 0.9], [0.3, 0.3]], [[0.1, 0.1], [0.9, 0.7]]]] + ).to(_device) + y = torch.tensor([[[[1, 0], [0, 1]], [[0, 1], [1, 0]]], [[[1, 1], [0, 0]], [[0, 0], [1, 1]]]]).to(_device) + + # Create data as list of batches (2 iterations) + data = [{"pred": y_pred, "label": y}, {"pred": y_pred, "label": y}] + + def _val_func(engine, batch): + return batch + + engine = Engine(_val_func) + calibration_metric.attach(engine=engine, name="calibration_error") + + engine.run(data, max_epochs=1) + + assert_allclose( + engine.state.metrics["calibration_error"], expected_value, atol=1e-4, rtol=1e-4, type_test=False + ) + + # Check details shape using invariants rather than exact tuple + details = engine.state.metric_details["calibration_error"] + self.assertEqual(details.shape[0], expected_rows) + self.assertEqual(details.shape[-1], expected_channels) + + +@unittest.skipUnless(has_ignite, "Requires pytorch-ignite") +class TestHandlerCalibrationErrorEdgeCases(unittest.TestCase): + + def test_single_iteration(self): + """Test handler with single iteration.""" + calibration_metric = CalibrationError( + num_bins=5, + include_background=True, + calibration_reduction="expected", + metric_reduction="mean", + output_transform=from_engine(["pred", "label"]), + ) + + y_pred = torch.tensor([[[[0.7, 0.3], [0.1, 0.9]]]]).to(_device) + y = torch.tensor([[[[1, 0], [0, 1]]]]).to(_device) + + data = [{"pred": y_pred, "label": y}] + + def _val_func(engine, batch): + return batch + + engine = Engine(_val_func) + calibration_metric.attach(engine=engine, name="calibration_error") + + engine.run(data, max_epochs=1) + + assert_allclose(engine.state.metrics["calibration_error"], 0.2, atol=1e-4, rtol=1e-4, type_test=False) + + def test_save_details_false(self): + """Test handler with save_details=False.""" + calibration_metric = CalibrationError( + num_bins=5, + include_background=True, + calibration_reduction="expected", + metric_reduction="mean", + output_transform=from_engine(["pred", "label"]), + save_details=False, + ) + + y_pred = torch.tensor([[[[0.7, 0.3], [0.1, 0.9]]]]).to(_device) + y = torch.tensor([[[[1, 0], [0, 1]]]]).to(_device) + + data = [{"pred": y_pred, "label": y}] + + def _val_func(engine, batch): + return batch + + engine = Engine(_val_func) + calibration_metric.attach(engine=engine, name="calibration_error") + + engine.run(data, max_epochs=1) + + assert_allclose(engine.state.metrics["calibration_error"], 0.2, atol=1e-4, rtol=1e-4, type_test=False) + + # When save_details=False, metric_details should not exist or should not have the metric key + if hasattr(engine.state, "metric_details"): + self.assertNotIn("calibration_error", engine.state.metric_details or {}) + + +if __name__ == "__main__": + unittest.main() diff --git a/tests/metrics/test_calibration_metric.py b/tests/metrics/test_calibration_metric.py new file mode 100644 index 0000000000..17dcfef6c2 --- /dev/null +++ b/tests/metrics/test_calibration_metric.py @@ -0,0 +1,319 @@ +# Copyright (c) MONAI Consortium +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# http://www.apache.org/licenses/LICENSE-2.0 +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import annotations + +import unittest +from unittest import mock + +import torch +from parameterized import parameterized + +from monai.metrics import CalibrationErrorMetric, CalibrationReduction, calibration_binning +from monai.utils import MetricReduction +from tests.test_utils import assert_allclose + +_device = "cuda:0" if torch.cuda.is_available() else "cpu" + +# Test cases for calibration binning +# Format: [name, y_pred, y, num_bins, right, expected_mean_p, expected_mean_gt, expected_counts] +TEST_BINNING_SMALL_MID = [ + "small_mid", + torch.tensor([[[[0.7, 0.3], [0.1, 0.9]]]]), + torch.tensor([[[[1, 0], [0, 1]]]]), + 5, + False, + torch.tensor([[[0.1, 0.3, float("nan"), 0.7, 0.9]]]), + torch.tensor([[[0.0, 0.0, float("nan"), 1.0, 1.0]]]), + torch.tensor([[[1.0, 1.0, 0.0, 1.0, 1.0]]]), +] + +TEST_BINNING_LARGE_MID = [ + "large_mid", + torch.tensor( + [[[[0.7, 0.3], [0.1, 0.9]], [[0.7, 0.3], [0.5, 0.5]]], [[[0.9, 0.9], [0.3, 0.3]], [[0.1, 0.1], [0.9, 0.7]]]] + ), + torch.tensor([[[[1, 0], [0, 1]], [[0, 1], [1, 0]]], [[[1, 1], [0, 0]], [[0, 0], [1, 1]]]]), + 5, + False, + torch.tensor( + [ + [[0.1, 0.3, float("nan"), 0.7, 0.9], [float("nan"), 0.3, 0.5, 0.7, float("nan")]], + [[float("nan"), 0.3, float("nan"), float("nan"), 0.9], [0.1, float("nan"), float("nan"), 0.7, 0.9]], + ] + ), + torch.tensor( + [ + [[0.0, 0.0, float("nan"), 1.0, 1.0], [float("nan"), 1.0, 0.5, 0.0, float("nan")]], + [[float("nan"), 0.0, float("nan"), float("nan"), 1.0], [0.0, float("nan"), float("nan"), 1.0, 1.0]], + ] + ), + torch.tensor( + [[[1.0, 1.0, 0.0, 1.0, 1.0], [0.0, 1.0, 2.0, 1.0, 0.0]], [[0.0, 2.0, 0.0, 0.0, 2.0], [2.0, 0.0, 0.0, 1.0, 1.0]]] + ), +] + +TEST_BINNING_SMALL_LEFT_EDGE = [ + "small_left_edge", + torch.tensor([[[[0.8, 0.2], [0.4, 0.6]]]]), + torch.tensor([[[[1, 0], [0, 1]]]]), + 5, + False, + torch.tensor([[[0.2, 0.4, 0.6, 0.8, float("nan")]]]), + torch.tensor([[[0.0, 0.0, 1.0, 1.0, float("nan")]]]), + torch.tensor([[[1.0, 1.0, 1.0, 1.0, 0.0]]]), +] + +TEST_BINNING_SMALL_RIGHT_EDGE = [ + "small_right_edge", + torch.tensor([[[[0.8, 0.2], [0.4, 0.6]]]]), + torch.tensor([[[[1, 0], [0, 1]]]]), + 5, + True, + torch.tensor([[[float("nan"), 0.2, 0.4, 0.6, 0.8]]]), + torch.tensor([[[float("nan"), 0.0, 0.0, 1.0, 1.0]]]), + torch.tensor([[[0.0, 1.0, 1.0, 1.0, 1.0]]]), +] + +BINNING_TEST_CASES = [ + TEST_BINNING_SMALL_MID, + TEST_BINNING_LARGE_MID, + TEST_BINNING_SMALL_LEFT_EDGE, + TEST_BINNING_SMALL_RIGHT_EDGE, +] + +# Test cases for calibration error metric values +# Format: [name, y_pred, y, num_bins, expected_expected, expected_average, expected_maximum] +TEST_VALUE_1B1C = [ + "1b1c", + torch.tensor([[[[0.7, 0.3], [0.1, 0.9]]]]), + torch.tensor([[[[1, 0], [0, 1]]]]), + 5, + torch.tensor([[0.2]]), + torch.tensor([[0.2]]), + torch.tensor([[0.3]]), +] + +TEST_VALUE_2B2C = [ + "2b2c", + torch.tensor( + [[[[0.7, 0.3], [0.1, 0.9]], [[0.7, 0.3], [0.5, 0.5]]], [[[0.9, 0.9], [0.3, 0.3]], [[0.1, 0.1], [0.9, 0.7]]]] + ), + torch.tensor([[[[1, 0], [0, 1]], [[0, 1], [1, 0]]], [[[1, 1], [0, 0]], [[0, 0], [1, 1]]]]), + 5, + torch.tensor([[0.2000, 0.3500], [0.2000, 0.1500]]), + torch.tensor([[0.2000, 0.4667], [0.2000, 0.1667]]), + torch.tensor([[0.3000, 0.7000], [0.3000, 0.3000]]), +] + +VALUE_TEST_CASES = [TEST_VALUE_1B1C, TEST_VALUE_2B2C] + + +class TestCalibrationBinning(unittest.TestCase): + + @parameterized.expand(BINNING_TEST_CASES) + def test_binning(self, _name, y_pred, y, num_bins, right, expected_mean_p, expected_mean_gt, expected_counts): + y_pred = y_pred.to(_device) + y = y.to(_device) + expected_mean_p = expected_mean_p.to(_device) + expected_mean_gt = expected_mean_gt.to(_device) + expected_counts = expected_counts.to(_device) + + # Use mock.patch to replace torch.linspace + # This is to avoid floating point precision issues when looking at edge conditions + mock_boundaries = torch.tensor([0.0, 0.2, 0.4, 0.6, 0.8, 1.0], device=_device) + with mock.patch("monai.metrics.calibration.torch.linspace", return_value=mock_boundaries): + mean_p_per_bin, mean_gt_per_bin, bin_counts = calibration_binning(y_pred, y, num_bins=num_bins, right=right) + + # Handle NaN comparisons: compare NaN masks separately, then compare non-NaN values + # mean_p_per_bin + self.assertTrue(torch.equal(torch.isnan(mean_p_per_bin), torch.isnan(expected_mean_p))) + mask_p = ~torch.isnan(expected_mean_p) + if mask_p.any(): + assert_allclose(mean_p_per_bin[mask_p], expected_mean_p[mask_p], atol=1e-4, rtol=1e-4) + + # mean_gt_per_bin + self.assertTrue(torch.equal(torch.isnan(mean_gt_per_bin), torch.isnan(expected_mean_gt))) + mask_gt = ~torch.isnan(expected_mean_gt) + if mask_gt.any(): + assert_allclose(mean_gt_per_bin[mask_gt], expected_mean_gt[mask_gt], atol=1e-4, rtol=1e-4) + + # bin_counts (no NaNs) + assert_allclose(bin_counts, expected_counts, atol=1e-4, rtol=1e-4) + + def test_shape_mismatch_raises(self): + """Test that mismatched shapes raise ValueError.""" + y_pred = torch.tensor([[[[0.7, 0.3], [0.1, 0.9]]]]).to(_device) + y = torch.tensor([[[[1, 0], [0, 1], [0, 0]]]]).to(_device) # Different shape + with self.assertRaises(ValueError) as context: + calibration_binning(y_pred, y, num_bins=5) + self.assertIn("same shape", str(context.exception)) + + def test_insufficient_ndim_raises(self): + """Test that tensors with ndim < 3 raise ValueError.""" + y_pred = torch.tensor([[0.7, 0.3]]).to(_device) # Only 2D + y = torch.tensor([[1, 0]]).to(_device) + with self.assertRaises(ValueError) as context: + calibration_binning(y_pred, y, num_bins=5) + self.assertIn("ndim", str(context.exception)) + + def test_invalid_num_bins_raises(self): + """Test that num_bins < 1 raises ValueError.""" + y_pred = torch.tensor([[[[0.7, 0.3], [0.1, 0.9]]]]).to(_device) + y = torch.tensor([[[[1, 0], [0, 1]]]]).to(_device) + with self.assertRaises(ValueError) as context: + calibration_binning(y_pred, y, num_bins=0) + self.assertIn("num_bins", str(context.exception)) + + +class TestCalibrationErrorMetricValue(unittest.TestCase): + + @parameterized.expand(VALUE_TEST_CASES) + def test_expected_reduction(self, _name, y_pred, y, num_bins, expected_expected, _expected_average, _expected_max): + y_pred = y_pred.to(_device) + y = y.to(_device) + expected_expected = expected_expected.to(_device) + + metric = CalibrationErrorMetric( + num_bins=num_bins, + include_background=True, + calibration_reduction=CalibrationReduction.EXPECTED, + metric_reduction=MetricReduction.NONE, + ) + + metric(y_pred=y_pred, y=y) + result = metric.aggregate() + + assert_allclose(result, expected_expected, atol=1e-4, rtol=1e-4) + + @parameterized.expand(VALUE_TEST_CASES) + def test_average_reduction(self, _name, y_pred, y, num_bins, _expected_expected, expected_average, _expected_max): + y_pred = y_pred.to(_device) + y = y.to(_device) + expected_average = expected_average.to(_device) + + metric = CalibrationErrorMetric( + num_bins=num_bins, + include_background=True, + calibration_reduction=CalibrationReduction.AVERAGE, + metric_reduction=MetricReduction.NONE, + ) + + metric(y_pred=y_pred, y=y) + result = metric.aggregate() + + assert_allclose(result, expected_average, atol=1e-4, rtol=1e-4) + + @parameterized.expand(VALUE_TEST_CASES) + def test_maximum_reduction(self, _name, y_pred, y, num_bins, _expected_expected, _expected_average, expected_max): + y_pred = y_pred.to(_device) + y = y.to(_device) + expected_max = expected_max.to(_device) + + metric = CalibrationErrorMetric( + num_bins=num_bins, + include_background=True, + calibration_reduction=CalibrationReduction.MAXIMUM, + metric_reduction=MetricReduction.NONE, + ) + + metric(y_pred=y_pred, y=y) + result = metric.aggregate() + + assert_allclose(result, expected_max, atol=1e-4, rtol=1e-4) + + +class TestCalibrationErrorMetricOptions(unittest.TestCase): + + def test_include_background_false(self): + y_pred = torch.tensor( + [[[[0.7, 0.3], [0.1, 0.9]], [[0.7, 0.3], [0.5, 0.5]]], [[[0.9, 0.9], [0.3, 0.3]], [[0.1, 0.1], [0.9, 0.7]]]] + ).to(_device) + y = torch.tensor([[[[1, 0], [0, 1]], [[0, 1], [1, 0]]], [[[1, 1], [0, 0]], [[0, 0], [1, 1]]]]).to(_device) + + metric = CalibrationErrorMetric( + num_bins=5, + include_background=False, + calibration_reduction=CalibrationReduction.EXPECTED, + metric_reduction=MetricReduction.MEAN, + ) + + metric(y_pred=y_pred, y=y) + result = metric.aggregate() + + assert_allclose(result, torch.tensor(0.2500, device=_device), atol=1e-4, rtol=1e-4) + + def test_metric_reduction_mean(self): + y_pred = torch.tensor( + [[[[0.7, 0.3], [0.1, 0.9]], [[0.7, 0.3], [0.5, 0.5]]], [[[0.9, 0.9], [0.3, 0.3]], [[0.1, 0.1], [0.9, 0.7]]]] + ).to(_device) + y = torch.tensor([[[[1, 0], [0, 1]], [[0, 1], [1, 0]]], [[[1, 1], [0, 0]], [[0, 0], [1, 1]]]]).to(_device) + + metric = CalibrationErrorMetric( + num_bins=5, + include_background=True, + calibration_reduction=CalibrationReduction.EXPECTED, + metric_reduction=MetricReduction.MEAN, + ) + + metric(y_pred=y_pred, y=y) + result = metric.aggregate() + + # Mean of [[0.2000, 0.3500], [0.2000, 0.1500]] = 0.225 + assert_allclose(result, torch.tensor(0.2250, device=_device), atol=1e-4, rtol=1e-4) + + def test_get_not_nans(self): + y_pred = torch.tensor([[[[0.7, 0.3], [0.1, 0.9]]]]).to(_device) + y = torch.tensor([[[[1, 0], [0, 1]]]]).to(_device) + + metric = CalibrationErrorMetric( + num_bins=5, + include_background=True, + calibration_reduction=CalibrationReduction.EXPECTED, + metric_reduction=MetricReduction.MEAN, + get_not_nans=True, + ) + + metric(y_pred=y_pred, y=y) + result, not_nans = metric.aggregate() + + assert_allclose(result, torch.tensor(0.2, device=_device), atol=1e-4, rtol=1e-4) + self.assertEqual(not_nans.item(), 1) + + def test_cumulative_iterations(self): + """Test that the metric correctly accumulates over multiple iterations.""" + y_pred = torch.tensor([[[[0.7, 0.3], [0.1, 0.9]]]]).to(_device) + y = torch.tensor([[[[1, 0], [0, 1]]]]).to(_device) + + metric = CalibrationErrorMetric( + num_bins=5, + include_background=True, + calibration_reduction=CalibrationReduction.EXPECTED, + metric_reduction=MetricReduction.MEAN, + ) + + # First iteration + metric(y_pred=y_pred, y=y) + # Second iteration + metric(y_pred=y_pred, y=y) + + result = metric.aggregate() + # Should still be 0.2 since both iterations have the same data + assert_allclose(result, torch.tensor(0.2, device=_device), atol=1e-4, rtol=1e-4) + + # Test reset + metric.reset() + data = metric.get_buffer() + self.assertIsNone(data) + + +if __name__ == "__main__": + unittest.main()