From cd5d4ed939ecac85c369c07f7ad95b0f5ea305a4 Mon Sep 17 00:00:00 2001 From: Daniel Smilkov Date: Wed, 27 Dec 2023 10:51:37 -0500 Subject: [PATCH] save --- lilac/data/clustering.py | 55 +++++++++++++++++++++++++--------------- 1 file changed, 34 insertions(+), 21 deletions(-) diff --git a/lilac/data/clustering.py b/lilac/data/clustering.py index 7d71e65d3..fe8d93b21 100644 --- a/lilac/data/clustering.py +++ b/lilac/data/clustering.py @@ -3,9 +3,11 @@ from typing import Any, Iterator, Optional import instructor +from joblib import Parallel, delayed from pydantic import ( BaseModel, ) +from tenacity import retry, stop_after_attempt, wait_random_exponential from ..batch_utils import group_by_sorted_key_iter from ..schema import ( @@ -24,6 +26,8 @@ _SHORTEN_LEN = 400 _TOP_K_CENTRAL_DOCS = 5 +_NUM_THREADS = 32 + TOPIC_FIELD_NAME = 'topic' CLUSTER_FIELD_NAME = 'cluster' @@ -127,32 +131,41 @@ def _compute_clusters(span_vectors: Iterator[list[SpanVector]]) -> Iterator[Item overwrite=overwrite, ) + @retry(wait=wait_random_exponential(min=0.5, max=20), stop=stop_after_attempt(10)) + def _compute_group_topic( + group: list[Item], text_column: str, cluster_column: str + ) -> tuple[int, Optional[str]]: + docs: list[tuple[str, float]] = [] + for item in group: + text = item[text_column] + if not text: + continue + cluster_id = item[cluster_column][CLUSTER_ID] + if cluster_id < 0: + continue + membership_prob = item[cluster_column][MEMBERSHIP_PROB] or 0 + if membership_prob == 0: + continue + docs.append((text, membership_prob)) + + # Sort by membership score. + sorted_docs = sorted(docs, key=lambda x: x[1], reverse=True) + topic = topic_fn(sorted_docs) if sorted_docs else None + return len(group), topic + def _compute_topics( text_column: str, cluster_column: str, items: Iterator[Item] ) -> Iterator[Item]: # items here are pre-sorted by cluster id so we can group neighboring items. groups = group_by_sorted_key_iter(items, lambda item: item[cluster_column][CLUSTER_ID]) - for group in groups: - docs: list[tuple[str, float]] = [] - for item in group: - text = item[text_column] - if not text: - continue - cluster_id = item[cluster_column][CLUSTER_ID] - if cluster_id < 0: - continue - membership_prob = item[cluster_column][MEMBERSHIP_PROB] or 0 - if membership_prob == 0: - continue - docs.append((text, membership_prob)) - - # Sort by membership score. - sorted_docs = sorted(docs, key=lambda x: x[1], reverse=True) - topic = topic_fn(sorted_docs) if sorted_docs else None - - # Yield a topic for each item in the group since the combined output needs to be the same - # length as the combined input. - for item in group: + parallel = Parallel(n_jobs=_NUM_THREADS, backend='threading', return_as='generator') + output_generator = parallel( + delayed(_compute_group_topic)(group, text_column, cluster_column) for group in groups + ) + for group_size, topic in output_generator: + # Yield the same topic for each item in the group since the output needs to be the same + # length as the input. + for _ in range(group_size): yield topic # Now that we have the clusters, compute the topic for each cluster with another transform.