From 2e0d9fdd226710a8ed8fdcb15266fcf79d1dd242 Mon Sep 17 00:00:00 2001 From: Kazuki Adachi Date: Sat, 23 Mar 2024 22:50:16 +0900 Subject: [PATCH 01/19] add MutualInformationMetric --- ignite/metrics/__init__.py | 2 + ignite/metrics/mutual_information.py | 106 +++++++++++ .../ignite/metrics/test_mutual_information.py | 175 ++++++++++++++++++ 3 files changed, 283 insertions(+) create mode 100644 ignite/metrics/mutual_information.py create mode 100644 tests/ignite/metrics/test_mutual_information.py diff --git a/ignite/metrics/__init__.py b/ignite/metrics/__init__.py index 04b490b9486..f6f360065f6 100644 --- a/ignite/metrics/__init__.py +++ b/ignite/metrics/__init__.py @@ -16,6 +16,7 @@ from ignite.metrics.metric import BatchFiltered, BatchWise, EpochWise, Metric, MetricUsage from ignite.metrics.metrics_lambda import MetricsLambda from ignite.metrics.multilabel_confusion_matrix import MultiLabelConfusionMatrix +from ignite.metrics.mutual_information import MutualInformation from ignite.metrics.nlp.bleu import Bleu from ignite.metrics.nlp.rouge import Rouge, RougeL, RougeN from ignite.metrics.precision import Precision @@ -50,6 +51,7 @@ "mIoU", "JaccardIndex", "MultiLabelConfusionMatrix", + "MutualInformation", "Precision", "PSNR", "Recall", diff --git a/ignite/metrics/mutual_information.py b/ignite/metrics/mutual_information.py new file mode 100644 index 00000000000..81926df7abe --- /dev/null +++ b/ignite/metrics/mutual_information.py @@ -0,0 +1,106 @@ +from typing import Sequence + +import torch +import torch.nn.functional as F + +from ignite.exceptions import NotComputableError +from ignite.metrics.metric import Metric, reinit__is_reduced, sync_all_reduce + +__all__ = ["MutualInformation"] + + +class MutualInformation(Metric): + r"""Calculates the `mutual information `_ + between input :math:`X` and prediction :math:`Y`. + + .. math:: I(X;Y) = H(Y) - H(Y|X) + = H \left( \frac{1}{N}\sum_{i=1}^N \hat{\mathbf{p}}_i \right) + - \frac{1}{N}\sum_{i=1}^N H(\hat{\mathbf{p}}_i), + + H(\mathbf{p}) = -\sum_{c=1}^C p_c \log p_c. + + where :math:`\hat{\mathbf{p}}_i` is the prediction probability vector for :math:`i`-th input, + and :math:`H(\mathbf{p})` is the entropy of :math:`\mathbf{p}`. + + - ``update`` must receive output of the form ``(y_pred, y)`` while ``y`` is not used in this metric. + - ``y_pred`` is expected to be the unnormalized logits for each class. :math:`(B, C)` (classification) + or :math:`(B, C, ...)` (e.g., image segmentation) shapes are allowed. + + Args: + output_transform: a callable that is used to transform the + :class:`~ignite.engine.engine.Engine`'s ``process_function``'s output into the + form expected by the metric. This can be useful if, for example, you have a multi-output model and + you want to compute the metric with respect to one of the outputs. + By default, metrics require the output as ``(y_pred, y)`` or ``{'y_pred': y_pred, 'y': y}``. + device: specifies which device updates are accumulated on. Setting the + metric's device to be the same as your ``update`` arguments ensures the ``update`` method is + non-blocking. By default, CPU. + + Examples: + To use with ``Engine`` and ``process_function``, simply attach the metric instance to the engine. + The output of the engine's ``process_function`` needs to be in the format of + ``(y_pred, y)`` or ``{'y_pred': y_pred, 'y': y, ...}``. If not, ``output_tranform`` can be added + to the metric to transform the output into the form expected by the metric. + + For more information on how metric works with :class:`~ignite.engine.engine.Engine`, visit :ref:`attach-engine`. + + .. include:: defaults.rst + :start-after: :orphan: + + .. testcode:: + + metric = MutualInformation() + metric.attach(default_evaluator, 'mutual_information') + y_true = torch.tensor([0, 1, 2]) # not considered in the MutualInformation metric. + y_pred = torch.tensor([ + [ 0.0000, 0.6931, 1.0986], + [ 1.3863, 1.6094, 1.6094], + [ 0.0000, -2.3026, -2.3026] + ]) + state = default_evaluator.run([[y_pred, y_true]]) + print(state.metrics['mutual_information']) + + .. testoutput:: + + 0.18599730730056763 + """ + + _state_dict_all_req_keys = ("_sum_of_probabilities", "_sum_of_conditional_entropies", "_num_examples") + + @reinit__is_reduced + def reset(self) -> None: + self._sum_of_probabilities = torch.tensor(0.0, device=self._device) + self._sum_of_conditional_entropies = torch.tensor(0.0, device=self._device) + self._num_examples = 0 + + @reinit__is_reduced + def update(self, output: Sequence[torch.Tensor]) -> None: + y_pred = output[0].detach() + if y_pred.ndim >= 3: + num_classes = y_pred.shape[1] + # (B, C, ...) -> (B, ..., C) -> (B*..., C) + # regarding as B*... predictions + y_pred = y_pred.movedim(1, -1).reshape(-1, num_classes) + elif y_pred.ndim == 1: + raise ValueError(f"y_pred must be in the shape of (B, C) or (B, C, ...), got {y_pred.shape}.") + + prob = F.softmax(y_pred, dim=1) + log_prob = F.log_softmax(y_pred, dim=1) + ent_sum = -(prob * log_prob).sum() + + self._sum_of_probabilities = self._sum_of_probabilities + prob.sum(dim=0) + self._sum_of_conditional_entropies += ent_sum + self._num_examples += y_pred.shape[0] + + @sync_all_reduce("_sum_of_probabilities", "_sum_of_conditional_entropies", "_num_examples") + def compute(self) -> float: + n = self._num_examples + if n == 0: + raise NotComputableError("MutualInformation must have at least one example before it can be computed.") + + marginal_prob = self._sum_of_probabilities / n + marginal_ent = -(marginal_prob * torch.log(marginal_prob)).sum() + conditional_ent = self._sum_of_conditional_entropies / n + mi = marginal_ent - conditional_ent + mi = torch.clamp(mi, min=0.0) # mutual information cannot be negative + return float(mi.item()) diff --git a/tests/ignite/metrics/test_mutual_information.py b/tests/ignite/metrics/test_mutual_information.py new file mode 100644 index 00000000000..8e2f2cd4766 --- /dev/null +++ b/tests/ignite/metrics/test_mutual_information.py @@ -0,0 +1,175 @@ +from typing import Tuple +import os + +import numpy as np +from scipy.special import softmax +from scipy.stats import entropy +import pytest +import torch +from torch import Tensor + +from ignite.engine import Engine +import ignite.distributed as idist +from ignite.exceptions import NotComputableError +from ignite.metrics import MutualInformation + + +def np_mutual_information(np_y_pred: np.ndarray) -> float: + prob = softmax(np_y_pred, axis=1) + marginal_ent = entropy(np.mean(prob, axis=0)) + conditional_ent = np.mean(entropy(prob, axis=1)) + return max(0.0, marginal_ent - conditional_ent) + + +def test_zero_sample(): + mi = MutualInformation() + with pytest.raises( + NotComputableError, match=r"MutualInformation must have at least one example before it can be computed" + ): + mi.compute() + + +def test_invalid_shape(): + mi = MutualInformation() + y_pred = torch.randn(10).float() + with pytest.raises(ValueError, match=r"y_pred must be in the shape of \(B, C\) or \(B, C, ...\), got"): + mi.update((y_pred, None)) + + +@pytest.fixture(params=[item for item in range(4)]) +def test_case(request): + return [ + (torch.randn((100, 10)).float(), torch.randint(0, 10, size=[100]), 1), + (torch.rand((100, 500)).float(), torch.randint(0, 500, size=[100]), 1), + # updated batches + (torch.normal(0.0, 5.0, size=(100, 10)).float(), torch.randint(0, 10, size=[100]), 16), + (torch.normal(5.0, 3.0, size=(100, 200)).float(), torch.randint(0, 200, size=[100]), 16), + # image segmentation + (torch.randn((100, 5, 32, 32)).float(), torch.randint(0, 5, size=(100, 32, 32)), 16), + (torch.randn((100, 5, 224, 224)).float(), torch.randint(0, 5, size=(100, 224, 224)), 16), + ][request.param] + + +@pytest.mark.parametrize("n_times", range(5)) +def test_compute(n_times, test_case: Tuple[Tensor, Tensor, int]): + mi = MutualInformation() + + y_pred, y, batch_size = test_case + + mi.reset() + if batch_size > 1: + n_iters = y.shape[0] // batch_size + 1 + for i in range(n_iters): + idx = i * batch_size + mi.update((y_pred[idx: idx + batch_size], y[idx: idx + batch_size])) + else: + mi.update((y_pred, y)) + + np_res = np_mutual_information(y_pred.numpy()) + res = mi.compute() + + assert isinstance(res, float) + assert pytest.approx(np_res, rel=1e-4) == res + + +def _test_distrib_integration(device, tol=1e-4): + rank = idist.get_rank() + torch.manual_seed(12 + rank) + + def _test(metric_device): + n_iters = 100 + batch_size = 10 + n_cls = 50 + + y_true = torch.randint(0, n_cls, size=[n_iters * batch_size], dtype=torch.long).to(device) + y_preds = torch.normal(0.0, 3.0, size=(n_iters * batch_size, n_cls), dtype=torch.float).to(device) + + def update(engine, i): + return ( + y_preds[i * batch_size: (i + 1) * batch_size], + y_true[i * batch_size: (i + 1) * batch_size], + ) + + engine = Engine(update) + + m = MutualInformation(device=metric_device) + m.attach(engine, "mutual_information") + + data = list(range(n_iters)) + engine.run(data=data, max_epochs=1) + + y_preds = idist.all_gather(y_preds) + y_true = idist.all_gather(y_true) + + assert "mutual_information" in engine.state.metrics + res = engine.state.metrics["mutual_information"] + + true_res = np_mutual_information(y_preds.cpu().numpy()) + + assert pytest.approx(true_res, rel=tol) == res + + _test("cpu") + if device.type != "xla": + _test(idist.device()) + + +@pytest.mark.distributed +@pytest.mark.skipif(not idist.has_native_dist_support, reason="Skip if no native dist support") +@pytest.mark.skipif(torch.cuda.device_count() < 1, reason="Skip if no GPU") +def test_distrib_nccl_gpu(distributed_context_single_node_nccl): + device = idist.device() + _test_distrib_integration(device) + + +@pytest.mark.distributed +@pytest.mark.skipif(not idist.has_native_dist_support, reason="Skip if no native dist support") +def test_distrib_gloo_cpu_or_gpu(distributed_context_single_node_gloo): + device = idist.device() + _test_distrib_integration(device) + + +@pytest.mark.distributed +@pytest.mark.skipif(not idist.has_hvd_support, reason="Skip if no Horovod dist support") +@pytest.mark.skipif("WORLD_SIZE" in os.environ, reason="Skip if launched as multiproc") +def test_distrib_hvd(gloo_hvd_executor): + device = torch.device("cpu" if not torch.cuda.is_available() else "cuda") + nproc = 4 if not torch.cuda.is_available() else torch.cuda.device_count() + + gloo_hvd_executor(_test_distrib_integration, (device,), np=nproc, do_init=True) + + +@pytest.mark.multinode_distributed +@pytest.mark.skipif(not idist.has_native_dist_support, reason="Skip if no native dist support") +@pytest.mark.skipif("MULTINODE_DISTRIB" not in os.environ, reason="Skip if not multi-node distributed") +def test_multinode_distrib_gloo_cpu_or_gpu(distributed_context_multi_node_gloo): + device = idist.device() + _test_distrib_integration(device) + + +@pytest.mark.multinode_distributed +@pytest.mark.skipif(not idist.has_native_dist_support, reason="Skip if no native dist support") +@pytest.mark.skipif("GPU_MULTINODE_DISTRIB" not in os.environ, reason="Skip if not multi-node distributed") +def test_multinode_distrib_nccl_gpu(distributed_context_multi_node_nccl): + device = idist.device() + _test_distrib_integration(device) + + +@pytest.mark.tpu +@pytest.mark.skipif("NUM_TPU_WORKERS" in os.environ, reason="Skip if NUM_TPU_WORKERS is in env vars") +@pytest.mark.skipif(not idist.has_xla_support, reason="Skip if no PyTorch XLA package") +def test_distrib_single_device_xla(): + device = idist.device() + _test_distrib_integration(device, tol=1e-4) + + +def _test_distrib_xla_nprocs(index): + device = idist.device() + _test_distrib_integration(device, tol=1e-4) + + +@pytest.mark.tpu +@pytest.mark.skipif("NUM_TPU_WORKERS" not in os.environ, reason="Skip if no NUM_TPU_WORKERS in env vars") +@pytest.mark.skipif(not idist.has_xla_support, reason="Skip if no PyTorch XLA package") +def test_distrib_xla_nprocs(xmp_executor): + n = int(os.environ["NUM_TPU_WORKERS"]) + xmp_executor(_test_distrib_xla_nprocs, args=(), nprocs=n) From c6cf3e554a8f96f58fe83ffc597b4193f4122f08 Mon Sep 17 00:00:00 2001 From: Kazuki Adachi Date: Sun, 24 Mar 2024 00:58:24 +0900 Subject: [PATCH 02/19] update test for MutualInformation metric --- .../ignite/metrics/test_mutual_information.py | 150 ++++++++---------- 1 file changed, 62 insertions(+), 88 deletions(-) diff --git a/tests/ignite/metrics/test_mutual_information.py b/tests/ignite/metrics/test_mutual_information.py index 8e2f2cd4766..a87dbb62db3 100644 --- a/tests/ignite/metrics/test_mutual_information.py +++ b/tests/ignite/metrics/test_mutual_information.py @@ -1,15 +1,16 @@ -from typing import Tuple import os +from typing import Tuple import numpy as np -from scipy.special import softmax -from scipy.stats import entropy import pytest import torch +from scipy.special import softmax +from scipy.stats import entropy from torch import Tensor -from ignite.engine import Engine import ignite.distributed as idist + +from ignite.engine import Engine from ignite.exceptions import NotComputableError from ignite.metrics import MutualInformation @@ -61,7 +62,7 @@ def test_compute(n_times, test_case: Tuple[Tensor, Tensor, int]): n_iters = y.shape[0] // batch_size + 1 for i in range(n_iters): idx = i * batch_size - mi.update((y_pred[idx: idx + batch_size], y[idx: idx + batch_size])) + mi.update((y_pred[idx : idx + batch_size], y[idx : idx + batch_size])) else: mi.update((y_pred, y)) @@ -72,104 +73,77 @@ def test_compute(n_times, test_case: Tuple[Tensor, Tensor, int]): assert pytest.approx(np_res, rel=1e-4) == res -def _test_distrib_integration(device, tol=1e-4): - rank = idist.get_rank() - torch.manual_seed(12 + rank) - - def _test(metric_device): - n_iters = 100 - batch_size = 10 - n_cls = 50 - - y_true = torch.randint(0, n_cls, size=[n_iters * batch_size], dtype=torch.long).to(device) - y_preds = torch.normal(0.0, 3.0, size=(n_iters * batch_size, n_cls), dtype=torch.float).to(device) - - def update(engine, i): - return ( - y_preds[i * batch_size: (i + 1) * batch_size], - y_true[i * batch_size: (i + 1) * batch_size], - ) - - engine = Engine(update) - - m = MutualInformation(device=metric_device) - m.attach(engine, "mutual_information") - - data = list(range(n_iters)) - engine.run(data=data, max_epochs=1) - - y_preds = idist.all_gather(y_preds) - y_true = idist.all_gather(y_true) - - assert "mutual_information" in engine.state.metrics - res = engine.state.metrics["mutual_information"] - - true_res = np_mutual_information(y_preds.cpu().numpy()) - - assert pytest.approx(true_res, rel=tol) == res +def test_accumulator_detached(): + mi = MutualInformation() - _test("cpu") - if device.type != "xla": - _test(idist.device()) + y_pred = torch.tensor([[2.0, 3.0], [-2.0, -1.0]], requires_grad=True) + y = torch.zeros(2) + mi.update((y_pred, y)) + assert all( + (not accumulator.requires_grad) for accumulator in (mi._sum_of_conditional_entropies, mi._sum_of_probabilities) + ) -@pytest.mark.distributed -@pytest.mark.skipif(not idist.has_native_dist_support, reason="Skip if no native dist support") -@pytest.mark.skipif(torch.cuda.device_count() < 1, reason="Skip if no GPU") -def test_distrib_nccl_gpu(distributed_context_single_node_nccl): - device = idist.device() - _test_distrib_integration(device) +@pytest.mark.usefixtures("distributed") +class TestDistributed: + def test_integration(self): + tol = 1e-4 + device = idist.device() + rank = idist.get_rank() + torch.manual_seed(12 + rank) -@pytest.mark.distributed -@pytest.mark.skipif(not idist.has_native_dist_support, reason="Skip if no native dist support") -def test_distrib_gloo_cpu_or_gpu(distributed_context_single_node_gloo): - device = idist.device() - _test_distrib_integration(device) + metric_devices = [torch.device("cpu")] + if device.type != "xla": + metric_devices.append(device) + for metric_device in metric_devices: + n_iters = 100 + batch_size = 10 + n_cls = 50 -@pytest.mark.distributed -@pytest.mark.skipif(not idist.has_hvd_support, reason="Skip if no Horovod dist support") -@pytest.mark.skipif("WORLD_SIZE" in os.environ, reason="Skip if launched as multiproc") -def test_distrib_hvd(gloo_hvd_executor): - device = torch.device("cpu" if not torch.cuda.is_available() else "cuda") - nproc = 4 if not torch.cuda.is_available() else torch.cuda.device_count() + y_true = torch.randint(0, n_cls, size=[n_iters * batch_size], dtype=torch.long).to(device) + y_preds = torch.normal(0.0, 3.0, size=(n_iters * batch_size, n_cls), dtype=torch.float).to(device) - gloo_hvd_executor(_test_distrib_integration, (device,), np=nproc, do_init=True) + engine = Engine( + lambda e, i: ( + y_preds[i * batch_size : (i + 1) * batch_size], + y_true[i * batch_size : (i + 1) * batch_size], + ) + ) + m = MutualInformation(device=metric_device) + m.attach(engine, "mutual_information") -@pytest.mark.multinode_distributed -@pytest.mark.skipif(not idist.has_native_dist_support, reason="Skip if no native dist support") -@pytest.mark.skipif("MULTINODE_DISTRIB" not in os.environ, reason="Skip if not multi-node distributed") -def test_multinode_distrib_gloo_cpu_or_gpu(distributed_context_multi_node_gloo): - device = idist.device() - _test_distrib_integration(device) + data = list(range(n_iters)) + engine.run(data=data, max_epochs=1) + y_preds = idist.all_gather(y_preds) + y_true = idist.all_gather(y_true) -@pytest.mark.multinode_distributed -@pytest.mark.skipif(not idist.has_native_dist_support, reason="Skip if no native dist support") -@pytest.mark.skipif("GPU_MULTINODE_DISTRIB" not in os.environ, reason="Skip if not multi-node distributed") -def test_multinode_distrib_nccl_gpu(distributed_context_multi_node_nccl): - device = idist.device() - _test_distrib_integration(device) + assert "mutual_information" in engine.state.metrics + res = engine.state.metrics["mutual_information"] + true_res = np_mutual_information(y_preds.cpu().numpy()) -@pytest.mark.tpu -@pytest.mark.skipif("NUM_TPU_WORKERS" in os.environ, reason="Skip if NUM_TPU_WORKERS is in env vars") -@pytest.mark.skipif(not idist.has_xla_support, reason="Skip if no PyTorch XLA package") -def test_distrib_single_device_xla(): - device = idist.device() - _test_distrib_integration(device, tol=1e-4) + assert pytest.approx(true_res, rel=tol) == res + def test_accumulator_device(self): + device = idist.device() + metric_devices = [torch.device("cpu")] + if device.type != "xla": + metric_devices.append(device) + for metric_device in metric_devices: + mi = MutualInformation(device=metric_device) -def _test_distrib_xla_nprocs(index): - device = idist.device() - _test_distrib_integration(device, tol=1e-4) + devices = (mi._device, mi._sum_of_conditional_entropies, mi._sum_of_probabilities) + for dev in devices: + assert dev == metric_device, f"{type(dev)}:{dev} vs {type(metric_device)}:{metric_device}" + y_pred = torch.tensor([[2.0, 3.0], [-2.0, -1.0]], requires_grad=True) + y = torch.zeros(2) + mi.update((y_pred, y)) -@pytest.mark.tpu -@pytest.mark.skipif("NUM_TPU_WORKERS" not in os.environ, reason="Skip if no NUM_TPU_WORKERS in env vars") -@pytest.mark.skipif(not idist.has_xla_support, reason="Skip if no PyTorch XLA package") -def test_distrib_xla_nprocs(xmp_executor): - n = int(os.environ["NUM_TPU_WORKERS"]) - xmp_executor(_test_distrib_xla_nprocs, args=(), nprocs=n) + devices = (mi._device, mi._sum_of_conditional_entropies, mi._sum_of_probabilities) + for dev in devices: + assert dev == metric_device, f"{type(dev)}:{dev} vs {type(metric_device)}:{metric_device}" From 06b218cd525a2f75f064f9eebc7cde6025f20fa7 Mon Sep 17 00:00:00 2001 From: Kazuki Adachi Date: Sun, 24 Mar 2024 00:58:42 +0900 Subject: [PATCH 03/19] format code for MutualInformation Metric --- ignite/metrics/mutual_information.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/ignite/metrics/mutual_information.py b/ignite/metrics/mutual_information.py index 81926df7abe..52d93c14941 100644 --- a/ignite/metrics/mutual_information.py +++ b/ignite/metrics/mutual_information.py @@ -13,7 +13,7 @@ class MutualInformation(Metric): r"""Calculates the `mutual information `_ between input :math:`X` and prediction :math:`Y`. - .. math:: I(X;Y) = H(Y) - H(Y|X) + .. math:: I(X;Y) = H(Y) - H(Y|X) = H \left( \frac{1}{N}\sum_{i=1}^N \hat{\mathbf{p}}_i \right) - \frac{1}{N}\sum_{i=1}^N H(\hat{\mathbf{p}}_i), @@ -102,5 +102,5 @@ def compute(self) -> float: marginal_ent = -(marginal_prob * torch.log(marginal_prob)).sum() conditional_ent = self._sum_of_conditional_entropies / n mi = marginal_ent - conditional_ent - mi = torch.clamp(mi, min=0.0) # mutual information cannot be negative + mi = torch.clamp(mi, min=0.0) # mutual information cannot be negative return float(mi.item()) From b85fbd111d552b9b19474343033c15e5fb5cf70f Mon Sep 17 00:00:00 2001 From: Kazuki Adachi Date: Sun, 24 Mar 2024 11:32:06 +0900 Subject: [PATCH 04/19] update test for MutualInformation metric --- tests/ignite/metrics/test_mutual_information.py | 9 ++++----- 1 file changed, 4 insertions(+), 5 deletions(-) diff --git a/tests/ignite/metrics/test_mutual_information.py b/tests/ignite/metrics/test_mutual_information.py index a87dbb62db3..b3210b3b105 100644 --- a/tests/ignite/metrics/test_mutual_information.py +++ b/tests/ignite/metrics/test_mutual_information.py @@ -37,7 +37,7 @@ def test_invalid_shape(): mi.update((y_pred, None)) -@pytest.fixture(params=[item for item in range(4)]) +@pytest.fixture(params=list(range(4))) def test_case(request): return [ (torch.randn((100, 10)).float(), torch.randint(0, 10, size=[100]), 1), @@ -89,6 +89,9 @@ def test_accumulator_detached(): class TestDistributed: def test_integration(self): tol = 1e-4 + n_iters = 100 + batch_size = 10 + n_cls = 50 device = idist.device() rank = idist.get_rank() torch.manual_seed(12 + rank) @@ -98,10 +101,6 @@ def test_integration(self): metric_devices.append(device) for metric_device in metric_devices: - n_iters = 100 - batch_size = 10 - n_cls = 50 - y_true = torch.randint(0, n_cls, size=[n_iters * batch_size], dtype=torch.long).to(device) y_preds = torch.normal(0.0, 3.0, size=(n_iters * batch_size, n_cls), dtype=torch.float).to(device) From 40e9e67167bf78e8fd1ab805fb88511eae0ed5ce Mon Sep 17 00:00:00 2001 From: Kazuki Adachi Date: Sat, 30 Mar 2024 21:38:46 +0900 Subject: [PATCH 05/19] update test --- tests/ignite/metrics/test_mutual_information.py | 5 ++--- 1 file changed, 2 insertions(+), 3 deletions(-) diff --git a/tests/ignite/metrics/test_mutual_information.py b/tests/ignite/metrics/test_mutual_information.py index b3210b3b105..232874cbf2c 100644 --- a/tests/ignite/metrics/test_mutual_information.py +++ b/tests/ignite/metrics/test_mutual_information.py @@ -80,9 +80,8 @@ def test_accumulator_detached(): y = torch.zeros(2) mi.update((y_pred, y)) - assert all( - (not accumulator.requires_grad) for accumulator in (mi._sum_of_conditional_entropies, mi._sum_of_probabilities) - ) + accumulators = (mi._sum_of_conditional_entropies, mi._sum_of_probabilities) + assert all((not accumulator.requires_grad) for accumulator in accumulators) @pytest.mark.usefixtures("distributed") From f7d3a4169d2f6e3990ffade58e0add3006c859ed Mon Sep 17 00:00:00 2001 From: Kazuki Adachi Date: Sat, 30 Mar 2024 22:07:56 +0900 Subject: [PATCH 06/19] update docstring --- ignite/metrics/mutual_information.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/ignite/metrics/mutual_information.py b/ignite/metrics/mutual_information.py index 52d93c14941..419ba47fa53 100644 --- a/ignite/metrics/mutual_information.py +++ b/ignite/metrics/mutual_information.py @@ -22,6 +22,8 @@ class MutualInformation(Metric): where :math:`\hat{\mathbf{p}}_i` is the prediction probability vector for :math:`i`-th input, and :math:`H(\mathbf{p})` is the entropy of :math:`\mathbf{p}`. + Intuitively, this metric measures how well input data are clustered by classes in the feature space. + - ``update`` must receive output of the form ``(y_pred, y)`` while ``y`` is not used in this metric. - ``y_pred`` is expected to be the unnormalized logits for each class. :math:`(B, C)` (classification) or :math:`(B, C, ...)` (e.g., image segmentation) shapes are allowed. From 72e16647dbc276f299f656042c273fcc74612852 Mon Sep 17 00:00:00 2001 From: Kazuki Adachi Date: Sun, 31 Mar 2024 11:51:31 +0900 Subject: [PATCH 07/19] fix device compatibility --- ignite/metrics/mutual_information.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/ignite/metrics/mutual_information.py b/ignite/metrics/mutual_information.py index 419ba47fa53..7b2be477d0b 100644 --- a/ignite/metrics/mutual_information.py +++ b/ignite/metrics/mutual_information.py @@ -90,8 +90,8 @@ def update(self, output: Sequence[torch.Tensor]) -> None: log_prob = F.log_softmax(y_pred, dim=1) ent_sum = -(prob * log_prob).sum() - self._sum_of_probabilities = self._sum_of_probabilities + prob.sum(dim=0) - self._sum_of_conditional_entropies += ent_sum + self._sum_of_probabilities = self._sum_of_probabilities + prob.sum(dim=0).to(self._device) + self._sum_of_conditional_entropies += ent_sum.to(self._device) self._num_examples += y_pred.shape[0] @sync_all_reduce("_sum_of_probabilities", "_sum_of_conditional_entropies", "_num_examples") From 3f5c28d465569f35f629707bb9e269c3c7a1499f Mon Sep 17 00:00:00 2001 From: Kazuki Adachi Date: Sun, 31 Mar 2024 11:57:43 +0900 Subject: [PATCH 08/19] fix test_accumulator_device for MutualInformation metric --- tests/ignite/metrics/test_mutual_information.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/tests/ignite/metrics/test_mutual_information.py b/tests/ignite/metrics/test_mutual_information.py index 232874cbf2c..5ca8d149d20 100644 --- a/tests/ignite/metrics/test_mutual_information.py +++ b/tests/ignite/metrics/test_mutual_information.py @@ -134,7 +134,7 @@ def test_accumulator_device(self): for metric_device in metric_devices: mi = MutualInformation(device=metric_device) - devices = (mi._device, mi._sum_of_conditional_entropies, mi._sum_of_probabilities) + devices = (mi._device, mi._sum_of_conditional_entropies.device, mi._sum_of_probabilities.device) for dev in devices: assert dev == metric_device, f"{type(dev)}:{dev} vs {type(metric_device)}:{metric_device}" @@ -142,6 +142,6 @@ def test_accumulator_device(self): y = torch.zeros(2) mi.update((y_pred, y)) - devices = (mi._device, mi._sum_of_conditional_entropies, mi._sum_of_probabilities) + devices = (mi._device, mi._sum_of_conditional_entropies.device, mi._sum_of_probabilities.device) for dev in devices: assert dev == metric_device, f"{type(dev)}:{dev} vs {type(metric_device)}:{metric_device}" From 61669eeeae0e749c8ce04ce42b88a741138e4056 Mon Sep 17 00:00:00 2001 From: Kazuki Adachi Date: Wed, 3 Apr 2024 23:17:48 +0900 Subject: [PATCH 09/19] update doc --- docs/source/metrics.rst | 1 + 1 file changed, 1 insertion(+) diff --git a/docs/source/metrics.rst b/docs/source/metrics.rst index 0696cc3070a..f6742f73be5 100644 --- a/docs/source/metrics.rst +++ b/docs/source/metrics.rst @@ -337,6 +337,7 @@ Complete list of metrics metric.Metric metrics_lambda.MetricsLambda MultiLabelConfusionMatrix + MutualInformation precision.Precision PSNR recall.Recall From 51676eb7543f5f2c14b7d311ef8c55b1fe364069 Mon Sep 17 00:00:00 2001 From: Kazuki Adachi Date: Wed, 3 Apr 2024 23:28:32 +0900 Subject: [PATCH 10/19] modify docstring --- ignite/metrics/mutual_information.py | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/ignite/metrics/mutual_information.py b/ignite/metrics/mutual_information.py index 7b2be477d0b..b01c967dce1 100644 --- a/ignite/metrics/mutual_information.py +++ b/ignite/metrics/mutual_information.py @@ -11,9 +11,11 @@ class MutualInformation(Metric): r"""Calculates the `mutual information `_ - between input :math:`X` and prediction :math:`Y`. + between input :math:`X` and prediction :math:`Y`. - .. math:: I(X;Y) = H(Y) - H(Y|X) + .. math:: + + I(X;Y) = H(Y) - H(Y|X) = H \left( \frac{1}{N}\sum_{i=1}^N \hat{\mathbf{p}}_i \right) - \frac{1}{N}\sum_{i=1}^N H(\hat{\mathbf{p}}_i), From 6ededec52c3135ac25068ade51b6c196e153679b Mon Sep 17 00:00:00 2001 From: Kazuki Adachi Date: Thu, 4 Apr 2024 21:06:14 +0900 Subject: [PATCH 11/19] modify formula of docstring --- ignite/metrics/mutual_information.py | 7 +++---- 1 file changed, 3 insertions(+), 4 deletions(-) diff --git a/ignite/metrics/mutual_information.py b/ignite/metrics/mutual_information.py index b01c967dce1..50e7d9b1b74 100644 --- a/ignite/metrics/mutual_information.py +++ b/ignite/metrics/mutual_information.py @@ -15,11 +15,10 @@ class MutualInformation(Metric): .. math:: - I(X;Y) = H(Y) - H(Y|X) + I(X;Y) &= H(Y) - H(Y|X) = H \left( \frac{1}{N}\sum_{i=1}^N \hat{\mathbf{p}}_i \right) - - \frac{1}{N}\sum_{i=1}^N H(\hat{\mathbf{p}}_i), - - H(\mathbf{p}) = -\sum_{c=1}^C p_c \log p_c. + - \frac{1}{N}\sum_{i=1}^N H(\hat{\mathbf{p}}_i), \\ + H(\mathbf{p}) &= -\sum_{c=1}^C p_c \log p_c. where :math:`\hat{\mathbf{p}}_i` is the prediction probability vector for :math:`i`-th input, and :math:`H(\mathbf{p})` is the entropy of :math:`\mathbf{p}`. From ba3e78f9ef3ddda6648fc70b051f36d3bbf91a1c Mon Sep 17 00:00:00 2001 From: Kazuki Adachi Date: Thu, 4 Apr 2024 21:17:26 +0900 Subject: [PATCH 12/19] update formula of docstring --- ignite/metrics/mutual_information.py | 1 - 1 file changed, 1 deletion(-) diff --git a/ignite/metrics/mutual_information.py b/ignite/metrics/mutual_information.py index 50e7d9b1b74..c7377a9071c 100644 --- a/ignite/metrics/mutual_information.py +++ b/ignite/metrics/mutual_information.py @@ -14,7 +14,6 @@ class MutualInformation(Metric): between input :math:`X` and prediction :math:`Y`. .. math:: - I(X;Y) &= H(Y) - H(Y|X) = H \left( \frac{1}{N}\sum_{i=1}^N \hat{\mathbf{p}}_i \right) - \frac{1}{N}\sum_{i=1}^N H(\hat{\mathbf{p}}_i), \\ From 73b79288b4fb4adeefff3a49099501f446eaed86 Mon Sep 17 00:00:00 2001 From: Kazuki Adachi Date: Thu, 4 Apr 2024 21:20:37 +0900 Subject: [PATCH 13/19] update formula of docstring --- ignite/metrics/mutual_information.py | 9 +++++---- 1 file changed, 5 insertions(+), 4 deletions(-) diff --git a/ignite/metrics/mutual_information.py b/ignite/metrics/mutual_information.py index c7377a9071c..faa5e62f752 100644 --- a/ignite/metrics/mutual_information.py +++ b/ignite/metrics/mutual_information.py @@ -14,10 +14,11 @@ class MutualInformation(Metric): between input :math:`X` and prediction :math:`Y`. .. math:: - I(X;Y) &= H(Y) - H(Y|X) - = H \left( \frac{1}{N}\sum_{i=1}^N \hat{\mathbf{p}}_i \right) - - \frac{1}{N}\sum_{i=1}^N H(\hat{\mathbf{p}}_i), \\ - H(\mathbf{p}) &= -\sum_{c=1}^C p_c \log p_c. + \begin{align*} + I(X;Y) &= H(Y) - H(Y|X) = H \left( \frac{1}{N}\sum_{i=1}^N \hat{\mathbf{p}}_i \right) + - \frac{1}{N}\sum_{i=1}^N H(\hat{\mathbf{p}}_i), \\ + H(\mathbf{p}) &= -\sum_{c=1}^C p_c \log p_c. + \end{align*} where :math:`\hat{\mathbf{p}}_i` is the prediction probability vector for :math:`i`-th input, and :math:`H(\mathbf{p})` is the entropy of :math:`\mathbf{p}`. From c4bc88a25c33386cb7a9b84ba71ca9b4c758e62e Mon Sep 17 00:00:00 2001 From: Kazuki Adachi Date: Fri, 5 Apr 2024 21:23:37 +0900 Subject: [PATCH 14/19] remove unused import --- tests/ignite/metrics/test_mutual_information.py | 1 - 1 file changed, 1 deletion(-) diff --git a/tests/ignite/metrics/test_mutual_information.py b/tests/ignite/metrics/test_mutual_information.py index 5ca8d149d20..f44e0c0cee3 100644 --- a/tests/ignite/metrics/test_mutual_information.py +++ b/tests/ignite/metrics/test_mutual_information.py @@ -1,4 +1,3 @@ -import os from typing import Tuple import numpy as np From d12475239b8661038264233d79030afc761a1d56 Mon Sep 17 00:00:00 2001 From: Kazuki Adachi Date: Fri, 5 Apr 2024 21:29:31 +0900 Subject: [PATCH 15/19] add reference --- ignite/metrics/mutual_information.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/ignite/metrics/mutual_information.py b/ignite/metrics/mutual_information.py index faa5e62f752..4f72986aa0e 100644 --- a/ignite/metrics/mutual_information.py +++ b/ignite/metrics/mutual_information.py @@ -23,7 +23,9 @@ class MutualInformation(Metric): where :math:`\hat{\mathbf{p}}_i` is the prediction probability vector for :math:`i`-th input, and :math:`H(\mathbf{p})` is the entropy of :math:`\mathbf{p}`. - Intuitively, this metric measures how well input data are clustered by classes in the feature space. + Intuitively, this metric measures how well input data are clustered by classes in the feature space [1]. + + [1] https://proceedings.mlr.press/v70/hu17b.html - ``update`` must receive output of the form ``(y_pred, y)`` while ``y`` is not used in this metric. - ``y_pred`` is expected to be the unnormalized logits for each class. :math:`(B, C)` (classification) From 366d85479e63b93831f930784b18ced8dfa5ef2b Mon Sep 17 00:00:00 2001 From: Kazuki Adachi Date: Sat, 6 Apr 2024 11:38:28 +0900 Subject: [PATCH 16/19] commonalize redundant code --- ignite/metrics/entropy.py | 8 +++-- ignite/metrics/mutual_information.py | 35 +++++-------------- .../ignite/metrics/test_mutual_information.py | 7 ++-- 3 files changed, 18 insertions(+), 32 deletions(-) diff --git a/ignite/metrics/entropy.py b/ignite/metrics/entropy.py index b3d0cff21b6..44241942c68 100644 --- a/ignite/metrics/entropy.py +++ b/ignite/metrics/entropy.py @@ -67,7 +67,6 @@ def reset(self) -> None: self._sum_of_entropies = torch.tensor(0.0, device=self._device) self._num_examples = 0 - @reinit__is_reduced def update(self, output: Sequence[torch.Tensor]) -> None: y_pred = output[0].detach() if y_pred.ndim >= 3: @@ -80,9 +79,14 @@ def update(self, output: Sequence[torch.Tensor]) -> None: prob = F.softmax(y_pred, dim=1) log_prob = F.log_softmax(y_pred, dim=1) + + self._update(prob, log_prob) + + @reinit__is_reduced + def _update(self, prob: torch.Tensor, log_prob: torch.Tensor) -> None: entropy_sum = -torch.sum(prob * log_prob) self._sum_of_entropies += entropy_sum.to(self._device) - self._num_examples += y_pred.shape[0] + self._num_examples += prob.shape[0] @sync_all_reduce("_sum_of_entropies", "_num_examples") def compute(self) -> float: diff --git a/ignite/metrics/mutual_information.py b/ignite/metrics/mutual_information.py index 4f72986aa0e..b784166db71 100644 --- a/ignite/metrics/mutual_information.py +++ b/ignite/metrics/mutual_information.py @@ -1,15 +1,13 @@ -from typing import Sequence - import torch -import torch.nn.functional as F from ignite.exceptions import NotComputableError -from ignite.metrics.metric import Metric, reinit__is_reduced, sync_all_reduce +from ignite.metrics import Entropy +from ignite.metrics.metric import reinit__is_reduced, sync_all_reduce __all__ = ["MutualInformation"] -class MutualInformation(Metric): +class MutualInformation(Entropy): r"""Calculates the `mutual information `_ between input :math:`X` and prediction :math:`Y`. @@ -70,34 +68,19 @@ class MutualInformation(Metric): 0.18599730730056763 """ - _state_dict_all_req_keys = ("_sum_of_probabilities", "_sum_of_conditional_entropies", "_num_examples") + _state_dict_all_req_keys = ("_sum_of_probabilities",) @reinit__is_reduced def reset(self) -> None: + super().reset() self._sum_of_probabilities = torch.tensor(0.0, device=self._device) - self._sum_of_conditional_entropies = torch.tensor(0.0, device=self._device) - self._num_examples = 0 @reinit__is_reduced - def update(self, output: Sequence[torch.Tensor]) -> None: - y_pred = output[0].detach() - if y_pred.ndim >= 3: - num_classes = y_pred.shape[1] - # (B, C, ...) -> (B, ..., C) -> (B*..., C) - # regarding as B*... predictions - y_pred = y_pred.movedim(1, -1).reshape(-1, num_classes) - elif y_pred.ndim == 1: - raise ValueError(f"y_pred must be in the shape of (B, C) or (B, C, ...), got {y_pred.shape}.") - - prob = F.softmax(y_pred, dim=1) - log_prob = F.log_softmax(y_pred, dim=1) - ent_sum = -(prob * log_prob).sum() - + def _update(self, prob: torch.Tensor, log_prob: torch.Tensor) -> None: + super()._update(prob, log_prob) self._sum_of_probabilities = self._sum_of_probabilities + prob.sum(dim=0).to(self._device) - self._sum_of_conditional_entropies += ent_sum.to(self._device) - self._num_examples += y_pred.shape[0] - @sync_all_reduce("_sum_of_probabilities", "_sum_of_conditional_entropies", "_num_examples") + @sync_all_reduce("_sum_of_probabilities") def compute(self) -> float: n = self._num_examples if n == 0: @@ -105,7 +88,7 @@ def compute(self) -> float: marginal_prob = self._sum_of_probabilities / n marginal_ent = -(marginal_prob * torch.log(marginal_prob)).sum() - conditional_ent = self._sum_of_conditional_entropies / n + conditional_ent = self._sum_of_entropies / n mi = marginal_ent - conditional_ent mi = torch.clamp(mi, min=0.0) # mutual information cannot be negative return float(mi.item()) diff --git a/tests/ignite/metrics/test_mutual_information.py b/tests/ignite/metrics/test_mutual_information.py index f44e0c0cee3..18d58d300bf 100644 --- a/tests/ignite/metrics/test_mutual_information.py +++ b/tests/ignite/metrics/test_mutual_information.py @@ -79,8 +79,7 @@ def test_accumulator_detached(): y = torch.zeros(2) mi.update((y_pred, y)) - accumulators = (mi._sum_of_conditional_entropies, mi._sum_of_probabilities) - assert all((not accumulator.requires_grad) for accumulator in accumulators) + assert not mi._sum_of_probabilities.requires_grad @pytest.mark.usefixtures("distributed") @@ -133,7 +132,7 @@ def test_accumulator_device(self): for metric_device in metric_devices: mi = MutualInformation(device=metric_device) - devices = (mi._device, mi._sum_of_conditional_entropies.device, mi._sum_of_probabilities.device) + devices = (mi._device, mi._sum_of_probabilities.device) for dev in devices: assert dev == metric_device, f"{type(dev)}:{dev} vs {type(metric_device)}:{metric_device}" @@ -141,6 +140,6 @@ def test_accumulator_device(self): y = torch.zeros(2) mi.update((y_pred, y)) - devices = (mi._device, mi._sum_of_conditional_entropies.device, mi._sum_of_probabilities.device) + devices = (mi._device, mi._sum_of_probabilities.device) for dev in devices: assert dev == metric_device, f"{type(dev)}:{dev} vs {type(metric_device)}:{metric_device}" From d6b2ee6c63973089503688fd5669d4e533264ac4 Mon Sep 17 00:00:00 2001 From: Kazuki Adachi Date: Mon, 8 Apr 2024 22:01:18 +0900 Subject: [PATCH 17/19] modify decorator --- ignite/metrics/entropy.py | 2 +- ignite/metrics/mutual_information.py | 1 - 2 files changed, 1 insertion(+), 2 deletions(-) diff --git a/ignite/metrics/entropy.py b/ignite/metrics/entropy.py index 44241942c68..4208bf205b3 100644 --- a/ignite/metrics/entropy.py +++ b/ignite/metrics/entropy.py @@ -67,6 +67,7 @@ def reset(self) -> None: self._sum_of_entropies = torch.tensor(0.0, device=self._device) self._num_examples = 0 + @reinit__is_reduced def update(self, output: Sequence[torch.Tensor]) -> None: y_pred = output[0].detach() if y_pred.ndim >= 3: @@ -82,7 +83,6 @@ def update(self, output: Sequence[torch.Tensor]) -> None: self._update(prob, log_prob) - @reinit__is_reduced def _update(self, prob: torch.Tensor, log_prob: torch.Tensor) -> None: entropy_sum = -torch.sum(prob * log_prob) self._sum_of_entropies += entropy_sum.to(self._device) diff --git a/ignite/metrics/mutual_information.py b/ignite/metrics/mutual_information.py index b784166db71..58fe1b35841 100644 --- a/ignite/metrics/mutual_information.py +++ b/ignite/metrics/mutual_information.py @@ -75,7 +75,6 @@ def reset(self) -> None: super().reset() self._sum_of_probabilities = torch.tensor(0.0, device=self._device) - @reinit__is_reduced def _update(self, prob: torch.Tensor, log_prob: torch.Tensor) -> None: super()._update(prob, log_prob) self._sum_of_probabilities = self._sum_of_probabilities + prob.sum(dim=0).to(self._device) From b1cc792722a4b23165939f41b9cf9595c2de0b09 Mon Sep 17 00:00:00 2001 From: Kazuki Adachi Date: Mon, 8 Apr 2024 22:04:21 +0900 Subject: [PATCH 18/19] add a comment --- ignite/metrics/mutual_information.py | 1 + 1 file changed, 1 insertion(+) diff --git a/ignite/metrics/mutual_information.py b/ignite/metrics/mutual_information.py index 58fe1b35841..e79c074fd69 100644 --- a/ignite/metrics/mutual_information.py +++ b/ignite/metrics/mutual_information.py @@ -77,6 +77,7 @@ def reset(self) -> None: def _update(self, prob: torch.Tensor, log_prob: torch.Tensor) -> None: super()._update(prob, log_prob) + # We can't use += below as _sum_of_probabilities can be a scalar and prob.sum(dim=0) is a vector self._sum_of_probabilities = self._sum_of_probabilities + prob.sum(dim=0).to(self._device) @sync_all_reduce("_sum_of_probabilities") From a23b4352ce826dbbda4ca33413822c7b0dc30514 Mon Sep 17 00:00:00 2001 From: Kazuki Adachi Date: Mon, 8 Apr 2024 23:34:54 +0900 Subject: [PATCH 19/19] fix decorator --- ignite/metrics/mutual_information.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/ignite/metrics/mutual_information.py b/ignite/metrics/mutual_information.py index e79c074fd69..2cca768ce43 100644 --- a/ignite/metrics/mutual_information.py +++ b/ignite/metrics/mutual_information.py @@ -80,7 +80,7 @@ def _update(self, prob: torch.Tensor, log_prob: torch.Tensor) -> None: # We can't use += below as _sum_of_probabilities can be a scalar and prob.sum(dim=0) is a vector self._sum_of_probabilities = self._sum_of_probabilities + prob.sum(dim=0).to(self._device) - @sync_all_reduce("_sum_of_probabilities") + @sync_all_reduce("_sum_of_probabilities", "_sum_of_entropies", "_num_examples") def compute(self) -> float: n = self._num_examples if n == 0: