diff --git a/docs/source/metrics.rst b/docs/source/metrics.rst index 0696cc3070a..f6742f73be5 100644 --- a/docs/source/metrics.rst +++ b/docs/source/metrics.rst @@ -337,6 +337,7 @@ Complete list of metrics metric.Metric metrics_lambda.MetricsLambda MultiLabelConfusionMatrix + MutualInformation precision.Precision PSNR recall.Recall diff --git a/ignite/metrics/__init__.py b/ignite/metrics/__init__.py index 1b23257d4aa..05ce97c4066 100644 --- a/ignite/metrics/__init__.py +++ b/ignite/metrics/__init__.py @@ -21,6 +21,7 @@ from ignite.metrics.metric import BatchFiltered, BatchWise, EpochWise, Metric, MetricUsage from ignite.metrics.metrics_lambda import MetricsLambda from ignite.metrics.multilabel_confusion_matrix import MultiLabelConfusionMatrix +from ignite.metrics.mutual_information import MutualInformation from ignite.metrics.nlp.bleu import Bleu from ignite.metrics.nlp.rouge import Rouge, RougeL, RougeN from ignite.metrics.precision import Precision @@ -57,6 +58,7 @@ "mIoU", "JaccardIndex", "MultiLabelConfusionMatrix", + "MutualInformation", "Precision", "PSNR", "Recall", diff --git a/ignite/metrics/entropy.py b/ignite/metrics/entropy.py index b3d0cff21b6..4208bf205b3 100644 --- a/ignite/metrics/entropy.py +++ b/ignite/metrics/entropy.py @@ -80,9 +80,13 @@ def update(self, output: Sequence[torch.Tensor]) -> None: prob = F.softmax(y_pred, dim=1) log_prob = F.log_softmax(y_pred, dim=1) + + self._update(prob, log_prob) + + def _update(self, prob: torch.Tensor, log_prob: torch.Tensor) -> None: entropy_sum = -torch.sum(prob * log_prob) self._sum_of_entropies += entropy_sum.to(self._device) - self._num_examples += y_pred.shape[0] + self._num_examples += prob.shape[0] @sync_all_reduce("_sum_of_entropies", "_num_examples") def compute(self) -> float: diff --git a/ignite/metrics/mutual_information.py b/ignite/metrics/mutual_information.py new file mode 100644 index 00000000000..2cca768ce43 --- /dev/null +++ b/ignite/metrics/mutual_information.py @@ -0,0 +1,94 @@ +import torch + +from ignite.exceptions import NotComputableError +from ignite.metrics import Entropy +from ignite.metrics.metric import reinit__is_reduced, sync_all_reduce + +__all__ = ["MutualInformation"] + + +class MutualInformation(Entropy): + r"""Calculates the `mutual information `_ + between input :math:`X` and prediction :math:`Y`. + + .. math:: + \begin{align*} + I(X;Y) &= H(Y) - H(Y|X) = H \left( \frac{1}{N}\sum_{i=1}^N \hat{\mathbf{p}}_i \right) + - \frac{1}{N}\sum_{i=1}^N H(\hat{\mathbf{p}}_i), \\ + H(\mathbf{p}) &= -\sum_{c=1}^C p_c \log p_c. + \end{align*} + + where :math:`\hat{\mathbf{p}}_i` is the prediction probability vector for :math:`i`-th input, + and :math:`H(\mathbf{p})` is the entropy of :math:`\mathbf{p}`. + + Intuitively, this metric measures how well input data are clustered by classes in the feature space [1]. + + [1] https://proceedings.mlr.press/v70/hu17b.html + + - ``update`` must receive output of the form ``(y_pred, y)`` while ``y`` is not used in this metric. + - ``y_pred`` is expected to be the unnormalized logits for each class. :math:`(B, C)` (classification) + or :math:`(B, C, ...)` (e.g., image segmentation) shapes are allowed. + + Args: + output_transform: a callable that is used to transform the + :class:`~ignite.engine.engine.Engine`'s ``process_function``'s output into the + form expected by the metric. This can be useful if, for example, you have a multi-output model and + you want to compute the metric with respect to one of the outputs. + By default, metrics require the output as ``(y_pred, y)`` or ``{'y_pred': y_pred, 'y': y}``. + device: specifies which device updates are accumulated on. Setting the + metric's device to be the same as your ``update`` arguments ensures the ``update`` method is + non-blocking. By default, CPU. + + Examples: + To use with ``Engine`` and ``process_function``, simply attach the metric instance to the engine. + The output of the engine's ``process_function`` needs to be in the format of + ``(y_pred, y)`` or ``{'y_pred': y_pred, 'y': y, ...}``. If not, ``output_tranform`` can be added + to the metric to transform the output into the form expected by the metric. + + For more information on how metric works with :class:`~ignite.engine.engine.Engine`, visit :ref:`attach-engine`. + + .. include:: defaults.rst + :start-after: :orphan: + + .. testcode:: + + metric = MutualInformation() + metric.attach(default_evaluator, 'mutual_information') + y_true = torch.tensor([0, 1, 2]) # not considered in the MutualInformation metric. + y_pred = torch.tensor([ + [ 0.0000, 0.6931, 1.0986], + [ 1.3863, 1.6094, 1.6094], + [ 0.0000, -2.3026, -2.3026] + ]) + state = default_evaluator.run([[y_pred, y_true]]) + print(state.metrics['mutual_information']) + + .. testoutput:: + + 0.18599730730056763 + """ + + _state_dict_all_req_keys = ("_sum_of_probabilities",) + + @reinit__is_reduced + def reset(self) -> None: + super().reset() + self._sum_of_probabilities = torch.tensor(0.0, device=self._device) + + def _update(self, prob: torch.Tensor, log_prob: torch.Tensor) -> None: + super()._update(prob, log_prob) + # We can't use += below as _sum_of_probabilities can be a scalar and prob.sum(dim=0) is a vector + self._sum_of_probabilities = self._sum_of_probabilities + prob.sum(dim=0).to(self._device) + + @sync_all_reduce("_sum_of_probabilities", "_sum_of_entropies", "_num_examples") + def compute(self) -> float: + n = self._num_examples + if n == 0: + raise NotComputableError("MutualInformation must have at least one example before it can be computed.") + + marginal_prob = self._sum_of_probabilities / n + marginal_ent = -(marginal_prob * torch.log(marginal_prob)).sum() + conditional_ent = self._sum_of_entropies / n + mi = marginal_ent - conditional_ent + mi = torch.clamp(mi, min=0.0) # mutual information cannot be negative + return float(mi.item()) diff --git a/tests/ignite/metrics/test_mutual_information.py b/tests/ignite/metrics/test_mutual_information.py new file mode 100644 index 00000000000..18d58d300bf --- /dev/null +++ b/tests/ignite/metrics/test_mutual_information.py @@ -0,0 +1,145 @@ +from typing import Tuple + +import numpy as np +import pytest +import torch +from scipy.special import softmax +from scipy.stats import entropy +from torch import Tensor + +import ignite.distributed as idist + +from ignite.engine import Engine +from ignite.exceptions import NotComputableError +from ignite.metrics import MutualInformation + + +def np_mutual_information(np_y_pred: np.ndarray) -> float: + prob = softmax(np_y_pred, axis=1) + marginal_ent = entropy(np.mean(prob, axis=0)) + conditional_ent = np.mean(entropy(prob, axis=1)) + return max(0.0, marginal_ent - conditional_ent) + + +def test_zero_sample(): + mi = MutualInformation() + with pytest.raises( + NotComputableError, match=r"MutualInformation must have at least one example before it can be computed" + ): + mi.compute() + + +def test_invalid_shape(): + mi = MutualInformation() + y_pred = torch.randn(10).float() + with pytest.raises(ValueError, match=r"y_pred must be in the shape of \(B, C\) or \(B, C, ...\), got"): + mi.update((y_pred, None)) + + +@pytest.fixture(params=list(range(4))) +def test_case(request): + return [ + (torch.randn((100, 10)).float(), torch.randint(0, 10, size=[100]), 1), + (torch.rand((100, 500)).float(), torch.randint(0, 500, size=[100]), 1), + # updated batches + (torch.normal(0.0, 5.0, size=(100, 10)).float(), torch.randint(0, 10, size=[100]), 16), + (torch.normal(5.0, 3.0, size=(100, 200)).float(), torch.randint(0, 200, size=[100]), 16), + # image segmentation + (torch.randn((100, 5, 32, 32)).float(), torch.randint(0, 5, size=(100, 32, 32)), 16), + (torch.randn((100, 5, 224, 224)).float(), torch.randint(0, 5, size=(100, 224, 224)), 16), + ][request.param] + + +@pytest.mark.parametrize("n_times", range(5)) +def test_compute(n_times, test_case: Tuple[Tensor, Tensor, int]): + mi = MutualInformation() + + y_pred, y, batch_size = test_case + + mi.reset() + if batch_size > 1: + n_iters = y.shape[0] // batch_size + 1 + for i in range(n_iters): + idx = i * batch_size + mi.update((y_pred[idx : idx + batch_size], y[idx : idx + batch_size])) + else: + mi.update((y_pred, y)) + + np_res = np_mutual_information(y_pred.numpy()) + res = mi.compute() + + assert isinstance(res, float) + assert pytest.approx(np_res, rel=1e-4) == res + + +def test_accumulator_detached(): + mi = MutualInformation() + + y_pred = torch.tensor([[2.0, 3.0], [-2.0, -1.0]], requires_grad=True) + y = torch.zeros(2) + mi.update((y_pred, y)) + + assert not mi._sum_of_probabilities.requires_grad + + +@pytest.mark.usefixtures("distributed") +class TestDistributed: + def test_integration(self): + tol = 1e-4 + n_iters = 100 + batch_size = 10 + n_cls = 50 + device = idist.device() + rank = idist.get_rank() + torch.manual_seed(12 + rank) + + metric_devices = [torch.device("cpu")] + if device.type != "xla": + metric_devices.append(device) + + for metric_device in metric_devices: + y_true = torch.randint(0, n_cls, size=[n_iters * batch_size], dtype=torch.long).to(device) + y_preds = torch.normal(0.0, 3.0, size=(n_iters * batch_size, n_cls), dtype=torch.float).to(device) + + engine = Engine( + lambda e, i: ( + y_preds[i * batch_size : (i + 1) * batch_size], + y_true[i * batch_size : (i + 1) * batch_size], + ) + ) + + m = MutualInformation(device=metric_device) + m.attach(engine, "mutual_information") + + data = list(range(n_iters)) + engine.run(data=data, max_epochs=1) + + y_preds = idist.all_gather(y_preds) + y_true = idist.all_gather(y_true) + + assert "mutual_information" in engine.state.metrics + res = engine.state.metrics["mutual_information"] + + true_res = np_mutual_information(y_preds.cpu().numpy()) + + assert pytest.approx(true_res, rel=tol) == res + + def test_accumulator_device(self): + device = idist.device() + metric_devices = [torch.device("cpu")] + if device.type != "xla": + metric_devices.append(device) + for metric_device in metric_devices: + mi = MutualInformation(device=metric_device) + + devices = (mi._device, mi._sum_of_probabilities.device) + for dev in devices: + assert dev == metric_device, f"{type(dev)}:{dev} vs {type(metric_device)}:{metric_device}" + + y_pred = torch.tensor([[2.0, 3.0], [-2.0, -1.0]], requires_grad=True) + y = torch.zeros(2) + mi.update((y_pred, y)) + + devices = (mi._device, mi._sum_of_probabilities.device) + for dev in devices: + assert dev == metric_device, f"{type(dev)}:{dev} vs {type(metric_device)}:{metric_device}"