Skip to content
New issue

Have a question about this project? # for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “#”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? # to your account

Cleanup silhouette variables #31

Merged
merged 4 commits into from
Oct 12, 2022
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 5 additions & 7 deletions src/scib_metrics/utils/_silhouette.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,9 +55,8 @@ def _intra_cluster_distances_block(i: int, input: _IntraClusterData) -> jnp.ndar
return intra_dist


def _nearest_cluster_distances(X: np.ndarray, labels: np.ndarray):
def _nearest_cluster_distances(X: np.ndarray, labels: np.ndarray, unique_labels: np.ndarray):
"""Calculate the mean nearest-cluster distance for observation i."""
unique_labels = jnp.unique(labels)
inter_dist = jnp.array(np.inf * np.ones((X.shape[0],)))
label_combinations = jnp.array([list(i) for i in list(itertools.combinations(unique_labels, 2))])
for i in range(len(label_combinations)):
Expand All @@ -74,9 +73,9 @@ def _inter_values(subset_a: np.ndarray, subset_b: np.ndarray) -> Union[jnp.ndarr
return values_a, values_b


def _nearest_cluster_distance_block(inter_dist: jnp.ndarray, input: _InterClusterData) -> jnp.ndarray:
label_a = input.label_combos[inter_dist, 0]
label_b = input.label_combos[inter_dist, 1]
def _nearest_cluster_distance_block(combo_ind: int, input: _InterClusterData) -> jnp.ndarray:
label_a = input.label_combos[combo_ind, 0]
label_b = input.label_combos[combo_ind, 1]
label_mask_a = input.labels == label_a
label_mask_b = input.labels == label_b
subset_a = input.data[label_mask_a]
Expand Down Expand Up @@ -109,7 +108,6 @@ def silhouette_samples(X: np.ndarray, labels: np.ndarray) -> np.ndarray:
"""
if X.shape[0] != labels.shape[0]:
raise ValueError("X and labels should have the same number of samples")

intra_dist = _intra_cluster_distances(X, labels)
inter_dist = _nearest_cluster_distances(X, labels)
inter_dist = _nearest_cluster_distances(X, labels, np.unique(labels))
return jax.device_get((inter_dist - intra_dist) / jnp.maximum(intra_dist, inter_dist))