From a5d3464231f37b86fe4e5ae272fe14bbcb33b27d Mon Sep 17 00:00:00 2001 From: Kazuki Adachi Date: Mon, 14 Oct 2024 16:40:14 +0900 Subject: [PATCH] simplify imports of metric functions (#3292) --- ignite/metrics/cohen_kappa.py | 16 +++++----------- .../metrics/regression/spearman_correlation.py | 15 ++++++--------- 2 files changed, 11 insertions(+), 20 deletions(-) diff --git a/ignite/metrics/cohen_kappa.py b/ignite/metrics/cohen_kappa.py index 15cb0222c25..b4d54f2a744 100644 --- a/ignite/metrics/cohen_kappa.py +++ b/ignite/metrics/cohen_kappa.py @@ -71,23 +71,17 @@ def __init__( # initalize weights self.weights = weights - self.cohen_kappa_compute = self.get_cohen_kappa_fn() - super(CohenKappa, self).__init__( - self.cohen_kappa_compute, + self._cohen_kappa_score, output_transform=output_transform, check_compute_fn=check_compute_fn, device=device, skip_unrolling=skip_unrolling, ) - def get_cohen_kappa_fn(self) -> Callable[[torch.Tensor, torch.Tensor], float]: - """Return a function computing Cohen Kappa from scikit-learn.""" + def _cohen_kappa_score(self, y_targets: torch.Tensor, y_preds: torch.Tensor) -> float: from sklearn.metrics import cohen_kappa_score - def wrapper(y_targets: torch.Tensor, y_preds: torch.Tensor) -> float: - y_true = y_targets.cpu().numpy() - y_pred = y_preds.cpu().numpy() - return cohen_kappa_score(y_true, y_pred, weights=self.weights) - - return wrapper + y_true = y_targets.cpu().numpy() + y_pred = y_preds.cpu().numpy() + return cohen_kappa_score(y_true, y_pred, weights=self.weights) diff --git a/ignite/metrics/regression/spearman_correlation.py b/ignite/metrics/regression/spearman_correlation.py index 175198ed688..7f126d6e56b 100644 --- a/ignite/metrics/regression/spearman_correlation.py +++ b/ignite/metrics/regression/spearman_correlation.py @@ -9,16 +9,13 @@ from ignite.metrics.regression._base import _check_output_shapes, _check_output_types -def _get_spearman_r() -> Callable[[Tensor, Tensor], float]: +def _spearman_r(predictions: Tensor, targets: Tensor) -> float: from scipy.stats import spearmanr - def _compute_spearman_r(predictions: Tensor, targets: Tensor) -> float: - np_preds = predictions.flatten().numpy() - np_targets = targets.flatten().numpy() - r = spearmanr(np_preds, np_targets).statistic - return r - - return _compute_spearman_r + np_preds = predictions.flatten().numpy() + np_targets = targets.flatten().numpy() + r = spearmanr(np_preds, np_targets).statistic + return r class SpearmanRankCorrelation(EpochMetric): @@ -92,7 +89,7 @@ def __init__( except ImportError: raise ModuleNotFoundError("This module requires scipy to be installed.") - super().__init__(_get_spearman_r(), output_transform, check_compute_fn, device, skip_unrolling) + super().__init__(_spearman_r, output_transform, check_compute_fn, device, skip_unrolling) def update(self, output: Tuple[torch.Tensor, torch.Tensor]) -> None: y_pred, y = output[0].detach(), output[1].detach()