Skip to content
New issue

Have a question about this project? # for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “#”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? # to your account

Make OpenAI calls threaded with exponential backoff #1005

Merged
merged 2 commits into from
Dec 27, 2023
Merged
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
55 changes: 34 additions & 21 deletions lilac/data/clustering.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,9 +3,11 @@
from typing import Any, Iterator, Optional

import instructor
from joblib import Parallel, delayed
from pydantic import (
BaseModel,
)
from tenacity import retry, stop_after_attempt, wait_random_exponential

from ..batch_utils import group_by_sorted_key_iter
from ..schema import (
Expand All @@ -24,6 +26,8 @@

_SHORTEN_LEN = 400
_TOP_K_CENTRAL_DOCS = 5
_NUM_THREADS = 32

TOPIC_FIELD_NAME = 'topic'
CLUSTER_FIELD_NAME = 'cluster'

Expand Down Expand Up @@ -127,32 +131,41 @@ def _compute_clusters(span_vectors: Iterator[list[SpanVector]]) -> Iterator[Item
overwrite=overwrite,
)

@retry(wait=wait_random_exponential(min=0.5, max=20), stop=stop_after_attempt(10))
def _compute_group_topic(
group: list[Item], text_column: str, cluster_column: str
) -> tuple[int, Optional[str]]:
docs: list[tuple[str, float]] = []
for item in group:
text = item[text_column]
if not text:
continue
cluster_id = item[cluster_column][CLUSTER_ID]
if cluster_id < 0:
continue
membership_prob = item[cluster_column][MEMBERSHIP_PROB] or 0
if membership_prob == 0:
continue
docs.append((text, membership_prob))

# Sort by membership score.
sorted_docs = sorted(docs, key=lambda x: x[1], reverse=True)
topic = topic_fn(sorted_docs) if sorted_docs else None
return len(group), topic

def _compute_topics(
text_column: str, cluster_column: str, items: Iterator[Item]
) -> Iterator[Item]:
# items here are pre-sorted by cluster id so we can group neighboring items.
groups = group_by_sorted_key_iter(items, lambda item: item[cluster_column][CLUSTER_ID])
for group in groups:
docs: list[tuple[str, float]] = []
for item in group:
text = item[text_column]
if not text:
continue
cluster_id = item[cluster_column][CLUSTER_ID]
if cluster_id < 0:
continue
membership_prob = item[cluster_column][MEMBERSHIP_PROB] or 0
if membership_prob == 0:
continue
docs.append((text, membership_prob))

# Sort by membership score.
sorted_docs = sorted(docs, key=lambda x: x[1], reverse=True)
topic = topic_fn(sorted_docs) if sorted_docs else None

# Yield a topic for each item in the group since the combined output needs to be the same
# length as the combined input.
for item in group:
parallel = Parallel(n_jobs=_NUM_THREADS, backend='threading', return_as='generator')
output_generator = parallel(
delayed(_compute_group_topic)(group, text_column, cluster_column) for group in groups
)
for group_size, topic in output_generator:
# Yield the same topic for each item in the group since the output needs to be the same
# length as the input.
for _ in range(group_size):
yield topic

# Now that we have the clusters, compute the topic for each cluster with another transform.
Expand Down