From 9dde999c5e94c9122829da670c4f5f10dfd8db79 Mon Sep 17 00:00:00 2001 From: Theo Barfoot Date: Fri, 16 Jan 2026 16:42:02 +0000 Subject: [PATCH 1/3] Add CalibrationErrorMetric and CalibrationError handler - Add calibration_binning() function for hard binning calibration - Add CalibrationErrorMetric with ECE/ACE/MCE reduction modes - Add CalibrationError Ignite handler - Add comprehensive tests for metrics and handler Addresses #8505 Signed-off-by: Theo Barfoot --- monai/handlers/__init__.py | 1 + monai/handlers/calibration.py | 72 ++++ monai/metrics/__init__.py | 1 + monai/metrics/calibration.py | 260 +++++++++++++ .../test_handler_calibration_error.py | 184 +++++++++ tests/metrics/test_calibration_metric.py | 357 ++++++++++++++++++ 6 files changed, 875 insertions(+) create mode 100644 monai/handlers/calibration.py create mode 100644 monai/metrics/calibration.py create mode 100644 tests/handlers/test_handler_calibration_error.py create mode 100644 tests/metrics/test_calibration_metric.py diff --git a/monai/handlers/__init__.py b/monai/handlers/__init__.py index ed5db8a7f3..7fc7b6df57 100644 --- a/monai/handlers/__init__.py +++ b/monai/handlers/__init__.py @@ -12,6 +12,7 @@ from __future__ import annotations from .average_precision import AveragePrecision +from .calibration import CalibrationError from .checkpoint_loader import CheckpointLoader from .checkpoint_saver import CheckpointSaver from .classification_saver import ClassificationSaver diff --git a/monai/handlers/calibration.py b/monai/handlers/calibration.py new file mode 100644 index 0000000000..6e6c2b74a9 --- /dev/null +++ b/monai/handlers/calibration.py @@ -0,0 +1,72 @@ +# Copyright (c) MONAI Consortium +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# http://www.apache.org/licenses/LICENSE-2.0 +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import annotations + +from collections.abc import Callable + +from monai.config import IgniteInfo +from monai.handlers.ignite_metric import IgniteMetricHandler +from monai.metrics import CalibrationErrorMetric, CalibrationReduction +from monai.utils import MetricReduction + +__all__ = ["CalibrationError"] + + +class CalibrationError(IgniteMetricHandler): + """ + Computes Calibration Error and reports the aggregated value according to `metric_reduction` + over all accumulated iterations. Can return the expected, average, or maximum calibration error. + + Args: + num_bins: number of bins to calculate calibration. Defaults to 20. + include_background: whether to include calibration error computation on the first channel of + the predicted output. Defaults to True. + calibration_reduction: Method for calculating calibration error values from binned data. + Available modes are `"expected"`, `"average"`, and `"maximum"`. Defaults to `"expected"`. + metric_reduction: Mode of reduction to apply to the metrics. + Reduction is only applied to non-NaN values. + Available reduction modes are `"none"`, `"mean"`, `"sum"`, `"mean_batch"`, + `"sum_batch"`, `"mean_channel"`, and `"sum_channel"`. + Defaults to `"mean"`. If set to `"none"`, no reduction will be performed. + output_transform: callable to extract `y_pred` and `y` from `ignite.engine.state.output` then + construct `(y_pred, y)` pair, where `y_pred` and `y` can be `batch-first` Tensors or + lists of `channel-first` Tensors. the form of `(y_pred, y)` is required by the `update()`. + `engine.state` and `output_transform` inherit from the ignite concept: + https://pytorch.org/ignite/concepts.html#state, explanation and usage example are in the tutorial: + https://github.com/Project-MONAI/tutorials/blob/master/modules/batch_output_transform.ipynb. + save_details: whether to save metric computation details per image, for example: calibration error + of every image. default to True, will save to `engine.state.metric_details` dict with the + metric name as key. + + """ + + def __init__( + self, + num_bins: int = 20, + include_background: bool = True, + calibration_reduction: CalibrationReduction | str = CalibrationReduction.EXPECTED, + metric_reduction: MetricReduction | str = MetricReduction.MEAN, + output_transform: Callable = lambda x: x, + save_details: bool = True, + ) -> None: + metric_fn = CalibrationErrorMetric( + num_bins=num_bins, + include_background=include_background, + calibration_reduction=calibration_reduction, + metric_reduction=metric_reduction, + ) + + super().__init__( + metric_fn=metric_fn, + output_transform=output_transform, + save_details=save_details, + ) diff --git a/monai/metrics/__init__.py b/monai/metrics/__init__.py index ae20903cfd..0da25feca9 100644 --- a/monai/metrics/__init__.py +++ b/monai/metrics/__init__.py @@ -13,6 +13,7 @@ from .active_learning_metrics import LabelQualityScore, VarianceMetric, compute_variance, label_quality_score from .average_precision import AveragePrecisionMetric, compute_average_precision +from .calibration import CalibrationErrorMetric, CalibrationReduction, calibration_binning from .confusion_matrix import ConfusionMatrixMetric, compute_confusion_matrix_metric, get_confusion_matrix from .cumulative_average import CumulativeAverage from .f_beta_score import FBetaScore diff --git a/monai/metrics/calibration.py b/monai/metrics/calibration.py new file mode 100644 index 0000000000..8d7b5729b9 --- /dev/null +++ b/monai/metrics/calibration.py @@ -0,0 +1,260 @@ +# Copyright (c) MONAI Consortium +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# http://www.apache.org/licenses/LICENSE-2.0 +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import annotations + +from typing import Any + +import torch + +from monai.metrics.metric import CumulativeIterationMetric +from monai.metrics.utils import do_metric_reduction, ignore_background +from monai.utils import MetricReduction +from monai.utils.enums import StrEnum + +__all__ = [ + "calibration_binning", + "CalibrationErrorMetric", + "CalibrationReduction", +] + + +def calibration_binning( + y_pred: torch.Tensor, y: torch.Tensor, num_bins: int = 20, right: bool = False +) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]: + """ + Compute calibration bins for predicted probabilities and ground truth labels. + This function calculates the mean predicted probabilities, mean ground truths, + and bin counts for each bin using a hard binning calibration approach. + + The function operates on input and target tensors with batch and channel dimensions, + handling each batch and channel separately. For bins that do not contain any elements, + the mean predicted values and mean ground truth values are set to NaN. + + Args: + y_pred: predicted tensor with shape [batch, channel, spatial], where spatial + can be any number of dimensions. The y_pred tensor represents probabilities. + Values should be in the range [0, 1] (probabilities). + y: Target tensor with the same shape as y_pred. It represents ground truth values. + num_bins: The number of bins to use for calibration. Defaults to 20. Must be >= 1. + right: If False (default), the bins include the left boundary and exclude the right boundary. + If True, the bins exclude the left boundary and include the right boundary. + + Returns: + A tuple of three tensors: + - mean_p_per_bin: Tensor of shape [batch_size, num_channels, num_bins] containing + the mean predicted values in each bin. + - mean_gt_per_bin: Tensor of shape [batch_size, num_channels, num_bins] containing + the mean ground truth values in each bin. + - bin_counts: Tensor of shape [batch_size, num_channels, num_bins] containing + the count of elements in each bin. + + Raises: + ValueError: If the input and target shapes do not match, if the input has fewer than 3 dimensions, + or if num_bins < 1. + + Note: + This function currently uses nested for loops over batch and channel dimensions + for binning operations. Future improvements may include vectorizing these operations + for enhanced performance. + """ + # Input validation + if y_pred.shape != y.shape: + raise ValueError(f"y_pred and y must have the same shape, got {y_pred.shape} and {y.shape}.") + if y_pred.ndim < 3: + raise ValueError(f"y_pred must have shape (B, C, spatial...), got ndim={y_pred.ndim}.") + if num_bins < 1: + raise ValueError(f"num_bins must be >= 1, got {num_bins}.") + + batch_size, num_channels = y_pred.shape[:2] + boundaries = torch.linspace( + start=0.0, + end=1.0 + torch.finfo(torch.float32).eps, + steps=num_bins + 1, + device=y_pred.device, + ) + + mean_p_per_bin = torch.zeros(batch_size, num_channels, num_bins, device=y_pred.device) + mean_gt_per_bin = torch.zeros_like(mean_p_per_bin) + bin_counts = torch.zeros_like(mean_p_per_bin) + + y_pred_flat = y_pred.flatten(start_dim=2).float() + y_flat = y.flatten(start_dim=2).float() + + for b in range(batch_size): + for c in range(num_channels): + values_p = y_pred_flat[b, c, :] + values_gt = y_flat[b, c, :] + + # Compute bin indices and clamp to valid range to handle out-of-range values + bin_idx = torch.bucketize(values_p, boundaries[1:], right=right) + bin_idx = bin_idx.clamp(max=num_bins - 1) + + # Compute bin counts using scatter_add + counts = torch.zeros(num_bins, device=y_pred.device, dtype=torch.float32) + counts.scatter_add_(0, bin_idx, torch.ones_like(values_p)) + bin_counts[b, c, :] = counts + + # Compute sums for mean calculation using scatter_add (more compatible than scatter_reduce) + sum_p = torch.zeros(num_bins, device=y_pred.device, dtype=torch.float32) + sum_p.scatter_add_(0, bin_idx, values_p) + + sum_gt = torch.zeros(num_bins, device=y_pred.device, dtype=torch.float32) + sum_gt.scatter_add_(0, bin_idx, values_gt) + + # Compute means, avoiding division by zero + safe_counts = counts.clamp(min=1) + mean_p_per_bin[b, c, :] = sum_p / safe_counts + mean_gt_per_bin[b, c, :] = sum_gt / safe_counts + + # Set empty bins to NaN + mean_p_per_bin[bin_counts == 0] = torch.nan + mean_gt_per_bin[bin_counts == 0] = torch.nan + + return mean_p_per_bin, mean_gt_per_bin, bin_counts + + +class CalibrationReduction(StrEnum): + """ + Enumeration of calibration error reduction methods. + + - EXPECTED: Expected Calibration Error (ECE) - weighted average by bin count + - AVERAGE: Average Calibration Error (ACE) - simple average across bins + - MAXIMUM: Maximum Calibration Error (MCE) - maximum error across bins + """ + + EXPECTED = "expected" + AVERAGE = "average" + MAXIMUM = "maximum" + + +class CalibrationErrorMetric(CumulativeIterationMetric): + """ + Compute the Calibration Error between predicted probabilities and ground truth labels. + This metric is suitable for multi-class tasks and supports batched inputs. + + The input `y_pred` represents the model's predicted probabilities, and `y` represents the ground truth labels. + `y_pred` is expected to have probabilities, and `y` should be in one-hot format. You can use suitable transforms + in `monai.transforms.post` to achieve the desired format. + + The `include_background` parameter can be set to `False` to exclude the first category (channel index 0), + which is conventionally assumed to be the background. This is particularly useful in segmentation tasks where + the background class might skew the calibration results. + + The metric supports both single-channel and multi-channel data. For multi-channel data, the input tensors + should be in the format of BCHW[D], where B is the batch size, C is the number of channels, and HW[D] + are the spatial dimensions. + + Args: + num_bins: Number of bins to divide probabilities into for calibration calculation. Defaults to 20. + include_background: Whether to include computation on the first channel of the predicted output. + Defaults to `True`. + calibration_reduction: Method for calculating calibration error values from binned data. + Available modes are `"expected"`, `"average"`, and `"maximum"`. Defaults to `"expected"`. + metric_reduction: Mode of reduction to apply to the metrics. + Reduction is only applied to non-NaN values. + Available reduction modes are `"none"`, `"mean"`, `"sum"`, `"mean_batch"`, + `"sum_batch"`, `"mean_channel"`, and `"sum_channel"`. + Defaults to `"mean"`. If set to `"none"`, no reduction will be performed. + get_not_nans: Whether to return the count of non-NaN values. + If `True`, `aggregate()` returns a tuple (metric, not_nans). Defaults to `False`. + right: Whether to use the right or left bin edge for binning. Defaults to `False` (left). + + Example of the typical execution steps of this metric class follows :py:class:`monai.metrics.metric.Cumulative`. + + Example: + >>> from monai.transforms import Activations, AsDiscrete + >>> # Transforms to convert model outputs to probabilities and labels to one-hot + >>> softmax = Activations(softmax=True) # or sigmoid=True for binary/multi-label + >>> to_onehot = AsDiscrete(to_onehot=num_classes) + >>> metric = CalibrationErrorMetric(num_bins=15, include_background=False, calibration_reduction="expected") + >>> for batch_data in dataloader: + >>> logits, labels = model(batch_data) + >>> preds = softmax(logits) # convert logits to probabilities + >>> labels_onehot = to_onehot(labels) # convert labels to one-hot format + >>> metric(y_pred=preds, y=labels_onehot) + >>> ece = metric.aggregate() + """ + + def __init__( + self, + num_bins: int = 20, + include_background: bool = True, + calibration_reduction: CalibrationReduction | str = CalibrationReduction.EXPECTED, + metric_reduction: MetricReduction | str = MetricReduction.MEAN, + get_not_nans: bool = False, + right: bool = False, + ) -> None: + super().__init__() + self.num_bins = num_bins + self.include_background = include_background + self.calibration_reduction = CalibrationReduction(calibration_reduction) + self.metric_reduction = metric_reduction + self.get_not_nans = get_not_nans + self.right = right + + def _compute_tensor(self, y_pred: torch.Tensor, y: torch.Tensor, **kwargs: Any) -> torch.Tensor: # type: ignore[override] + """ + Compute calibration error for the given predictions and ground truth. + + Args: + y_pred: input data to compute. It should be in the format of (batch, channel, spatial...). + It represents probability predictions of the model. + y: ground truth in one-hot format. It should be in the format of (batch, channel, spatial...). + The values should be binarized. + **kwargs: additional keyword arguments (unused, for API compatibility). + + Returns: + Calibration error tensor with shape (batch, channel). + """ + if not self.include_background: + y_pred, y = ignore_background(y_pred=y_pred, y=y) + + mean_p_per_bin, mean_gt_per_bin, bin_counts = calibration_binning( + y_pred=y_pred, y=y, num_bins=self.num_bins, right=self.right + ) + + # Calculate the absolute differences, ignoring nan values + abs_diff = torch.abs(mean_p_per_bin - mean_gt_per_bin) + + if self.calibration_reduction == CalibrationReduction.EXPECTED: + # Calculate the weighted sum of absolute differences + return torch.nansum(abs_diff * bin_counts, dim=-1) / torch.sum(bin_counts, dim=-1) + elif self.calibration_reduction == CalibrationReduction.AVERAGE: + return torch.nanmean(abs_diff, dim=-1) # Average across all dimensions, ignoring nan + elif self.calibration_reduction == CalibrationReduction.MAXIMUM: + abs_diff_no_nan = torch.nan_to_num(abs_diff, nan=0.0) + return torch.max(abs_diff_no_nan, dim=-1).values # Maximum across all dimensions + else: + raise ValueError(f"Unsupported calibration reduction: {self.calibration_reduction}") + + def aggregate( + self, reduction: MetricReduction | str | None = None + ) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]: + """ + Execute reduction logic for the output of `_compute_tensor`. + + Args: + reduction: define mode of reduction to the metrics, will only apply reduction on `not-nan` values, + available reduction modes: {``"none"``, ``"mean"``, ``"sum"``, ``"mean_batch"``, ``"sum_batch"``, + ``"mean_channel"``, ``"sum_channel"``}, default to `self.metric_reduction`. if "none", will not + do reduction. + + Returns: + If `get_not_nans` is True, returns a tuple (metric, not_nans), otherwise returns only the metric. + """ + data = self.get_buffer() + if not isinstance(data, torch.Tensor): + raise ValueError("the data to aggregate must be PyTorch Tensor.") + + # do metric reduction + f, not_nans = do_metric_reduction(data, reduction or self.metric_reduction) + return (f, not_nans) if self.get_not_nans else f diff --git a/tests/handlers/test_handler_calibration_error.py b/tests/handlers/test_handler_calibration_error.py new file mode 100644 index 0000000000..5cc0c2609c --- /dev/null +++ b/tests/handlers/test_handler_calibration_error.py @@ -0,0 +1,184 @@ +# Copyright (c) MONAI Consortium +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# http://www.apache.org/licenses/LICENSE-2.0 +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import annotations + +import unittest + +import torch +from parameterized import parameterized + +from monai.handlers import CalibrationError, from_engine +from monai.utils import IgniteInfo, min_version, optional_import +from tests.test_utils import assert_allclose + +Engine, has_ignite = optional_import("ignite.engine", IgniteInfo.OPT_IMPORT_VERSION, min_version, "Engine") + +_device = "cuda:0" if torch.cuda.is_available() else "cpu" + +# Test cases for handler +# Format: [input_params, expected_value, expected_rows, expected_channels] +TEST_CASE_1 = [ + { + "num_bins": 5, + "include_background": True, + "calibration_reduction": "expected", + "metric_reduction": "mean", + "output_transform": from_engine(["pred", "label"]), + }, + 0.2250, + 4, # 2 batches * 2 iterations + 2, # 2 channels +] + +TEST_CASE_2 = [ + { + "num_bins": 5, + "include_background": False, + "calibration_reduction": "expected", + "metric_reduction": "mean", + "output_transform": from_engine(["pred", "label"]), + }, + 0.2500, + 4, # 2 batches * 2 iterations + 1, # 1 channel (background excluded) +] + +TEST_CASE_3 = [ + { + "num_bins": 5, + "include_background": True, + "calibration_reduction": "average", + "metric_reduction": "mean", + "output_transform": from_engine(["pred", "label"]), + }, + 0.2584, # Mean of [[0.2000, 0.4667], [0.2000, 0.1667]] + 4, + 2, +] + +TEST_CASE_4 = [ + { + "num_bins": 5, + "include_background": True, + "calibration_reduction": "maximum", + "metric_reduction": "mean", + "output_transform": from_engine(["pred", "label"]), + }, + 0.4000, # Mean of [[0.3000, 0.7000], [0.3000, 0.3000]] + 4, + 2, +] + + +@unittest.skipUnless(has_ignite, "Requires pytorch-ignite") +class TestHandlerCalibrationError(unittest.TestCase): + + @parameterized.expand([TEST_CASE_1, TEST_CASE_2, TEST_CASE_3, TEST_CASE_4]) + def test_compute(self, input_params, expected_value, expected_rows, expected_channels): + calibration_metric = CalibrationError(**input_params) + + # Test data: 2 batches with 2 channels each + y_pred = torch.tensor( + [ + [[[0.7, 0.3], [0.1, 0.9]], [[0.7, 0.3], [0.5, 0.5]]], + [[[0.9, 0.9], [0.3, 0.3]], [[0.1, 0.1], [0.9, 0.7]]], + ] + ).to(_device) + y = torch.tensor( + [ + [[[1, 0], [0, 1]], [[0, 1], [1, 0]]], + [[[1, 1], [0, 0]], [[0, 0], [1, 1]]], + ] + ).to(_device) + + # Create data as list of batches (2 iterations) + data = [{"pred": y_pred, "label": y}, {"pred": y_pred, "label": y}] + + def _val_func(engine, batch): + return batch + + engine = Engine(_val_func) + calibration_metric.attach(engine=engine, name="calibration_error") + + engine.run(data, max_epochs=1) + + assert_allclose( + engine.state.metrics["calibration_error"], expected_value, atol=1e-4, rtol=1e-4, type_test=False + ) + + # Check details shape using invariants rather than exact tuple + details = engine.state.metric_details["calibration_error"] + self.assertEqual(details.shape[0], expected_rows) + self.assertEqual(details.shape[-1], expected_channels) + + +@unittest.skipUnless(has_ignite, "Requires pytorch-ignite") +class TestHandlerCalibrationErrorEdgeCases(unittest.TestCase): + + def test_single_iteration(self): + """Test handler with single iteration.""" + calibration_metric = CalibrationError( + num_bins=5, + include_background=True, + calibration_reduction="expected", + metric_reduction="mean", + output_transform=from_engine(["pred", "label"]), + ) + + y_pred = torch.tensor([[[[0.7, 0.3], [0.1, 0.9]]]]).to(_device) + y = torch.tensor([[[[1, 0], [0, 1]]]]).to(_device) + + data = [{"pred": y_pred, "label": y}] + + def _val_func(engine, batch): + return batch + + engine = Engine(_val_func) + calibration_metric.attach(engine=engine, name="calibration_error") + + engine.run(data, max_epochs=1) + + assert_allclose(engine.state.metrics["calibration_error"], 0.2, atol=1e-4, rtol=1e-4, type_test=False) + + def test_save_details_false(self): + """Test handler with save_details=False.""" + calibration_metric = CalibrationError( + num_bins=5, + include_background=True, + calibration_reduction="expected", + metric_reduction="mean", + output_transform=from_engine(["pred", "label"]), + save_details=False, + ) + + y_pred = torch.tensor([[[[0.7, 0.3], [0.1, 0.9]]]]).to(_device) + y = torch.tensor([[[[1, 0], [0, 1]]]]).to(_device) + + data = [{"pred": y_pred, "label": y}] + + def _val_func(engine, batch): + return batch + + engine = Engine(_val_func) + calibration_metric.attach(engine=engine, name="calibration_error") + + engine.run(data, max_epochs=1) + + assert_allclose(engine.state.metrics["calibration_error"], 0.2, atol=1e-4, rtol=1e-4, type_test=False) + + # When save_details=False, metric_details should not exist or should not have the metric key + if hasattr(engine.state, "metric_details"): + self.assertNotIn("calibration_error", engine.state.metric_details or {}) + + +if __name__ == "__main__": + unittest.main() diff --git a/tests/metrics/test_calibration_metric.py b/tests/metrics/test_calibration_metric.py new file mode 100644 index 0000000000..f220525793 --- /dev/null +++ b/tests/metrics/test_calibration_metric.py @@ -0,0 +1,357 @@ +# Copyright (c) MONAI Consortium +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# http://www.apache.org/licenses/LICENSE-2.0 +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import annotations + +import unittest +from unittest import mock + +import torch +from parameterized import parameterized + +from monai.metrics import CalibrationErrorMetric, CalibrationReduction, calibration_binning +from monai.utils import MetricReduction +from tests.test_utils import assert_allclose + +_device = "cuda:0" if torch.cuda.is_available() else "cpu" + +# Test cases for calibration binning +# Format: [name, y_pred, y, num_bins, right, expected_mean_p, expected_mean_gt, expected_counts] +TEST_BINNING_SMALL_MID = [ + "small_mid", + torch.tensor([[[[0.7, 0.3], [0.1, 0.9]]]]), + torch.tensor([[[[1, 0], [0, 1]]]]), + 5, + False, + torch.tensor([[[0.1, 0.3, float("nan"), 0.7, 0.9]]]), + torch.tensor([[[0.0, 0.0, float("nan"), 1.0, 1.0]]]), + torch.tensor([[[1.0, 1.0, 0.0, 1.0, 1.0]]]), +] + +TEST_BINNING_LARGE_MID = [ + "large_mid", + torch.tensor( + [ + [[[0.7, 0.3], [0.1, 0.9]], [[0.7, 0.3], [0.5, 0.5]]], + [[[0.9, 0.9], [0.3, 0.3]], [[0.1, 0.1], [0.9, 0.7]]], + ] + ), + torch.tensor( + [ + [[[1, 0], [0, 1]], [[0, 1], [1, 0]]], + [[[1, 1], [0, 0]], [[0, 0], [1, 1]]], + ] + ), + 5, + False, + torch.tensor( + [ + [[0.1, 0.3, float("nan"), 0.7, 0.9], [float("nan"), 0.3, 0.5, 0.7, float("nan")]], + [[float("nan"), 0.3, float("nan"), float("nan"), 0.9], [0.1, float("nan"), float("nan"), 0.7, 0.9]], + ] + ), + torch.tensor( + [ + [[0.0, 0.0, float("nan"), 1.0, 1.0], [float("nan"), 1.0, 0.5, 0.0, float("nan")]], + [[float("nan"), 0.0, float("nan"), float("nan"), 1.0], [0.0, float("nan"), float("nan"), 1.0, 1.0]], + ] + ), + torch.tensor( + [ + [[1.0, 1.0, 0.0, 1.0, 1.0], [0.0, 1.0, 2.0, 1.0, 0.0]], + [[0.0, 2.0, 0.0, 0.0, 2.0], [2.0, 0.0, 0.0, 1.0, 1.0]], + ] + ), +] + +TEST_BINNING_SMALL_LEFT_EDGE = [ + "small_left_edge", + torch.tensor([[[[0.8, 0.2], [0.4, 0.6]]]]), + torch.tensor([[[[1, 0], [0, 1]]]]), + 5, + False, + torch.tensor([[[0.2, 0.4, 0.6, 0.8, float("nan")]]]), + torch.tensor([[[0.0, 0.0, 1.0, 1.0, float("nan")]]]), + torch.tensor([[[1.0, 1.0, 1.0, 1.0, 0.0]]]), +] + +TEST_BINNING_SMALL_RIGHT_EDGE = [ + "small_right_edge", + torch.tensor([[[[0.8, 0.2], [0.4, 0.6]]]]), + torch.tensor([[[[1, 0], [0, 1]]]]), + 5, + True, + torch.tensor([[[float("nan"), 0.2, 0.4, 0.6, 0.8]]]), + torch.tensor([[[float("nan"), 0.0, 0.0, 1.0, 1.0]]]), + torch.tensor([[[0.0, 1.0, 1.0, 1.0, 1.0]]]), +] + +BINNING_TEST_CASES = [ + TEST_BINNING_SMALL_MID, + TEST_BINNING_LARGE_MID, + TEST_BINNING_SMALL_LEFT_EDGE, + TEST_BINNING_SMALL_RIGHT_EDGE, +] + +# Test cases for calibration error metric values +# Format: [name, y_pred, y, num_bins, expected_expected, expected_average, expected_maximum] +TEST_VALUE_1B1C = [ + "1b1c", + torch.tensor([[[[0.7, 0.3], [0.1, 0.9]]]]), + torch.tensor([[[[1, 0], [0, 1]]]]), + 5, + torch.tensor([[0.2]]), + torch.tensor([[0.2]]), + torch.tensor([[0.3]]), +] + +TEST_VALUE_2B2C = [ + "2b2c", + torch.tensor( + [ + [[[0.7, 0.3], [0.1, 0.9]], [[0.7, 0.3], [0.5, 0.5]]], + [[[0.9, 0.9], [0.3, 0.3]], [[0.1, 0.1], [0.9, 0.7]]], + ] + ), + torch.tensor( + [ + [[[1, 0], [0, 1]], [[0, 1], [1, 0]]], + [[[1, 1], [0, 0]], [[0, 0], [1, 1]]], + ] + ), + 5, + torch.tensor([[0.2000, 0.3500], [0.2000, 0.1500]]), + torch.tensor([[0.2000, 0.4667], [0.2000, 0.1667]]), + torch.tensor([[0.3000, 0.7000], [0.3000, 0.3000]]), +] + +VALUE_TEST_CASES = [ + TEST_VALUE_1B1C, + TEST_VALUE_2B2C, +] + + +class TestCalibrationBinning(unittest.TestCase): + + @parameterized.expand(BINNING_TEST_CASES) + def test_binning(self, _name, y_pred, y, num_bins, right, expected_mean_p, expected_mean_gt, expected_counts): + y_pred = y_pred.to(_device) + y = y.to(_device) + expected_mean_p = expected_mean_p.to(_device) + expected_mean_gt = expected_mean_gt.to(_device) + expected_counts = expected_counts.to(_device) + + # Use mock.patch to replace torch.linspace + # This is to avoid floating point precision issues when looking at edge conditions + mock_boundaries = torch.tensor([0.0, 0.2, 0.4, 0.6, 0.8, 1.0], device=_device) + with mock.patch("monai.metrics.calibration.torch.linspace", return_value=mock_boundaries): + mean_p_per_bin, mean_gt_per_bin, bin_counts = calibration_binning(y_pred, y, num_bins=num_bins, right=right) + + # Handle NaN comparisons: compare NaN masks separately, then compare non-NaN values + # mean_p_per_bin + self.assertTrue(torch.equal(torch.isnan(mean_p_per_bin), torch.isnan(expected_mean_p))) + mask_p = ~torch.isnan(expected_mean_p) + if mask_p.any(): + assert_allclose(mean_p_per_bin[mask_p], expected_mean_p[mask_p], atol=1e-4, rtol=1e-4) + + # mean_gt_per_bin + self.assertTrue(torch.equal(torch.isnan(mean_gt_per_bin), torch.isnan(expected_mean_gt))) + mask_gt = ~torch.isnan(expected_mean_gt) + if mask_gt.any(): + assert_allclose(mean_gt_per_bin[mask_gt], expected_mean_gt[mask_gt], atol=1e-4, rtol=1e-4) + + # bin_counts (no NaNs) + assert_allclose(bin_counts, expected_counts, atol=1e-4, rtol=1e-4) + + def test_shape_mismatch_raises(self): + """Test that mismatched shapes raise ValueError.""" + y_pred = torch.tensor([[[[0.7, 0.3], [0.1, 0.9]]]]).to(_device) + y = torch.tensor([[[[1, 0], [0, 1], [0, 0]]]]).to(_device) # Different shape + with self.assertRaises(ValueError) as context: + calibration_binning(y_pred, y, num_bins=5) + self.assertIn("same shape", str(context.exception)) + + def test_insufficient_ndim_raises(self): + """Test that tensors with ndim < 3 raise ValueError.""" + y_pred = torch.tensor([[0.7, 0.3]]).to(_device) # Only 2D + y = torch.tensor([[1, 0]]).to(_device) + with self.assertRaises(ValueError) as context: + calibration_binning(y_pred, y, num_bins=5) + self.assertIn("ndim", str(context.exception)) + + def test_invalid_num_bins_raises(self): + """Test that num_bins < 1 raises ValueError.""" + y_pred = torch.tensor([[[[0.7, 0.3], [0.1, 0.9]]]]).to(_device) + y = torch.tensor([[[[1, 0], [0, 1]]]]).to(_device) + with self.assertRaises(ValueError) as context: + calibration_binning(y_pred, y, num_bins=0) + self.assertIn("num_bins", str(context.exception)) + + +class TestCalibrationErrorMetricValue(unittest.TestCase): + + @parameterized.expand(VALUE_TEST_CASES) + def test_expected_reduction(self, _name, y_pred, y, num_bins, expected_expected, _expected_average, _expected_max): + y_pred = y_pred.to(_device) + y = y.to(_device) + expected_expected = expected_expected.to(_device) + + metric = CalibrationErrorMetric( + num_bins=num_bins, + include_background=True, + calibration_reduction=CalibrationReduction.EXPECTED, + metric_reduction=MetricReduction.NONE, + ) + + metric(y_pred=y_pred, y=y) + result = metric.aggregate() + + assert_allclose(result, expected_expected, atol=1e-4, rtol=1e-4) + + @parameterized.expand(VALUE_TEST_CASES) + def test_average_reduction(self, _name, y_pred, y, num_bins, _expected_expected, expected_average, _expected_max): + y_pred = y_pred.to(_device) + y = y.to(_device) + expected_average = expected_average.to(_device) + + metric = CalibrationErrorMetric( + num_bins=num_bins, + include_background=True, + calibration_reduction=CalibrationReduction.AVERAGE, + metric_reduction=MetricReduction.NONE, + ) + + metric(y_pred=y_pred, y=y) + result = metric.aggregate() + + assert_allclose(result, expected_average, atol=1e-4, rtol=1e-4) + + @parameterized.expand(VALUE_TEST_CASES) + def test_maximum_reduction(self, _name, y_pred, y, num_bins, _expected_expected, _expected_average, expected_max): + y_pred = y_pred.to(_device) + y = y.to(_device) + expected_max = expected_max.to(_device) + + metric = CalibrationErrorMetric( + num_bins=num_bins, + include_background=True, + calibration_reduction=CalibrationReduction.MAXIMUM, + metric_reduction=MetricReduction.NONE, + ) + + metric(y_pred=y_pred, y=y) + result = metric.aggregate() + + assert_allclose(result, expected_max, atol=1e-4, rtol=1e-4) + + +class TestCalibrationErrorMetricOptions(unittest.TestCase): + + def test_include_background_false(self): + y_pred = torch.tensor( + [ + [[[0.7, 0.3], [0.1, 0.9]], [[0.7, 0.3], [0.5, 0.5]]], + [[[0.9, 0.9], [0.3, 0.3]], [[0.1, 0.1], [0.9, 0.7]]], + ] + ).to(_device) + y = torch.tensor( + [ + [[[1, 0], [0, 1]], [[0, 1], [1, 0]]], + [[[1, 1], [0, 0]], [[0, 0], [1, 1]]], + ] + ).to(_device) + + metric = CalibrationErrorMetric( + num_bins=5, + include_background=False, + calibration_reduction=CalibrationReduction.EXPECTED, + metric_reduction=MetricReduction.MEAN, + ) + + metric(y_pred=y_pred, y=y) + result = metric.aggregate() + + assert_allclose(result, torch.tensor(0.2500, device=_device), atol=1e-4, rtol=1e-4) + + def test_metric_reduction_mean(self): + y_pred = torch.tensor( + [ + [[[0.7, 0.3], [0.1, 0.9]], [[0.7, 0.3], [0.5, 0.5]]], + [[[0.9, 0.9], [0.3, 0.3]], [[0.1, 0.1], [0.9, 0.7]]], + ] + ).to(_device) + y = torch.tensor( + [ + [[[1, 0], [0, 1]], [[0, 1], [1, 0]]], + [[[1, 1], [0, 0]], [[0, 0], [1, 1]]], + ] + ).to(_device) + + metric = CalibrationErrorMetric( + num_bins=5, + include_background=True, + calibration_reduction=CalibrationReduction.EXPECTED, + metric_reduction=MetricReduction.MEAN, + ) + + metric(y_pred=y_pred, y=y) + result = metric.aggregate() + + # Mean of [[0.2000, 0.3500], [0.2000, 0.1500]] = 0.225 + assert_allclose(result, torch.tensor(0.2250, device=_device), atol=1e-4, rtol=1e-4) + + def test_get_not_nans(self): + y_pred = torch.tensor([[[[0.7, 0.3], [0.1, 0.9]]]]).to(_device) + y = torch.tensor([[[[1, 0], [0, 1]]]]).to(_device) + + metric = CalibrationErrorMetric( + num_bins=5, + include_background=True, + calibration_reduction=CalibrationReduction.EXPECTED, + metric_reduction=MetricReduction.MEAN, + get_not_nans=True, + ) + + metric(y_pred=y_pred, y=y) + result, not_nans = metric.aggregate() + + assert_allclose(result, torch.tensor(0.2, device=_device), atol=1e-4, rtol=1e-4) + self.assertEqual(not_nans.item(), 1) + + def test_cumulative_iterations(self): + """Test that the metric correctly accumulates over multiple iterations.""" + y_pred = torch.tensor([[[[0.7, 0.3], [0.1, 0.9]]]]).to(_device) + y = torch.tensor([[[[1, 0], [0, 1]]]]).to(_device) + + metric = CalibrationErrorMetric( + num_bins=5, + include_background=True, + calibration_reduction=CalibrationReduction.EXPECTED, + metric_reduction=MetricReduction.MEAN, + ) + + # First iteration + metric(y_pred=y_pred, y=y) + # Second iteration + metric(y_pred=y_pred, y=y) + + result = metric.aggregate() + # Should still be 0.2 since both iterations have the same data + assert_allclose(result, torch.tensor(0.2, device=_device), atol=1e-4, rtol=1e-4) + + # Test reset + metric.reset() + data = metric.get_buffer() + self.assertIsNone(data) + + +if __name__ == "__main__": + unittest.main() From 1c4b62a7625a48cdb7801dc4c9e3a1c2820dae78 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Fri, 16 Jan 2026 16:49:12 +0000 Subject: [PATCH 2/3] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci Signed-off-by: Theo Barfoot --- monai/handlers/calibration.py | 1 - 1 file changed, 1 deletion(-) diff --git a/monai/handlers/calibration.py b/monai/handlers/calibration.py index 6e6c2b74a9..afc4f45a50 100644 --- a/monai/handlers/calibration.py +++ b/monai/handlers/calibration.py @@ -13,7 +13,6 @@ from collections.abc import Callable -from monai.config import IgniteInfo from monai.handlers.ignite_metric import IgniteMetricHandler from monai.metrics import CalibrationErrorMetric, CalibrationReduction from monai.utils import MetricReduction From f3446cead00e9ec1394894e247f613f387f468b9 Mon Sep 17 00:00:00 2001 From: Theo Barfoot Date: Mon, 19 Jan 2026 15:32:21 +0000 Subject: [PATCH 3/3] Address PR review feedback - Fix MAXIMUM reduction to return NaN (not 0.0) for all-empty bins (CodeRabbit) - Enhance docstrings with 'Why Calibration Matters' section explaining that probabilities should match observed accuracy - Add paper references: Guo et al. 2017 (ICML primary source), MICCAI 2024, - Add Sphinx autodoc entries to metrics.rst and handlers.rst - Improve parameter documentation and usage examples Signed-off-by: Theo Barfoot --- docs/source/handlers.rst | 6 + docs/source/metrics.rst | 9 + monai/handlers/calibration.py | 90 +++++--- monai/metrics/calibration.py | 211 ++++++++++++------ .../test_handler_calibration_error.py | 12 +- tests/metrics/test_calibration_metric.py | 58 +---- 6 files changed, 232 insertions(+), 154 deletions(-) diff --git a/docs/source/handlers.rst b/docs/source/handlers.rst index 49c84dab28..56142819a1 100644 --- a/docs/source/handlers.rst +++ b/docs/source/handlers.rst @@ -83,6 +83,12 @@ Panoptic Quality metrics handler :members: +Calibration Error metrics handler +--------------------------------- +.. autoclass:: CalibrationError + :members: + + Mean squared error metrics handler ---------------------------------- .. autoclass:: MeanSquaredError diff --git a/docs/source/metrics.rst b/docs/source/metrics.rst index 88fbea7ff0..ae3d4a0c92 100644 --- a/docs/source/metrics.rst +++ b/docs/source/metrics.rst @@ -180,6 +180,15 @@ Metrics .. autoclass:: MetricsReloadedCategorical :members: +`Calibration Error` +------------------- +.. autofunction:: calibration_binning + +.. autoclass:: CalibrationReduction + :members: + +.. autoclass:: CalibrationErrorMetric + :members: Utilities diff --git a/monai/handlers/calibration.py b/monai/handlers/calibration.py index afc4f45a50..c5fac0919f 100644 --- a/monai/handlers/calibration.py +++ b/monai/handlers/calibration.py @@ -22,30 +22,72 @@ class CalibrationError(IgniteMetricHandler): """ - Computes Calibration Error and reports the aggregated value according to `metric_reduction` - over all accumulated iterations. Can return the expected, average, or maximum calibration error. + Ignite handler to compute Calibration Error during training or evaluation. + + **Why Calibration Matters:** + + A well-calibrated model produces probability estimates that match the true likelihood of correctness. + For example, predictions with 80% confidence should be correct approximately 80% of the time. + Modern neural networks often exhibit poor calibration (typically overconfident), which can be + problematic in medical imaging where probability estimates may inform clinical decisions. + + This handler wraps :py:class:`~monai.metrics.CalibrationErrorMetric` for use with PyTorch Ignite + engines, automatically computing and aggregating calibration errors across iterations. + + **Supported Calibration Metrics:** + + - **Expected Calibration Error (ECE)**: Weighted average of per-bin errors (most common). + - **Average Calibration Error (ACE)**: Unweighted average across bins. + - **Maximum Calibration Error (MCE)**: Worst-case calibration error. Args: - num_bins: number of bins to calculate calibration. Defaults to 20. - include_background: whether to include calibration error computation on the first channel of - the predicted output. Defaults to True. - calibration_reduction: Method for calculating calibration error values from binned data. - Available modes are `"expected"`, `"average"`, and `"maximum"`. Defaults to `"expected"`. - metric_reduction: Mode of reduction to apply to the metrics. - Reduction is only applied to non-NaN values. - Available reduction modes are `"none"`, `"mean"`, `"sum"`, `"mean_batch"`, - `"sum_batch"`, `"mean_channel"`, and `"sum_channel"`. - Defaults to `"mean"`. If set to `"none"`, no reduction will be performed. - output_transform: callable to extract `y_pred` and `y` from `ignite.engine.state.output` then - construct `(y_pred, y)` pair, where `y_pred` and `y` can be `batch-first` Tensors or - lists of `channel-first` Tensors. the form of `(y_pred, y)` is required by the `update()`. - `engine.state` and `output_transform` inherit from the ignite concept: - https://pytorch.org/ignite/concepts.html#state, explanation and usage example are in the tutorial: - https://github.com/Project-MONAI/tutorials/blob/master/modules/batch_output_transform.ipynb. - save_details: whether to save metric computation details per image, for example: calibration error - of every image. default to True, will save to `engine.state.metric_details` dict with the - metric name as key. + num_bins: Number of equally-spaced bins for calibration computation. Defaults to 20. + include_background: Whether to include the first channel (index 0) in computation. + Set to ``False`` to exclude background in segmentation tasks. Defaults to ``True``. + calibration_reduction: Calibration error reduction mode. Options: ``"expected"`` (ECE), + ``"average"`` (ACE), ``"maximum"`` (MCE). Defaults to ``"expected"``. + metric_reduction: Reduction across batch/channel after computing per-sample errors. + Options: ``"none"``, ``"mean"``, ``"sum"``, ``"mean_batch"``, ``"sum_batch"``, + ``"mean_channel"``, ``"sum_channel"``. Defaults to ``"mean"``. + output_transform: Callable to extract ``(y_pred, y)`` from ``engine.state.output``. + See `Ignite concepts `_ and + the batch output transform tutorial in the MONAI tutorials repository. + save_details: If ``True``, saves per-sample/per-channel metric values to + ``engine.state.metric_details[name]``. Defaults to ``True``. + + References: + - Guo, C., et al. "On Calibration of Modern Neural Networks." ICML 2017. + https://proceedings.mlr.press/v70/guo17a.html + - Barfoot, T., et al. "Average Calibration Error: A Differentiable Loss for Improved + Reliability in Image Segmentation." MICCAI 2024. + https://papers.miccai.org/miccai-2024/091-Paper3075.html + See Also: + - :py:class:`~monai.metrics.CalibrationErrorMetric`: The underlying metric class. + - :py:func:`~monai.metrics.calibration_binning`: Low-level binning for reliability diagrams. + + Example: + >>> from monai.handlers import CalibrationError, from_engine + >>> from ignite.engine import Engine + >>> + >>> def evaluation_step(engine, batch): + ... # Returns dict with "pred" (probabilities) and "label" (one-hot) + ... return {"pred": model(batch["image"]), "label": batch["label"]} + >>> + >>> evaluator = Engine(evaluation_step) + >>> + >>> # Attach calibration error handler + >>> CalibrationError( + ... num_bins=15, + ... include_background=False, + ... calibration_reduction="expected", + ... output_transform=from_engine(["pred", "label"]), + ... ).attach(evaluator, name="ECE") + >>> + >>> # After evaluation, access results + >>> evaluator.run(val_loader) + >>> ece = evaluator.state.metrics["ECE"] + >>> print(f"Expected Calibration Error: {ece:.4f}") """ def __init__( @@ -64,8 +106,4 @@ def __init__( metric_reduction=metric_reduction, ) - super().__init__( - metric_fn=metric_fn, - output_transform=output_transform, - save_details=save_details, - ) + super().__init__(metric_fn=metric_fn, output_transform=output_transform, save_details=save_details) diff --git a/monai/metrics/calibration.py b/monai/metrics/calibration.py index 8d7b5729b9..b3df6aff28 100644 --- a/monai/metrics/calibration.py +++ b/monai/metrics/calibration.py @@ -20,11 +20,7 @@ from monai.utils import MetricReduction from monai.utils.enums import StrEnum -__all__ = [ - "calibration_binning", - "CalibrationErrorMetric", - "CalibrationReduction", -] +__all__ = ["calibration_binning", "CalibrationErrorMetric", "CalibrationReduction"] def calibration_binning( @@ -32,39 +28,62 @@ def calibration_binning( ) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]: """ Compute calibration bins for predicted probabilities and ground truth labels. - This function calculates the mean predicted probabilities, mean ground truths, - and bin counts for each bin using a hard binning calibration approach. + + This function implements hard binning for calibration analysis, grouping predictions + into bins based on their confidence values and computing statistics for each bin. + These statistics can be used to assess model calibration or plot reliability diagrams. + + A well-calibrated model should have predicted probabilities that match empirical accuracy. + For example, among all predictions with 80% confidence, approximately 80% should be correct. + This function provides the per-bin statistics needed to evaluate this property. The function operates on input and target tensors with batch and channel dimensions, handling each batch and channel separately. For bins that do not contain any elements, the mean predicted values and mean ground truth values are set to NaN. Args: - y_pred: predicted tensor with shape [batch, channel, spatial], where spatial - can be any number of dimensions. The y_pred tensor represents probabilities. - Values should be in the range [0, 1] (probabilities). - y: Target tensor with the same shape as y_pred. It represents ground truth values. - num_bins: The number of bins to use for calibration. Defaults to 20. Must be >= 1. - right: If False (default), the bins include the left boundary and exclude the right boundary. - If True, the bins exclude the left boundary and include the right boundary. + y_pred: Predicted probabilities with shape ``(B, C, spatial...)``, where B is batch size, + C is number of classes/channels, and spatial can be any number of dimensions (H, W, D, etc.). + Values should be in the range [0, 1]. + y: Ground truth tensor with the same shape as ``y_pred``. Should be one-hot encoded + or contain binary values (0 or 1) indicating the true class membership. + num_bins: Number of equally-spaced bins to divide the [0, 1] probability range into. + Defaults to 20. Must be >= 1. + right: Determines bin boundary inclusion. If False (default), bins include the left + boundary and exclude the right (i.e., [left, right)). If True, bins exclude the + left boundary and include the right (i.e., (left, right]). Returns: - A tuple of three tensors: - - mean_p_per_bin: Tensor of shape [batch_size, num_channels, num_bins] containing - the mean predicted values in each bin. - - mean_gt_per_bin: Tensor of shape [batch_size, num_channels, num_bins] containing - the mean ground truth values in each bin. - - bin_counts: Tensor of shape [batch_size, num_channels, num_bins] containing - the count of elements in each bin. + A tuple of three tensors, each with shape ``(B, C, num_bins)``: + - **mean_p_per_bin**: Mean predicted probability for samples in each bin. + - **mean_gt_per_bin**: Mean ground truth value (empirical accuracy) for samples in each bin. + - **bin_counts**: Number of samples falling into each bin. + + Bins with no samples have NaN values for mean_p_per_bin and mean_gt_per_bin. Raises: - ValueError: If the input and target shapes do not match, if the input has fewer than 3 dimensions, - or if num_bins < 1. + ValueError: If ``y_pred`` and ``y`` have different shapes, if input has fewer than + 3 dimensions, or if ``num_bins < 1``. + + References: + - Guo, C., et al. "On Calibration of Modern Neural Networks." ICML 2017. + https://proceedings.mlr.press/v70/guo17a.html + - Barfoot, T., et al. "Average Calibration Error: A Differentiable Loss for Improved + Reliability in Image Segmentation." MICCAI 2024. + https://papers.miccai.org/miccai-2024/091-Paper3075.html Note: - This function currently uses nested for loops over batch and channel dimensions - for binning operations. Future improvements may include vectorizing these operations - for enhanced performance. + This function uses nested loops over batch and channel dimensions for binning operations. + For reliability diagram visualization, use the returned statistics to plot mean predicted + probability vs. empirical accuracy for each bin. + + Example: + >>> import torch + >>> # Binary segmentation: batch=1, channels=2, spatial=4x4 + >>> y_pred = torch.rand(1, 2, 4, 4) # predicted probabilities + >>> y = torch.randint(0, 2, (1, 2, 4, 4)).float() # one-hot ground truth + >>> mean_p, mean_gt, counts = calibration_binning(y_pred, y, num_bins=10) + >>> # mean_p, mean_gt, counts each have shape (1, 2, 10) """ # Input validation if y_pred.shape != y.shape: @@ -76,10 +95,7 @@ def calibration_binning( batch_size, num_channels = y_pred.shape[:2] boundaries = torch.linspace( - start=0.0, - end=1.0 + torch.finfo(torch.float32).eps, - steps=num_bins + 1, - device=y_pred.device, + start=0.0, end=1.0 + torch.finfo(torch.float32).eps, steps=num_bins + 1, device=y_pred.device ) mean_p_per_bin = torch.zeros(batch_size, num_channels, num_bins, device=y_pred.device) @@ -124,11 +140,18 @@ def calibration_binning( class CalibrationReduction(StrEnum): """ - Enumeration of calibration error reduction methods. - - - EXPECTED: Expected Calibration Error (ECE) - weighted average by bin count - - AVERAGE: Average Calibration Error (ACE) - simple average across bins - - MAXIMUM: Maximum Calibration Error (MCE) - maximum error across bins + Enumeration of calibration error reduction methods for aggregating per-bin calibration errors. + + - **EXPECTED**: Expected Calibration Error (ECE) - weighted average of per-bin errors by bin count. + This is the most commonly used calibration metric, giving more weight to bins with more samples. + - **AVERAGE**: Average Calibration Error (ACE) - unweighted mean of per-bin errors. + Treats all bins equally regardless of sample count. + - **MAXIMUM**: Maximum Calibration Error (MCE) - worst-case calibration error across all bins. + Useful for identifying the confidence range with poorest calibration. + + References: + - Naeini, M.P., et al. "Obtaining Well Calibrated Probabilities Using Bayesian Binning." AAAI 2015. + - Guo, C., et al. "On Calibration of Modern Neural Networks." ICML 2017. """ EXPECTED = "expected" @@ -139,49 +162,91 @@ class CalibrationReduction(StrEnum): class CalibrationErrorMetric(CumulativeIterationMetric): """ Compute the Calibration Error between predicted probabilities and ground truth labels. - This metric is suitable for multi-class tasks and supports batched inputs. - The input `y_pred` represents the model's predicted probabilities, and `y` represents the ground truth labels. - `y_pred` is expected to have probabilities, and `y` should be in one-hot format. You can use suitable transforms - in `monai.transforms.post` to achieve the desired format. + **Why Calibration Matters:** + + A well-calibrated classifier produces probability estimates that reflect true correctness likelihood. + For instance, if a model predicts 80% probability for class A, a well calibrated and reliable model + should be correct approximately 80% of the time among all such predictions. + Modern neural networks, despite high accuracy, are often poorly calibrated, as they tend to be + overconfident in their predictions. + This is particularly important in medical imaging where probability estimates may inform clinical decisions. + + **How It Works:** - The `include_background` parameter can be set to `False` to exclude the first category (channel index 0), - which is conventionally assumed to be the background. This is particularly useful in segmentation tasks where - the background class might skew the calibration results. + This metric uses a binning approach: predictions are grouped into bins based on their confidence + (predicted probability), and for each bin, the average confidence is compared to the empirical + accuracy (fraction of correct predictions). The calibration error measures the discrepancy between + these values across all bins. - The metric supports both single-channel and multi-channel data. For multi-channel data, the input tensors - should be in the format of BCHW[D], where B is the batch size, C is the number of channels, and HW[D] - are the spatial dimensions. + Three reduction modes are supported: + + - **Expected Calibration Error (ECE)**: Weighted average of per-bin errors, where weights are + proportional to the number of samples in each bin. Most commonly used metric. + - **Average Calibration Error (ACE)**: Simple unweighted average across bins. + - **Maximum Calibration Error (MCE)**: The largest calibration error among all bins. + + The metric supports both single-channel and multi-channel data in the format ``(B, C, H, W[, D])``, + where B is batch size, C is number of classes, and H, W, D are spatial dimensions. Args: - num_bins: Number of bins to divide probabilities into for calibration calculation. Defaults to 20. - include_background: Whether to include computation on the first channel of the predicted output. - Defaults to `True`. - calibration_reduction: Method for calculating calibration error values from binned data. - Available modes are `"expected"`, `"average"`, and `"maximum"`. Defaults to `"expected"`. - metric_reduction: Mode of reduction to apply to the metrics. - Reduction is only applied to non-NaN values. - Available reduction modes are `"none"`, `"mean"`, `"sum"`, `"mean_batch"`, - `"sum_batch"`, `"mean_channel"`, and `"sum_channel"`. - Defaults to `"mean"`. If set to `"none"`, no reduction will be performed. - get_not_nans: Whether to return the count of non-NaN values. - If `True`, `aggregate()` returns a tuple (metric, not_nans). Defaults to `False`. - right: Whether to use the right or left bin edge for binning. Defaults to `False` (left). - - Example of the typical execution steps of this metric class follows :py:class:`monai.metrics.metric.Cumulative`. + num_bins: Number of equally-spaced bins to divide the [0, 1] probability range into. + Defaults to 20. + include_background: Whether to include the first channel (index 0) in the computation. + Set to ``False`` to exclude background class, which is useful in segmentation tasks + where background may dominate and skew calibration results. Defaults to ``True``. + calibration_reduction: Method for calculating calibration error from binned data. + Available modes: ``"expected"`` (ECE), ``"average"`` (ACE), ``"maximum"`` (MCE). + Defaults to ``"expected"``. + metric_reduction: Reduction mode to apply across batch/channel dimensions after computing + per-sample calibration errors. Available modes: ``"none"``, ``"mean"``, ``"sum"``, + ``"mean_batch"``, ``"sum_batch"``, ``"mean_channel"``, ``"sum_channel"``. + Defaults to ``"mean"``. + get_not_nans: If ``True``, ``aggregate()`` returns a tuple ``(metric, not_nans)`` where + ``not_nans`` is the count of non-NaN values. Defaults to ``False``. + right: Bin boundary inclusion rule. If ``False`` (default), bins are ``[left, right)``. + If ``True``, bins are ``(left, right]``. + + References: + - Guo, C., et al. "On Calibration of Modern Neural Networks." ICML 2017. + https://proceedings.mlr.press/v70/guo17a.html + - Barfoot, T., et al. "Average Calibration Error: A Differentiable Loss for Improved + Reliability in Image Segmentation." MICCAI 2024. + https://papers.miccai.org/miccai-2024/091-Paper3075.html + + See Also: + - :py:class:`monai.handlers.CalibrationError`: Ignite handler wrapper for this metric. + - :py:func:`calibration_binning`: Low-level binning function for reliability diagrams. Example: + Typical execution steps follow :py:class:`monai.metrics.metric.Cumulative`. + + >>> import torch + >>> from monai.metrics import CalibrationErrorMetric >>> from monai.transforms import Activations, AsDiscrete - >>> # Transforms to convert model outputs to probabilities and labels to one-hot - >>> softmax = Activations(softmax=True) # or sigmoid=True for binary/multi-label - >>> to_onehot = AsDiscrete(to_onehot=num_classes) - >>> metric = CalibrationErrorMetric(num_bins=15, include_background=False, calibration_reduction="expected") + >>> + >>> # Setup transforms for probability conversion + >>> num_classes = 3 + >>> softmax = Activations(softmax=True) # convert logits to probabilities + >>> to_onehot = AsDiscrete(to_onehot=num_classes) # convert labels to one-hot + >>> + >>> # Create metric (Expected Calibration Error, excluding background) + >>> metric = CalibrationErrorMetric( + ... num_bins=15, + ... include_background=False, + ... calibration_reduction="expected" + ... ) + >>> + >>> # Evaluation loop >>> for batch_data in dataloader: - >>> logits, labels = model(batch_data) - >>> preds = softmax(logits) # convert logits to probabilities - >>> labels_onehot = to_onehot(labels) # convert labels to one-hot format - >>> metric(y_pred=preds, y=labels_onehot) + ... logits, labels = model(batch_data) + ... preds = softmax(logits) # shape: (B, C, H, W) with values in [0, 1] + ... labels_onehot = to_onehot(labels) # shape: (B, C, H, W) with values 0 or 1 + ... metric(y_pred=preds, y=labels_onehot) + >>> + >>> # Get final calibration error >>> ece = metric.aggregate() + >>> print(f"Expected Calibration Error: {ece:.4f}") """ def __init__( @@ -231,8 +296,14 @@ def _compute_tensor(self, y_pred: torch.Tensor, y: torch.Tensor, **kwargs: Any) elif self.calibration_reduction == CalibrationReduction.AVERAGE: return torch.nanmean(abs_diff, dim=-1) # Average across all dimensions, ignoring nan elif self.calibration_reduction == CalibrationReduction.MAXIMUM: - abs_diff_no_nan = torch.nan_to_num(abs_diff, nan=0.0) - return torch.max(abs_diff_no_nan, dim=-1).values # Maximum across all dimensions + # Replace NaN with -inf for max computation, then restore NaN for all-NaN cases + abs_diff_for_max = torch.nan_to_num(abs_diff, nan=float("-inf")) + max_vals = torch.max(abs_diff_for_max, dim=-1).values + # Restore NaN where all bins were empty (max is -inf) + max_vals = torch.where( + max_vals == float("-inf"), torch.tensor(float("nan"), device=max_vals.device), max_vals + ) + return max_vals else: raise ValueError(f"Unsupported calibration reduction: {self.calibration_reduction}") diff --git a/tests/handlers/test_handler_calibration_error.py b/tests/handlers/test_handler_calibration_error.py index 5cc0c2609c..1837f721c7 100644 --- a/tests/handlers/test_handler_calibration_error.py +++ b/tests/handlers/test_handler_calibration_error.py @@ -88,17 +88,9 @@ def test_compute(self, input_params, expected_value, expected_rows, expected_cha # Test data: 2 batches with 2 channels each y_pred = torch.tensor( - [ - [[[0.7, 0.3], [0.1, 0.9]], [[0.7, 0.3], [0.5, 0.5]]], - [[[0.9, 0.9], [0.3, 0.3]], [[0.1, 0.1], [0.9, 0.7]]], - ] - ).to(_device) - y = torch.tensor( - [ - [[[1, 0], [0, 1]], [[0, 1], [1, 0]]], - [[[1, 1], [0, 0]], [[0, 0], [1, 1]]], - ] + [[[[0.7, 0.3], [0.1, 0.9]], [[0.7, 0.3], [0.5, 0.5]]], [[[0.9, 0.9], [0.3, 0.3]], [[0.1, 0.1], [0.9, 0.7]]]] ).to(_device) + y = torch.tensor([[[[1, 0], [0, 1]], [[0, 1], [1, 0]]], [[[1, 1], [0, 0]], [[0, 0], [1, 1]]]]).to(_device) # Create data as list of batches (2 iterations) data = [{"pred": y_pred, "label": y}, {"pred": y_pred, "label": y}] diff --git a/tests/metrics/test_calibration_metric.py b/tests/metrics/test_calibration_metric.py index f220525793..17dcfef6c2 100644 --- a/tests/metrics/test_calibration_metric.py +++ b/tests/metrics/test_calibration_metric.py @@ -39,17 +39,9 @@ TEST_BINNING_LARGE_MID = [ "large_mid", torch.tensor( - [ - [[[0.7, 0.3], [0.1, 0.9]], [[0.7, 0.3], [0.5, 0.5]]], - [[[0.9, 0.9], [0.3, 0.3]], [[0.1, 0.1], [0.9, 0.7]]], - ] - ), - torch.tensor( - [ - [[[1, 0], [0, 1]], [[0, 1], [1, 0]]], - [[[1, 1], [0, 0]], [[0, 0], [1, 1]]], - ] + [[[[0.7, 0.3], [0.1, 0.9]], [[0.7, 0.3], [0.5, 0.5]]], [[[0.9, 0.9], [0.3, 0.3]], [[0.1, 0.1], [0.9, 0.7]]]] ), + torch.tensor([[[[1, 0], [0, 1]], [[0, 1], [1, 0]]], [[[1, 1], [0, 0]], [[0, 0], [1, 1]]]]), 5, False, torch.tensor( @@ -65,10 +57,7 @@ ] ), torch.tensor( - [ - [[1.0, 1.0, 0.0, 1.0, 1.0], [0.0, 1.0, 2.0, 1.0, 0.0]], - [[0.0, 2.0, 0.0, 0.0, 2.0], [2.0, 0.0, 0.0, 1.0, 1.0]], - ] + [[[1.0, 1.0, 0.0, 1.0, 1.0], [0.0, 1.0, 2.0, 1.0, 0.0]], [[0.0, 2.0, 0.0, 0.0, 2.0], [2.0, 0.0, 0.0, 1.0, 1.0]]] ), ] @@ -116,27 +105,16 @@ TEST_VALUE_2B2C = [ "2b2c", torch.tensor( - [ - [[[0.7, 0.3], [0.1, 0.9]], [[0.7, 0.3], [0.5, 0.5]]], - [[[0.9, 0.9], [0.3, 0.3]], [[0.1, 0.1], [0.9, 0.7]]], - ] - ), - torch.tensor( - [ - [[[1, 0], [0, 1]], [[0, 1], [1, 0]]], - [[[1, 1], [0, 0]], [[0, 0], [1, 1]]], - ] + [[[[0.7, 0.3], [0.1, 0.9]], [[0.7, 0.3], [0.5, 0.5]]], [[[0.9, 0.9], [0.3, 0.3]], [[0.1, 0.1], [0.9, 0.7]]]] ), + torch.tensor([[[[1, 0], [0, 1]], [[0, 1], [1, 0]]], [[[1, 1], [0, 0]], [[0, 0], [1, 1]]]]), 5, torch.tensor([[0.2000, 0.3500], [0.2000, 0.1500]]), torch.tensor([[0.2000, 0.4667], [0.2000, 0.1667]]), torch.tensor([[0.3000, 0.7000], [0.3000, 0.3000]]), ] -VALUE_TEST_CASES = [ - TEST_VALUE_1B1C, - TEST_VALUE_2B2C, -] +VALUE_TEST_CASES = [TEST_VALUE_1B1C, TEST_VALUE_2B2C] class TestCalibrationBinning(unittest.TestCase): @@ -257,17 +235,9 @@ class TestCalibrationErrorMetricOptions(unittest.TestCase): def test_include_background_false(self): y_pred = torch.tensor( - [ - [[[0.7, 0.3], [0.1, 0.9]], [[0.7, 0.3], [0.5, 0.5]]], - [[[0.9, 0.9], [0.3, 0.3]], [[0.1, 0.1], [0.9, 0.7]]], - ] - ).to(_device) - y = torch.tensor( - [ - [[[1, 0], [0, 1]], [[0, 1], [1, 0]]], - [[[1, 1], [0, 0]], [[0, 0], [1, 1]]], - ] + [[[[0.7, 0.3], [0.1, 0.9]], [[0.7, 0.3], [0.5, 0.5]]], [[[0.9, 0.9], [0.3, 0.3]], [[0.1, 0.1], [0.9, 0.7]]]] ).to(_device) + y = torch.tensor([[[[1, 0], [0, 1]], [[0, 1], [1, 0]]], [[[1, 1], [0, 0]], [[0, 0], [1, 1]]]]).to(_device) metric = CalibrationErrorMetric( num_bins=5, @@ -283,17 +253,9 @@ def test_include_background_false(self): def test_metric_reduction_mean(self): y_pred = torch.tensor( - [ - [[[0.7, 0.3], [0.1, 0.9]], [[0.7, 0.3], [0.5, 0.5]]], - [[[0.9, 0.9], [0.3, 0.3]], [[0.1, 0.1], [0.9, 0.7]]], - ] - ).to(_device) - y = torch.tensor( - [ - [[[1, 0], [0, 1]], [[0, 1], [1, 0]]], - [[[1, 1], [0, 0]], [[0, 0], [1, 1]]], - ] + [[[[0.7, 0.3], [0.1, 0.9]], [[0.7, 0.3], [0.5, 0.5]]], [[[0.9, 0.9], [0.3, 0.3]], [[0.1, 0.1], [0.9, 0.7]]]] ).to(_device) + y = torch.tensor([[[[1, 0], [0, 1]], [[0, 1], [1, 0]]], [[[1, 1], [0, 0]], [[0, 0], [1, 1]]]]).to(_device) metric = CalibrationErrorMetric( num_bins=5,