Skip to content

Commit

Permalink
update compute_fn style
Browse files Browse the repository at this point in the history
  • Loading branch information
kzkadc committed Oct 14, 2024
1 parent b15a9a9 commit a027027
Showing 1 changed file with 13 additions and 17 deletions.
30 changes: 13 additions & 17 deletions ignite/metrics/clustering/silhouette_score.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,18 +8,6 @@
__all__ = ["SilhouetteScore"]


def _get_silhouette_score(**kwargs: Any) -> Callable[[Tensor, Tensor], float]:
from sklearn.metrics import silhouette_score

def _silhouette_score(features: Tensor, labels: Tensor) -> float:
np_features = features.numpy()
np_labels = labels.numpy()
score = silhouette_score(np_features, np_labels, **kwargs)
return score

return _silhouette_score


class SilhouetteScore(_ClusteringMetricBase):
r"""Calculates the
`silhouette score <https://en.wikipedia.org/wiki/Silhouette_(clustering)>`_.
Expand Down Expand Up @@ -66,7 +54,7 @@ class SilhouetteScore(_ClusteringMetricBase):
skip_unrolling: specifies whether output should be unrolled before being fed to update method. Should be
true for multi-output model, for example, if ``y_pred`` contains multi-ouput as ``(y_pred_a, y_pred_b)``
Alternatively, ``output_transform`` can be used to handle this.
**silhouette_kwargs: additional arguments passed to ``sklearn.metrics.silhouette_score``.
silhouette_kwargs: additional arguments passed to ``sklearn.metrics.silhouette_score``.
Examples:
To use with ``Engine`` and ``process_function``, simply attach the metric instance to the engine.
Expand Down Expand Up @@ -109,13 +97,21 @@ def __init__(
check_compute_fn: bool = True,
device: Union[str, torch.device] = torch.device("cpu"),
skip_unrolling: bool = False,
**silhouette_kwargs: Any,
silhouette_kwargs: dict | None = None,
) -> None:
try:
from sklearn.metrics import silhouette_score # noqa: F401
except ImportError:
raise ModuleNotFoundError("This module requires scikit-learn to be installed.")

super().__init__(
_get_silhouette_score(**silhouette_kwargs), output_transform, check_compute_fn, device, skip_unrolling
)
self._silhouette_kwargs = {} if silhouette_kwargs is None else silhouette_kwargs

super().__init__(self._silhouette_score, output_transform, check_compute_fn, device, skip_unrolling)

def _silhouette_score(self, features: Tensor, labels: Tensor) -> float:
from sklearn.metrics import silhouette_score

np_features = features.numpy()
np_labels = labels.numpy()
score = silhouette_score(np_features, np_labels, **self._silhouette_kwargs)
return score

0 comments on commit a027027

Please # to comment.