From 592d27d47eb9a76d87f1ad8147e82bda02ff1fb7 Mon Sep 17 00:00:00 2001 From: Kazuki Adachi Date: Sun, 1 Sep 2024 21:36:38 +0900 Subject: [PATCH 01/12] add SpearmanRankCorrelation metric --- docs/source/metrics.rst | 1 + ignite/metrics/regression/__init__.py | 1 + .../regression/spearman_correlation.py | 96 +++++++++ .../regression/test_spearman_correlation.py | 191 ++++++++++++++++++ 4 files changed, 289 insertions(+) create mode 100644 ignite/metrics/regression/spearman_correlation.py create mode 100644 tests/ignite/metrics/regression/test_spearman_correlation.py diff --git a/docs/source/metrics.rst b/docs/source/metrics.rst index 0e4979f82a1..1d02f1b3d13 100644 --- a/docs/source/metrics.rst +++ b/docs/source/metrics.rst @@ -377,6 +377,7 @@ Complete list of metrics regression.MedianAbsolutePercentageError regression.MedianRelativeAbsoluteError regression.PearsonCorrelation + regression.SpearmanRankCorrelation regression.R2Score regression.WaveHedgesDistance diff --git a/ignite/metrics/regression/__init__.py b/ignite/metrics/regression/__init__.py index 7be1f18d0f3..0f4f58327e0 100644 --- a/ignite/metrics/regression/__init__.py +++ b/ignite/metrics/regression/__init__.py @@ -13,4 +13,5 @@ from ignite.metrics.regression.median_relative_absolute_error import MedianRelativeAbsoluteError from ignite.metrics.regression.pearson_correlation import PearsonCorrelation from ignite.metrics.regression.r2_score import R2Score +from ignite.metrics.regression.spearman_correlation import SpearmanRankCorrelation from ignite.metrics.regression.wave_hedges_distance import WaveHedgesDistance diff --git a/ignite/metrics/regression/spearman_correlation.py b/ignite/metrics/regression/spearman_correlation.py new file mode 100644 index 00000000000..b4c3baee524 --- /dev/null +++ b/ignite/metrics/regression/spearman_correlation.py @@ -0,0 +1,96 @@ +from typing import Any, Callable, Tuple + +import torch + +from scipy.stats import spearmanr +from torch import Tensor + +from ignite.exceptions import NotComputableError +from ignite.metrics.epoch_metric import EpochMetric +from ignite.metrics.regression._base import _check_output_shapes, _check_output_types + + +def _compute_spearman_r(predictions: Tensor, targets: Tensor) -> float: + np_preds = predictions.flatten().numpy() + np_targets = targets.flatten().numpy() + r = spearmanr(np_preds, np_targets).statistic + return r + + +class SpearmanRankCorrelation(EpochMetric): + r"""Calculates the + `Spearman's rank correlation coefficient `_. + + .. math:: + r_\text{s} = \text{Corr}[R[P], R[A]] = \frac{\text{Cov}[R[P], R[A]]}{\sigma_{R[P]} \sigma_{R[A]}} + + where :math:`A` and :math:`P` are the ground truth and predicted value, and R[X] is the ranking value of X. + + The computation of this metric is implemented with + `scipy.stats.spearmanr `_. + + - ``update`` must receive output of the form ``(y_pred, y)`` or ``{'y_pred': y_pred, 'y': y}``. + - `y` and `y_pred` must be of same shape `(N, )` or `(N, 1)`. + + Parameters are inherited from ``Metric.__init__``. + + Args: + output_transform: a callable that is used to transform the + :class:`~ignite.engine.engine.Engine`'s ``process_function``'s output into the + form expected by the metric. This can be useful if, for example, you have a multi-output model and + you want to compute the metric with respect to one of the outputs. + By default, metrics require the output as ``(y_pred, y)`` or ``{'y_pred': y_pred, 'y': y}``. + device: specifies which device updates are accumulated on. Setting the + metric's device to be the same as your ``update`` arguments ensures the ``update`` method is + non-blocking. By default, CPU. + + Examples: + To use with ``Engine`` and ``process_function``, simply attach the metric instance to the engine. + The output of the engine's ``process_function`` needs to be in format of + ``(y_pred, y)`` or ``{'y_pred': y_pred, 'y': y, ...}``. + + .. include:: defaults.rst + :start-after: :orphan: + + .. testcode:: + + metric = SpearmanRankCorrelation() + metric.attach(default_evaluator, 'spearman_corr') + y_true = torch.tensor([0., 1., 2., 3., 4., 5.]) + y_pred = torch.tensor([0.5, 2.8, 1.9, 1.3, 6.0, 4.1]) + state = default_evaluator.run([[y_pred, y_true]]) + print(state.metrics['spearman_corr']) + + .. testoutput:: + + 0.7142857142857143 + """ + + def __init__( + self, + output_transform: Callable[..., Any] = lambda x: x, + check_compute_fn: bool = True, + device: str | torch.device = torch.device("cpu"), + skip_unrolling: bool = False, + ) -> None: + super().__init__(_compute_spearman_r, output_transform, check_compute_fn, device, skip_unrolling) + + def update(self, output: Tuple[torch.Tensor, torch.Tensor]) -> None: + y_pred, y = output[0].detach(), output[1].detach() + if y_pred.ndim == 1: + y_pred = y_pred.unsqueeze(1) + if y.ndim == 1: + y = y.unsqueeze(1) + + _check_output_shapes(output) + _check_output_types(output) + + super().update(output) + + def compute(self) -> float: + if len(self._predictions) < 1 or len(self._targets) < 1: + raise NotComputableError( + "SpearmanRankCorrelation must have at least one example before it can be computed." + ) + + return super().compute() diff --git a/tests/ignite/metrics/regression/test_spearman_correlation.py b/tests/ignite/metrics/regression/test_spearman_correlation.py new file mode 100644 index 00000000000..4aac6221f62 --- /dev/null +++ b/tests/ignite/metrics/regression/test_spearman_correlation.py @@ -0,0 +1,191 @@ +from typing import Tuple + +import numpy as np +import pytest + +import torch +from scipy.stats import spearmanr +from torch import Tensor + +from ignite import distributed as idist +from ignite.engine import Engine +from ignite.exceptions import NotComputableError +from ignite.metrics.regression import SpearmanRankCorrelation + + +def test_zero_sample(): + with pytest.raises( + NotComputableError, match="SpearmanRankCorrelation must have at least one example before it can be computed" + ): + metric = SpearmanRankCorrelation() + metric.compute() + + +def test_wrong_y_pred_shape(): + with pytest.raises(ValueError, match=r"Input y_pred should have shape \(N,\) or \(N, 1\), but given"): + metric = SpearmanRankCorrelation() + y_pred = torch.arange(9).reshape(3, 3).float() + y = torch.arange(3).unsqueeze(1).float() + metric.update((y_pred, y)) + + +def test_wrong_y_shape(): + with pytest.raises(ValueError, match=r"Input y should have shape \(N,\) or \(N, 1\), but given"): + metric = SpearmanRankCorrelation() + y_pred = torch.arange(3).unsqueeze(1).float() + y = torch.arange(9).reshape(3, 3).float() + metric.update((y_pred, y)) + + +def test_wrong_y_pred_dtype(): + with pytest.raises(TypeError, match="Input y_pred dtype should be float 16, 32 or 64, but given"): + metric = SpearmanRankCorrelation() + y_pred = torch.arange(3).unsqueeze(1).long() + y = torch.arange(3).unsqueeze(1).float() + metric.update((y_pred, y)) + + +def test_wrong_y_dtype(): + with pytest.raises(TypeError, match="Input y dtype should be float 16, 32 or 64, but given"): + metric = SpearmanRankCorrelation() + y_pred = torch.arange(3).unsqueeze(1).float() + y = torch.arange(3).unsqueeze(1).long() + metric.update((y_pred, y)) + + +def test_spearman_correlation(): + a = np.random.randn(4).astype(np.float32) + b = np.random.randn(4).astype(np.float32) + c = np.random.randn(4).astype(np.float32) + d = np.random.randn(4).astype(np.float32) + ground_truth = np.random.randn(4).astype(np.float32) + + m = SpearmanRankCorrelation() + + m.update((torch.from_numpy(a), torch.from_numpy(ground_truth))) + np_ans = spearmanr(a, ground_truth).statistic + assert m.compute() == pytest.approx(np_ans, rel=1e-4) + + m.update((torch.from_numpy(b), torch.from_numpy(ground_truth))) + np_ans = spearmanr(np.concatenate([a, b]), np.concatenate([ground_truth] * 2)).statistic + assert m.compute() == pytest.approx(np_ans, rel=1e-4) + + m.update((torch.from_numpy(c), torch.from_numpy(ground_truth))) + np_ans = spearmanr(np.concatenate([a, b, c]), np.concatenate([ground_truth] * 3)).statistic + assert m.compute() == pytest.approx(np_ans, rel=1e-4) + + m.update((torch.from_numpy(d), torch.from_numpy(ground_truth))) + np_ans = spearmanr(np.concatenate([a, b, c, d]), np.concatenate([ground_truth] * 4)).statistic + assert m.compute() == pytest.approx(np_ans, rel=1e-4) + + +@pytest.fixture(params=list(range(2))) +def test_case(request): + # correlated sample + x = torch.randn(size=[50]).float() + y = x + torch.randn_like(x) * 0.1 + + return [ + (x, y, 1), + (torch.rand(size=(50, 1)).float(), torch.rand(size=(50, 1)).float(), 10), + ][request.param] + + +@pytest.mark.parametrize("n_times", range(5)) +def test_integration(n_times, test_case: Tuple[Tensor, Tensor, int]): + y_pred, y, batch_size = test_case + + np_y = y.numpy().ravel() + np_y_pred = y_pred.numpy().ravel() + + def update_fn(engine: Engine, batch): + idx = (engine.state.iteration - 1) * batch_size + y_true_batch = np_y[idx : idx + batch_size] + y_pred_batch = np_y_pred[idx : idx + batch_size] + return torch.from_numpy(y_pred_batch), torch.from_numpy(y_true_batch) + + engine = Engine(update_fn) + + m = SpearmanRankCorrelation() + m.attach(engine, "spearman_corr") + + data = list(range(y_pred.shape[0] // batch_size)) + corr = engine.run(data, max_epochs=1).metrics["spearman_corr"] + + np_ans = spearmanr(np_y_pred, np_y).statistic + + assert pytest.approx(np_ans, rel=2e-4) == corr + + +@pytest.mark.usefixtures("distributed") +class TestDistributed: + def test_compute(self): + rank = idist.get_rank() + device = idist.device() + metric_devices = [torch.device("cpu")] + if device.type != "xla": + metric_devices.append(device) + + torch.manual_seed(10 + rank) + for metric_device in metric_devices: + m = SpearmanRankCorrelation(device=metric_device) + + y_pred = torch.rand(size=[100], device=device) + y = torch.rand(size=[100], device=device) + + m.update((y_pred, y)) + + y_pred = idist.all_gather(y_pred) + y = idist.all_gather(y) + + np_y = y.cpu().numpy() + np_y_pred = y_pred.cpu().numpy() + + np_ans = spearmanr(np_y_pred, np_y).statistic + + assert pytest.approx(np_ans, rel=2e-4) == m.compute() + + @pytest.mark.parametrize("n_epochs", [1, 2]) + def test_integration(self, n_epochs: int): + tol = 2e-4 + rank = idist.get_rank() + device = idist.device() + metric_devices = [torch.device("cpu")] + if device.type != "xla": + metric_devices.append(device) + + n_iters = 80 + batch_size = 16 + + for metric_device in metric_devices: + torch.manual_seed(12 + rank) + + y_true = torch.rand(size=(n_iters * batch_size,)).to(device) + y_preds = torch.rand(size=(n_iters * batch_size,)).to(device) + + engine = Engine( + lambda e, i: ( + y_preds[i * batch_size : (i + 1) * batch_size], + y_true[i * batch_size : (i + 1) * batch_size], + ) + ) + + corr = SpearmanRankCorrelation(device=metric_device) + corr.attach(engine, "spearman_corr") + + data = list(range(n_iters)) + engine.run(data=data, max_epochs=n_epochs) + + y_preds = idist.all_gather(y_preds) + y_true = idist.all_gather(y_true) + + assert "spearman_corr" in engine.state.metrics + + res = engine.state.metrics["spearman_corr"] + + np_y = y_true.cpu().numpy() + np_y_pred = y_preds.cpu().numpy() + + np_ans = spearmanr(np_y_pred, np_y).statistic + + assert pytest.approx(np_ans, rel=tol) == res From a116a177ae5f609781169ec48699d33215446e4e Mon Sep 17 00:00:00 2001 From: Kazuki Adachi Date: Sun, 1 Sep 2024 21:37:09 +0900 Subject: [PATCH 02/12] add KendallRankCorrelation metric --- docs/source/metrics.rst | 1 + ignite/metrics/regression/__init__.py | 1 + .../metrics/regression/kendall_correlation.py | 109 ++++++++++ .../regression/test_kendall_correlation.py | 200 ++++++++++++++++++ 4 files changed, 311 insertions(+) create mode 100644 ignite/metrics/regression/kendall_correlation.py create mode 100644 tests/ignite/metrics/regression/test_kendall_correlation.py diff --git a/docs/source/metrics.rst b/docs/source/metrics.rst index 1d02f1b3d13..d67c90c2248 100644 --- a/docs/source/metrics.rst +++ b/docs/source/metrics.rst @@ -378,6 +378,7 @@ Complete list of metrics regression.MedianRelativeAbsoluteError regression.PearsonCorrelation regression.SpearmanRankCorrelation + regression.KendallRankCorrelation regression.R2Score regression.WaveHedgesDistance diff --git a/ignite/metrics/regression/__init__.py b/ignite/metrics/regression/__init__.py index 0f4f58327e0..4be1abddb11 100644 --- a/ignite/metrics/regression/__init__.py +++ b/ignite/metrics/regression/__init__.py @@ -3,6 +3,7 @@ from ignite.metrics.regression.fractional_bias import FractionalBias from ignite.metrics.regression.geometric_mean_absolute_error import GeometricMeanAbsoluteError from ignite.metrics.regression.geometric_mean_relative_absolute_error import GeometricMeanRelativeAbsoluteError +from ignite.metrics.regression.kendall_correlation import KendallRankCorrelation from ignite.metrics.regression.manhattan_distance import ManhattanDistance from ignite.metrics.regression.maximum_absolute_error import MaximumAbsoluteError from ignite.metrics.regression.mean_absolute_relative_error import MeanAbsoluteRelativeError diff --git a/ignite/metrics/regression/kendall_correlation.py b/ignite/metrics/regression/kendall_correlation.py new file mode 100644 index 00000000000..5bb5cac7151 --- /dev/null +++ b/ignite/metrics/regression/kendall_correlation.py @@ -0,0 +1,109 @@ +from typing import Any, Callable, Tuple + +import torch + +from scipy.stats import kendalltau +from torch import Tensor + +from ignite.exceptions import NotComputableError +from ignite.metrics.epoch_metric import EpochMetric +from ignite.metrics.regression._base import _check_output_shapes, _check_output_types + + +def _compute_kendall_tau(variant: str = "b") -> Callable[[Tensor, Tensor], float]: + if variant not in ("b", "c"): + raise ValueError(f"variant accepts 'b' or 'c', got {variant!r}.") + + def _tau(predictions: Tensor, targets: Tensor) -> float: + np_preds = predictions.flatten().numpy() + np_targets = targets.flatten().numpy() + r = kendalltau(np_preds, np_targets, variant=variant).statistic + return r + + return _tau + + +class KendallRankCorrelation(EpochMetric): + r"""Calculates the + `Kendall rank correlation coefficient `_. + + .. math:: + \tau = 1-\frac{2(\text{number of discordant pairs})}{\left( \begin{array}{c}n\\2\end{array} \right)} + + Two prediction-target pairs :math:`(P_i, A_i)` and :math:`(P_j, A_j)`, where :math:`iP_j` and :math:`A_i>A_j`. + + The ``number of discordant pairs`` counts the number of pairs that are not concordant. + + The computation of this metric is implemented with + `scipy.stats.kendalltau `_. + + - ``update`` must receive output of the form ``(y_pred, y)`` or ``{'y_pred': y_pred, 'y': y}``. + - `y` and `y_pred` must be of same shape `(N, )` or `(N, 1)`. + + Parameters are inherited from ``Metric.__init__``. + + Args: + variant: variant of kendall rank correlation. ``b`` or ``c`` is accepted. + Details can be found + `here `_. + Default: ``b`` + output_transform: a callable that is used to transform the + :class:`~ignite.engine.engine.Engine`'s ``process_function``'s output into the + form expected by the metric. This can be useful if, for example, you have a multi-output model and + you want to compute the metric with respect to one of the outputs. + By default, metrics require the output as ``(y_pred, y)`` or ``{'y_pred': y_pred, 'y': y}``. + device: specifies which device updates are accumulated on. Setting the + metric's device to be the same as your ``update`` arguments ensures the ``update`` method is + non-blocking. By default, CPU. + + Examples: + To use with ``Engine`` and ``process_function``, simply attach the metric instance to the engine. + The output of the engine's ``process_function`` needs to be in format of + ``(y_pred, y)`` or ``{'y_pred': y_pred, 'y': y, ...}``. + + .. include:: defaults.rst + :start-after: :orphan: + + .. testcode:: + + metric = KendallRankCorrelation() + metric.attach(default_evaluator, 'kendall_tau') + y_true = torch.tensor([0., 1., 2., 3., 4., 5.]) + y_pred = torch.tensor([0.5, 2.8, 1.9, 1.3, 6.0, 4.1]) + state = default_evaluator.run([[y_pred, y_true]]) + print(state.metrics['kendall_tau']) + + .. testoutput:: + + 0.4666666666666666 + """ + + def __init__( + self, + variant: str = "b", + output_transform: Callable[..., Any] = lambda x: x, + check_compute_fn: bool = True, + device: str | torch.device = torch.device("cpu"), + skip_unrolling: bool = False, + ) -> None: + super().__init__(_compute_kendall_tau(variant), output_transform, check_compute_fn, device, skip_unrolling) + + def update(self, output: Tuple[torch.Tensor, torch.Tensor]) -> None: + y_pred, y = output[0].detach(), output[1].detach() + if y_pred.ndim == 1: + y_pred = y_pred.unsqueeze(1) + if y.ndim == 1: + y = y.unsqueeze(1) + + _check_output_shapes(output) + _check_output_types(output) + + super().update(output) + + def compute(self) -> float: + if len(self._predictions) < 1 or len(self._targets) < 1: + raise NotComputableError("KendallRankCorrelation must have at least one example before it can be computed.") + + return super().compute() diff --git a/tests/ignite/metrics/regression/test_kendall_correlation.py b/tests/ignite/metrics/regression/test_kendall_correlation.py new file mode 100644 index 00000000000..5dd55b0691b --- /dev/null +++ b/tests/ignite/metrics/regression/test_kendall_correlation.py @@ -0,0 +1,200 @@ +from typing import Tuple + +import numpy as np +import pytest + +import torch +from scipy.stats import kendalltau +from torch import Tensor + +from ignite import distributed as idist +from ignite.engine import Engine +from ignite.exceptions import NotComputableError +from ignite.metrics.regression import KendallRankCorrelation + + +def test_zero_sample(): + with pytest.raises( + NotComputableError, match="KendallRankCorrelation must have at least one example before it can be computed" + ): + metric = KendallRankCorrelation() + metric.compute() + + +def test_wrong_y_pred_shape(): + with pytest.raises(ValueError, match=r"Input y_pred should have shape \(N,\) or \(N, 1\), but given"): + metric = KendallRankCorrelation() + y_pred = torch.arange(9).reshape(3, 3).float() + y = torch.arange(3).unsqueeze(1).float() + metric.update((y_pred, y)) + + +def test_wrong_y_shape(): + with pytest.raises(ValueError, match=r"Input y should have shape \(N,\) or \(N, 1\), but given"): + metric = KendallRankCorrelation() + y_pred = torch.arange(3).unsqueeze(1).float() + y = torch.arange(9).reshape(3, 3).float() + metric.update((y_pred, y)) + + +def test_wrong_y_pred_dtype(): + with pytest.raises(TypeError, match="Input y_pred dtype should be float 16, 32 or 64, but given"): + metric = KendallRankCorrelation() + y_pred = torch.arange(3).unsqueeze(1).long() + y = torch.arange(3).unsqueeze(1).float() + metric.update((y_pred, y)) + + +def test_wrong_y_dtype(): + with pytest.raises(TypeError, match="Input y dtype should be float 16, 32 or 64, but given"): + metric = KendallRankCorrelation() + y_pred = torch.arange(3).unsqueeze(1).float() + y = torch.arange(3).unsqueeze(1).long() + metric.update((y_pred, y)) + + +def test_wrong_variant(): + with pytest.raises(ValueError, match="variant accepts 'b' or 'c', got"): + KendallRankCorrelation(variant="x") + + +@pytest.mark.parametrize("variant", ["b", "c"]) +def test_kendall_correlation(variant: str): + a = np.random.randn(4).astype(np.float32) + b = np.random.randn(4).astype(np.float32) + c = np.random.randn(4).astype(np.float32) + d = np.random.randn(4).astype(np.float32) + ground_truth = np.random.randn(4).astype(np.float32) + + m = KendallRankCorrelation(variant=variant) + + m.update((torch.from_numpy(a), torch.from_numpy(ground_truth))) + np_ans = kendalltau(a, ground_truth, variant=variant).statistic + assert m.compute() == pytest.approx(np_ans, rel=1e-4) + + m.update((torch.from_numpy(b), torch.from_numpy(ground_truth))) + np_ans = kendalltau(np.concatenate([a, b]), np.concatenate([ground_truth] * 2), variant=variant).statistic + assert m.compute() == pytest.approx(np_ans, rel=1e-4) + + m.update((torch.from_numpy(c), torch.from_numpy(ground_truth))) + np_ans = kendalltau(np.concatenate([a, b, c]), np.concatenate([ground_truth] * 3), variant=variant).statistic + assert m.compute() == pytest.approx(np_ans, rel=1e-4) + + m.update((torch.from_numpy(d), torch.from_numpy(ground_truth))) + np_ans = kendalltau(np.concatenate([a, b, c, d]), np.concatenate([ground_truth] * 4), variant=variant).statistic + assert m.compute() == pytest.approx(np_ans, rel=1e-4) + + +@pytest.fixture(params=list(range(2))) +def test_case(request): + # correlated sample + x = torch.randn(size=[50]).float() + y = x + torch.randn_like(x) * 0.1 + + return [ + (x, y, 1), + (torch.rand(size=(50, 1)).float(), torch.rand(size=(50, 1)).float(), 10), + ][request.param] + + +@pytest.mark.parametrize("n_times", range(5)) +@pytest.mark.parametrize("variant", ["b", "c"]) +def test_integration(n_times: int, variant: str, test_case: Tuple[Tensor, Tensor, int]): + y_pred, y, batch_size = test_case + + np_y = y.numpy().ravel() + np_y_pred = y_pred.numpy().ravel() + + def update_fn(engine: Engine, batch): + idx = (engine.state.iteration - 1) * batch_size + y_true_batch = np_y[idx : idx + batch_size] + y_pred_batch = np_y_pred[idx : idx + batch_size] + return torch.from_numpy(y_pred_batch), torch.from_numpy(y_true_batch) + + engine = Engine(update_fn) + + m = KendallRankCorrelation(variant=variant) + m.attach(engine, "kendall_tau") + + data = list(range(y_pred.shape[0] // batch_size)) + corr = engine.run(data, max_epochs=1).metrics["kendall_tau"] + + np_ans = kendalltau(np_y_pred, np_y, variant=variant).statistic + + assert pytest.approx(np_ans, rel=2e-4) == corr + + +@pytest.mark.usefixtures("distributed") +class TestDistributed: + @pytest.mark.parametrize("variant", ["b", "c"]) + def test_compute(self, variant: str): + rank = idist.get_rank() + device = idist.device() + metric_devices = [torch.device("cpu")] + if device.type != "xla": + metric_devices.append(device) + + torch.manual_seed(10 + rank) + for metric_device in metric_devices: + m = KendallRankCorrelation(device=metric_device, variant=variant) + + y_pred = torch.rand(size=[100], device=device) + y = torch.rand(size=[100], device=device) + + m.update((y_pred, y)) + + y_pred = idist.all_gather(y_pred) + y = idist.all_gather(y) + + np_y = y.cpu().numpy() + np_y_pred = y_pred.cpu().numpy() + + np_ans = kendalltau(np_y_pred, np_y, variant=variant).statistic + + assert pytest.approx(np_ans, rel=2e-4) == m.compute() + + @pytest.mark.parametrize("n_epochs", [1, 2]) + @pytest.mark.parametrize("variant", ["b", "c"]) + def test_integration(self, n_epochs: int, variant: str): + tol = 2e-4 + rank = idist.get_rank() + device = idist.device() + metric_devices = [torch.device("cpu")] + if device.type != "xla": + metric_devices.append(device) + + n_iters = 80 + batch_size = 16 + + for metric_device in metric_devices: + torch.manual_seed(12 + rank) + + y_true = torch.rand(size=(n_iters * batch_size,)).to(device) + y_preds = torch.rand(size=(n_iters * batch_size,)).to(device) + + engine = Engine( + lambda e, i: ( + y_preds[i * batch_size : (i + 1) * batch_size], + y_true[i * batch_size : (i + 1) * batch_size], + ) + ) + + corr = KendallRankCorrelation(variant=variant, device=metric_device) + corr.attach(engine, "kendall_tau") + + data = list(range(n_iters)) + engine.run(data=data, max_epochs=n_epochs) + + y_preds = idist.all_gather(y_preds) + y_true = idist.all_gather(y_true) + + assert "kendall_tau" in engine.state.metrics + + res = engine.state.metrics["kendall_tau"] + + np_y = y_true.cpu().numpy() + np_y_pred = y_preds.cpu().numpy() + + np_ans = kendalltau(np_y_pred, np_y, variant=variant).statistic + + assert pytest.approx(np_ans, rel=tol) == res From 572c83c28bfe149adee61be18c1b1c6c4c85c331 Mon Sep 17 00:00:00 2001 From: Kazuki Adachi Date: Sun, 1 Sep 2024 22:32:23 +0900 Subject: [PATCH 03/12] add import check of scipy --- .../metrics/regression/kendall_correlation.py | 12 +++++++--- .../regression/spearman_correlation.py | 23 +++++++++++++------ 2 files changed, 25 insertions(+), 10 deletions(-) diff --git a/ignite/metrics/regression/kendall_correlation.py b/ignite/metrics/regression/kendall_correlation.py index 5bb5cac7151..fdc1778b388 100644 --- a/ignite/metrics/regression/kendall_correlation.py +++ b/ignite/metrics/regression/kendall_correlation.py @@ -2,7 +2,6 @@ import torch -from scipy.stats import kendalltau from torch import Tensor from ignite.exceptions import NotComputableError @@ -10,7 +9,9 @@ from ignite.metrics.regression._base import _check_output_shapes, _check_output_types -def _compute_kendall_tau(variant: str = "b") -> Callable[[Tensor, Tensor], float]: +def _get_kendall_tau(variant: str = "b") -> Callable[[Tensor, Tensor], float]: + from scipy.stats import kendalltau + if variant not in ("b", "c"): raise ValueError(f"variant accepts 'b' or 'c', got {variant!r}.") @@ -88,7 +89,12 @@ def __init__( device: str | torch.device = torch.device("cpu"), skip_unrolling: bool = False, ) -> None: - super().__init__(_compute_kendall_tau(variant), output_transform, check_compute_fn, device, skip_unrolling) + try: + from scipy.stats import kendalltau # noqa: F401 + except ImportError: + raise ModuleNotFoundError("This module requires scipy to be installed.") + + super().__init__(_get_kendall_tau(variant), output_transform, check_compute_fn, device, skip_unrolling) def update(self, output: Tuple[torch.Tensor, torch.Tensor]) -> None: y_pred, y = output[0].detach(), output[1].detach() diff --git a/ignite/metrics/regression/spearman_correlation.py b/ignite/metrics/regression/spearman_correlation.py index b4c3baee524..fb1c565e4ea 100644 --- a/ignite/metrics/regression/spearman_correlation.py +++ b/ignite/metrics/regression/spearman_correlation.py @@ -2,7 +2,6 @@ import torch -from scipy.stats import spearmanr from torch import Tensor from ignite.exceptions import NotComputableError @@ -10,11 +9,16 @@ from ignite.metrics.regression._base import _check_output_shapes, _check_output_types -def _compute_spearman_r(predictions: Tensor, targets: Tensor) -> float: - np_preds = predictions.flatten().numpy() - np_targets = targets.flatten().numpy() - r = spearmanr(np_preds, np_targets).statistic - return r +def _get_spearman_r() -> Callable[[Tensor, Tensor], float]: + from scipy.stats import spearmanr + + def _compute_spearman_r(predictions: Tensor, targets: Tensor) -> float: + np_preds = predictions.flatten().numpy() + np_targets = targets.flatten().numpy() + r = spearmanr(np_preds, np_targets).statistic + return r + + return _compute_spearman_r class SpearmanRankCorrelation(EpochMetric): @@ -73,7 +77,12 @@ def __init__( device: str | torch.device = torch.device("cpu"), skip_unrolling: bool = False, ) -> None: - super().__init__(_compute_spearman_r, output_transform, check_compute_fn, device, skip_unrolling) + try: + from scipy.stats import spearmanr # noqa: F401 + except ImportError: + raise ModuleNotFoundError("This module requires scipy to be installed.") + + super().__init__(_get_spearman_r(), output_transform, check_compute_fn, device, skip_unrolling) def update(self, output: Tuple[torch.Tensor, torch.Tensor]) -> None: y_pred, y = output[0].detach(), output[1].detach() From 2553385594bfe9e2c30893d258ea4f19dac88462 Mon Sep 17 00:00:00 2001 From: Kazuki Adachi Date: Tue, 3 Sep 2024 20:42:43 +0900 Subject: [PATCH 04/12] fix type hints --- ignite/metrics/regression/kendall_correlation.py | 4 ++-- ignite/metrics/regression/spearman_correlation.py | 4 ++-- 2 files changed, 4 insertions(+), 4 deletions(-) diff --git a/ignite/metrics/regression/kendall_correlation.py b/ignite/metrics/regression/kendall_correlation.py index fdc1778b388..f1dbad5fff8 100644 --- a/ignite/metrics/regression/kendall_correlation.py +++ b/ignite/metrics/regression/kendall_correlation.py @@ -1,4 +1,4 @@ -from typing import Any, Callable, Tuple +from typing import Any, Callable, Tuple, Union import torch @@ -86,7 +86,7 @@ def __init__( variant: str = "b", output_transform: Callable[..., Any] = lambda x: x, check_compute_fn: bool = True, - device: str | torch.device = torch.device("cpu"), + device: Union[str, torch.device] = torch.device("cpu"), skip_unrolling: bool = False, ) -> None: try: diff --git a/ignite/metrics/regression/spearman_correlation.py b/ignite/metrics/regression/spearman_correlation.py index fb1c565e4ea..84ca3dc8b12 100644 --- a/ignite/metrics/regression/spearman_correlation.py +++ b/ignite/metrics/regression/spearman_correlation.py @@ -1,4 +1,4 @@ -from typing import Any, Callable, Tuple +from typing import Any, Callable, Tuple, Union import torch @@ -74,7 +74,7 @@ def __init__( self, output_transform: Callable[..., Any] = lambda x: x, check_compute_fn: bool = True, - device: str | torch.device = torch.device("cpu"), + device: Union[str, torch.device] = torch.device("cpu"), skip_unrolling: bool = False, ) -> None: try: From febc2004d9b3def87a40cb9439f8072621e80207 Mon Sep 17 00:00:00 2001 From: Kazuki Adachi Date: Tue, 3 Sep 2024 20:44:53 +0900 Subject: [PATCH 05/12] fix formatting error --- ignite/metrics/regression/spearman_correlation.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/ignite/metrics/regression/spearman_correlation.py b/ignite/metrics/regression/spearman_correlation.py index 84ca3dc8b12..b57cc63994b 100644 --- a/ignite/metrics/regression/spearman_correlation.py +++ b/ignite/metrics/regression/spearman_correlation.py @@ -23,7 +23,8 @@ def _compute_spearman_r(predictions: Tensor, targets: Tensor) -> float: class SpearmanRankCorrelation(EpochMetric): r"""Calculates the - `Spearman's rank correlation coefficient `_. + `Spearman's rank correlation coefficient + `_. .. math:: r_\text{s} = \text{Corr}[R[P], R[A]] = \frac{\text{Cov}[R[P], R[A]]}{\sigma_{R[P]} \sigma_{R[A]}} From 0232b9f154769c7f22ecbb13ab39ed48d22c9e7b Mon Sep 17 00:00:00 2001 From: Kazuki Adachi Date: Tue, 3 Sep 2024 20:53:20 +0900 Subject: [PATCH 06/12] minor modification to docstring --- ignite/metrics/regression/kendall_correlation.py | 2 +- ignite/metrics/regression/spearman_correlation.py | 3 ++- 2 files changed, 3 insertions(+), 2 deletions(-) diff --git a/ignite/metrics/regression/kendall_correlation.py b/ignite/metrics/regression/kendall_correlation.py index f1dbad5fff8..826ca350ac2 100644 --- a/ignite/metrics/regression/kendall_correlation.py +++ b/ignite/metrics/regression/kendall_correlation.py @@ -35,7 +35,7 @@ class KendallRankCorrelation(EpochMetric): are said to be concordant when both :math:`P_iP_j` and :math:`A_i>A_j`. - The ``number of discordant pairs`` counts the number of pairs that are not concordant. + The `number of discordant pairs` counts the number of pairs that are not concordant. The computation of this metric is implemented with `scipy.stats.kendalltau `_. diff --git a/ignite/metrics/regression/spearman_correlation.py b/ignite/metrics/regression/spearman_correlation.py index b57cc63994b..7c5d586b152 100644 --- a/ignite/metrics/regression/spearman_correlation.py +++ b/ignite/metrics/regression/spearman_correlation.py @@ -29,7 +29,8 @@ class SpearmanRankCorrelation(EpochMetric): .. math:: r_\text{s} = \text{Corr}[R[P], R[A]] = \frac{\text{Cov}[R[P], R[A]]}{\sigma_{R[P]} \sigma_{R[A]}} - where :math:`A` and :math:`P` are the ground truth and predicted value, and R[X] is the ranking value of X. + where :math:`A` and :math:`P` are the ground truth and predicted value, + and :math:`R[X]` is the ranking value of :math:`X`. The computation of this metric is implemented with `scipy.stats.spearmanr `_. From e8a40f382b140505160c2ab283ede2672a35623f Mon Sep 17 00:00:00 2001 From: Kazuki Adachi Date: Tue, 10 Sep 2024 19:04:06 +0900 Subject: [PATCH 07/12] add versionadded directive to docstring --- ignite/metrics/regression/kendall_correlation.py | 2 ++ ignite/metrics/regression/spearman_correlation.py | 2 ++ 2 files changed, 4 insertions(+) diff --git a/ignite/metrics/regression/kendall_correlation.py b/ignite/metrics/regression/kendall_correlation.py index 826ca350ac2..306cc7394ab 100644 --- a/ignite/metrics/regression/kendall_correlation.py +++ b/ignite/metrics/regression/kendall_correlation.py @@ -79,6 +79,8 @@ class KendallRankCorrelation(EpochMetric): .. testoutput:: 0.4666666666666666 + + .. versionadded:: 0.5.2 """ def __init__( diff --git a/ignite/metrics/regression/spearman_correlation.py b/ignite/metrics/regression/spearman_correlation.py index 7c5d586b152..d35db94ec4a 100644 --- a/ignite/metrics/regression/spearman_correlation.py +++ b/ignite/metrics/regression/spearman_correlation.py @@ -70,6 +70,8 @@ class SpearmanRankCorrelation(EpochMetric): .. testoutput:: 0.7142857142857143 + + .. versionadded:: 0.5.2 """ def __init__( From c0a05a09037711a3e3d67229465ad5b6d4309fe3 Mon Sep 17 00:00:00 2001 From: Kazuki Adachi Date: Tue, 10 Sep 2024 19:04:48 +0900 Subject: [PATCH 08/12] add description for skip_unrolling argument --- ignite/metrics/regression/kendall_correlation.py | 3 +++ ignite/metrics/regression/spearman_correlation.py | 3 +++ 2 files changed, 6 insertions(+) diff --git a/ignite/metrics/regression/kendall_correlation.py b/ignite/metrics/regression/kendall_correlation.py index 306cc7394ab..4db5f4e22d6 100644 --- a/ignite/metrics/regression/kendall_correlation.py +++ b/ignite/metrics/regression/kendall_correlation.py @@ -58,6 +58,9 @@ class KendallRankCorrelation(EpochMetric): device: specifies which device updates are accumulated on. Setting the metric's device to be the same as your ``update`` arguments ensures the ``update`` method is non-blocking. By default, CPU. + skip_unrolling: specifies whether output should be unrolled before being fed to update method. Should be + true for multi-output model, for example, if ``y_pred`` contains multi-ouput as ``(y_pred_a, y_pred_b)`` + Alternatively, ``output_transform`` can be used to handle this. Examples: To use with ``Engine`` and ``process_function``, simply attach the metric instance to the engine. diff --git a/ignite/metrics/regression/spearman_correlation.py b/ignite/metrics/regression/spearman_correlation.py index d35db94ec4a..f5cc3cfff7b 100644 --- a/ignite/metrics/regression/spearman_correlation.py +++ b/ignite/metrics/regression/spearman_correlation.py @@ -49,6 +49,9 @@ class SpearmanRankCorrelation(EpochMetric): device: specifies which device updates are accumulated on. Setting the metric's device to be the same as your ``update`` arguments ensures the ``update`` method is non-blocking. By default, CPU. + skip_unrolling: specifies whether output should be unrolled before being fed to update method. Should be + true for multi-output model, for example, if ``y_pred`` contains multi-ouput as ``(y_pred_a, y_pred_b)`` + Alternatively, ``output_transform`` can be used to handle this. Examples: To use with ``Engine`` and ``process_function``, simply attach the metric instance to the engine. From 642b3f63315421df60826bbabd6332f94538f60e Mon Sep 17 00:00:00 2001 From: Kazuki Adachi Date: Tue, 10 Sep 2024 19:06:00 +0900 Subject: [PATCH 09/12] remove check_compute_fn argument --- ignite/metrics/regression/kendall_correlation.py | 3 +-- ignite/metrics/regression/spearman_correlation.py | 3 +-- 2 files changed, 2 insertions(+), 4 deletions(-) diff --git a/ignite/metrics/regression/kendall_correlation.py b/ignite/metrics/regression/kendall_correlation.py index 4db5f4e22d6..0928c816b74 100644 --- a/ignite/metrics/regression/kendall_correlation.py +++ b/ignite/metrics/regression/kendall_correlation.py @@ -90,7 +90,6 @@ def __init__( self, variant: str = "b", output_transform: Callable[..., Any] = lambda x: x, - check_compute_fn: bool = True, device: Union[str, torch.device] = torch.device("cpu"), skip_unrolling: bool = False, ) -> None: @@ -99,7 +98,7 @@ def __init__( except ImportError: raise ModuleNotFoundError("This module requires scipy to be installed.") - super().__init__(_get_kendall_tau(variant), output_transform, check_compute_fn, device, skip_unrolling) + super().__init__(_get_kendall_tau(variant), output_transform, True, device, skip_unrolling) def update(self, output: Tuple[torch.Tensor, torch.Tensor]) -> None: y_pred, y = output[0].detach(), output[1].detach() diff --git a/ignite/metrics/regression/spearman_correlation.py b/ignite/metrics/regression/spearman_correlation.py index f5cc3cfff7b..5dd6855b2d2 100644 --- a/ignite/metrics/regression/spearman_correlation.py +++ b/ignite/metrics/regression/spearman_correlation.py @@ -80,7 +80,6 @@ class SpearmanRankCorrelation(EpochMetric): def __init__( self, output_transform: Callable[..., Any] = lambda x: x, - check_compute_fn: bool = True, device: Union[str, torch.device] = torch.device("cpu"), skip_unrolling: bool = False, ) -> None: @@ -89,7 +88,7 @@ def __init__( except ImportError: raise ModuleNotFoundError("This module requires scipy to be installed.") - super().__init__(_get_spearman_r(), output_transform, check_compute_fn, device, skip_unrolling) + super().__init__(_get_spearman_r(), output_transform, True, device, skip_unrolling) def update(self, output: Tuple[torch.Tensor, torch.Tensor]) -> None: y_pred, y = output[0].detach(), output[1].detach() From 1d3744bfc9a036eb67e4a85c2df583566837a79b Mon Sep 17 00:00:00 2001 From: Kazuki Adachi Date: Tue, 10 Sep 2024 19:06:29 +0900 Subject: [PATCH 10/12] minor update on docstring --- ignite/metrics/regression/kendall_correlation.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/ignite/metrics/regression/kendall_correlation.py b/ignite/metrics/regression/kendall_correlation.py index 0928c816b74..dc7dfa295b4 100644 --- a/ignite/metrics/regression/kendall_correlation.py +++ b/ignite/metrics/regression/kendall_correlation.py @@ -46,10 +46,10 @@ class KendallRankCorrelation(EpochMetric): Parameters are inherited from ``Metric.__init__``. Args: - variant: variant of kendall rank correlation. ``b`` or ``c`` is accepted. + variant: variant of kendall rank correlation. ``'b'`` or ``'c'`` is accepted. Details can be found `here `_. - Default: ``b`` + Default: ``'b'`` output_transform: a callable that is used to transform the :class:`~ignite.engine.engine.Engine`'s ``process_function``'s output into the form expected by the metric. This can be useful if, for example, you have a multi-output model and From 465f55b6773f49b5af95a66c2633724690f5f409 Mon Sep 17 00:00:00 2001 From: Kazuki Adachi Date: Tue, 10 Sep 2024 21:06:58 +0900 Subject: [PATCH 11/12] Revert "remove check_compute_fn argument" This reverts commit 642b3f63315421df60826bbabd6332f94538f60e. --- ignite/metrics/regression/kendall_correlation.py | 3 ++- ignite/metrics/regression/spearman_correlation.py | 3 ++- 2 files changed, 4 insertions(+), 2 deletions(-) diff --git a/ignite/metrics/regression/kendall_correlation.py b/ignite/metrics/regression/kendall_correlation.py index dc7dfa295b4..04090d0d7d9 100644 --- a/ignite/metrics/regression/kendall_correlation.py +++ b/ignite/metrics/regression/kendall_correlation.py @@ -90,6 +90,7 @@ def __init__( self, variant: str = "b", output_transform: Callable[..., Any] = lambda x: x, + check_compute_fn: bool = True, device: Union[str, torch.device] = torch.device("cpu"), skip_unrolling: bool = False, ) -> None: @@ -98,7 +99,7 @@ def __init__( except ImportError: raise ModuleNotFoundError("This module requires scipy to be installed.") - super().__init__(_get_kendall_tau(variant), output_transform, True, device, skip_unrolling) + super().__init__(_get_kendall_tau(variant), output_transform, check_compute_fn, device, skip_unrolling) def update(self, output: Tuple[torch.Tensor, torch.Tensor]) -> None: y_pred, y = output[0].detach(), output[1].detach() diff --git a/ignite/metrics/regression/spearman_correlation.py b/ignite/metrics/regression/spearman_correlation.py index 5dd6855b2d2..f5cc3cfff7b 100644 --- a/ignite/metrics/regression/spearman_correlation.py +++ b/ignite/metrics/regression/spearman_correlation.py @@ -80,6 +80,7 @@ class SpearmanRankCorrelation(EpochMetric): def __init__( self, output_transform: Callable[..., Any] = lambda x: x, + check_compute_fn: bool = True, device: Union[str, torch.device] = torch.device("cpu"), skip_unrolling: bool = False, ) -> None: @@ -88,7 +89,7 @@ def __init__( except ImportError: raise ModuleNotFoundError("This module requires scipy to be installed.") - super().__init__(_get_spearman_r(), output_transform, True, device, skip_unrolling) + super().__init__(_get_spearman_r(), output_transform, check_compute_fn, device, skip_unrolling) def update(self, output: Tuple[torch.Tensor, torch.Tensor]) -> None: y_pred, y = output[0].detach(), output[1].detach() From edf078f6cd1a024610f08a0444844b4a3c320748 Mon Sep 17 00:00:00 2001 From: Kazuki Adachi Date: Sat, 12 Oct 2024 12:29:44 +0900 Subject: [PATCH 12/12] add description for check_compute_fn argument --- ignite/metrics/regression/kendall_correlation.py | 3 +++ ignite/metrics/regression/spearman_correlation.py | 3 +++ 2 files changed, 6 insertions(+) diff --git a/ignite/metrics/regression/kendall_correlation.py b/ignite/metrics/regression/kendall_correlation.py index 04090d0d7d9..7ad87b22402 100644 --- a/ignite/metrics/regression/kendall_correlation.py +++ b/ignite/metrics/regression/kendall_correlation.py @@ -55,6 +55,9 @@ class KendallRankCorrelation(EpochMetric): form expected by the metric. This can be useful if, for example, you have a multi-output model and you want to compute the metric with respect to one of the outputs. By default, metrics require the output as ``(y_pred, y)`` or ``{'y_pred': y_pred, 'y': y}``. + check_compute_fn: if True, ``compute_fn`` is run on the first batch of data to ensure there are no + issues. If issues exist, user is warned that there might be an issue with the ``compute_fn``. + Default, True. device: specifies which device updates are accumulated on. Setting the metric's device to be the same as your ``update`` arguments ensures the ``update`` method is non-blocking. By default, CPU. diff --git a/ignite/metrics/regression/spearman_correlation.py b/ignite/metrics/regression/spearman_correlation.py index f5cc3cfff7b..175198ed688 100644 --- a/ignite/metrics/regression/spearman_correlation.py +++ b/ignite/metrics/regression/spearman_correlation.py @@ -46,6 +46,9 @@ class SpearmanRankCorrelation(EpochMetric): form expected by the metric. This can be useful if, for example, you have a multi-output model and you want to compute the metric with respect to one of the outputs. By default, metrics require the output as ``(y_pred, y)`` or ``{'y_pred': y_pred, 'y': y}``. + check_compute_fn: if True, ``compute_fn`` is run on the first batch of data to ensure there are no + issues. If issues exist, user is warned that there might be an issue with the ``compute_fn``. + Default, True. device: specifies which device updates are accumulated on. Setting the metric's device to be the same as your ``update`` arguments ensures the ``update`` method is non-blocking. By default, CPU.