From a027027f33c7c7f7d27697f4fbb3d1e57e8b46fb Mon Sep 17 00:00:00 2001 From: Kazuki Adachi Date: Mon, 14 Oct 2024 18:27:20 +0900 Subject: [PATCH] update compute_fn style --- ignite/metrics/clustering/silhouette_score.py | 30 ++++++++----------- 1 file changed, 13 insertions(+), 17 deletions(-) diff --git a/ignite/metrics/clustering/silhouette_score.py b/ignite/metrics/clustering/silhouette_score.py index d82d603e0ea..13ea8bc0898 100644 --- a/ignite/metrics/clustering/silhouette_score.py +++ b/ignite/metrics/clustering/silhouette_score.py @@ -8,18 +8,6 @@ __all__ = ["SilhouetteScore"] -def _get_silhouette_score(**kwargs: Any) -> Callable[[Tensor, Tensor], float]: - from sklearn.metrics import silhouette_score - - def _silhouette_score(features: Tensor, labels: Tensor) -> float: - np_features = features.numpy() - np_labels = labels.numpy() - score = silhouette_score(np_features, np_labels, **kwargs) - return score - - return _silhouette_score - - class SilhouetteScore(_ClusteringMetricBase): r"""Calculates the `silhouette score `_. @@ -66,7 +54,7 @@ class SilhouetteScore(_ClusteringMetricBase): skip_unrolling: specifies whether output should be unrolled before being fed to update method. Should be true for multi-output model, for example, if ``y_pred`` contains multi-ouput as ``(y_pred_a, y_pred_b)`` Alternatively, ``output_transform`` can be used to handle this. - **silhouette_kwargs: additional arguments passed to ``sklearn.metrics.silhouette_score``. + silhouette_kwargs: additional arguments passed to ``sklearn.metrics.silhouette_score``. Examples: To use with ``Engine`` and ``process_function``, simply attach the metric instance to the engine. @@ -109,13 +97,21 @@ def __init__( check_compute_fn: bool = True, device: Union[str, torch.device] = torch.device("cpu"), skip_unrolling: bool = False, - **silhouette_kwargs: Any, + silhouette_kwargs: dict | None = None, ) -> None: try: from sklearn.metrics import silhouette_score # noqa: F401 except ImportError: raise ModuleNotFoundError("This module requires scikit-learn to be installed.") - super().__init__( - _get_silhouette_score(**silhouette_kwargs), output_transform, check_compute_fn, device, skip_unrolling - ) + self._silhouette_kwargs = {} if silhouette_kwargs is None else silhouette_kwargs + + super().__init__(self._silhouette_score, output_transform, check_compute_fn, device, skip_unrolling) + + def _silhouette_score(self, features: Tensor, labels: Tensor) -> float: + from sklearn.metrics import silhouette_score + + np_features = features.numpy() + np_labels = labels.numpy() + score = silhouette_score(np_features, np_labels, **self._silhouette_kwargs) + return score