Skip to content

Commit

Permalink
Cleanup silhouette variables (#31)
Browse files Browse the repository at this point in the history
* cleanup sil

* cleanup sil

* cleanup sil

* Update src/scib_metrics/utils/_silhouette.py
  • Loading branch information
adamgayoso authored Oct 12, 2022
1 parent 9f63380 commit e189453
Showing 1 changed file with 5 additions and 7 deletions.
12 changes: 5 additions & 7 deletions src/scib_metrics/utils/_silhouette.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,9 +55,8 @@ def _intra_cluster_distances_block(i: int, input: _IntraClusterData) -> jnp.ndar
return intra_dist


def _nearest_cluster_distances(X: np.ndarray, labels: np.ndarray):
def _nearest_cluster_distances(X: np.ndarray, labels: np.ndarray, unique_labels: np.ndarray):
"""Calculate the mean nearest-cluster distance for observation i."""
unique_labels = jnp.unique(labels)
inter_dist = jnp.array(np.inf * np.ones((X.shape[0],)))
label_combinations = jnp.array([list(i) for i in list(itertools.combinations(unique_labels, 2))])
for i in range(len(label_combinations)):
Expand All @@ -74,9 +73,9 @@ def _inter_values(subset_a: np.ndarray, subset_b: np.ndarray) -> Union[jnp.ndarr
return values_a, values_b


def _nearest_cluster_distance_block(inter_dist: jnp.ndarray, input: _InterClusterData) -> jnp.ndarray:
label_a = input.label_combos[inter_dist, 0]
label_b = input.label_combos[inter_dist, 1]
def _nearest_cluster_distance_block(combo_ind: int, input: _InterClusterData) -> jnp.ndarray:
label_a = input.label_combos[combo_ind, 0]
label_b = input.label_combos[combo_ind, 1]
label_mask_a = input.labels == label_a
label_mask_b = input.labels == label_b
subset_a = input.data[label_mask_a]
Expand Down Expand Up @@ -109,7 +108,6 @@ def silhouette_samples(X: np.ndarray, labels: np.ndarray) -> np.ndarray:
"""
if X.shape[0] != labels.shape[0]:
raise ValueError("X and labels should have the same number of samples")

intra_dist = _intra_cluster_distances(X, labels)
inter_dist = _nearest_cluster_distances(X, labels)
inter_dist = _nearest_cluster_distances(X, labels, np.unique(labels))
return jax.device_get((inter_dist - intra_dist) / jnp.maximum(intra_dist, inter_dist))

0 comments on commit e189453

Please # to comment.