Skip to content

Commit

Permalink
Remove TaskShardId
Browse files Browse the repository at this point in the history
  • Loading branch information
brilee committed Dec 26, 2023
1 parent bd40007 commit a9e5667
Show file tree
Hide file tree
Showing 6 changed files with 35 additions and 59 deletions.
14 changes: 7 additions & 7 deletions lilac/data/dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,7 +61,7 @@
)
from ..signals.concept_scorer import ConceptSignal
from ..source import Source, resolve_source
from ..tasks import TaskExecutionType, TaskShardId
from ..tasks import TaskExecutionType, TaskId
from .dataset_format import DatasetFormat

# Threshold for rejecting certain queries (e.g. group by) for columns with large cardinality.
Expand Down Expand Up @@ -426,7 +426,7 @@ def compute_signal(
limit: Optional[int] = None,
include_deleted: bool = False,
overwrite: bool = False,
task_shard_id: Optional[TaskShardId] = None,
task_id: Optional[TaskId] = None,
) -> None:
"""Compute a signal for a column.
Expand All @@ -437,7 +437,7 @@ def compute_signal(
limit: Limit the number of rows to compute the signal on.
include_deleted: Whether to include deleted rows in the computation.
overwrite: Whether to overwrite an existing signal computed at this path.
task_shard_id: The TaskManager `task_shard_id` for this process run. This is used to update
task_id: The TaskManager `task_id` for this process run. This is used to update
the progress of the task.
"""
pass
Expand Down Expand Up @@ -475,11 +475,11 @@ def compute_embedding(
limit: Optional[int] = None,
include_deleted: bool = False,
overwrite: bool = False,
task_shard_id: Optional[TaskShardId] = None,
task_id: Optional[TaskId] = None,
) -> None:
"""Compute an embedding for a given field path."""
signal = get_signal_by_type(embedding, TextEmbeddingSignal)()
self.compute_signal(signal, path, filters, limit, include_deleted, overwrite, task_shard_id)
self.compute_signal(signal, path, filters, limit, include_deleted, overwrite, task_id)

def compute_concept(
self,
Expand All @@ -491,7 +491,7 @@ def compute_concept(
limit: Optional[int] = None,
include_deleted: bool = False,
overwrite: bool = False,
task_shard_id: Optional[TaskShardId] = None,
task_id: Optional[TaskId] = None,
) -> None:
"""Compute concept scores for a given field path."""
signal = ConceptSignal(namespace=namespace, concept_name=concept_name, embedding=embedding)
Expand All @@ -502,7 +502,7 @@ def compute_concept(
limit,
include_deleted,
overwrite=overwrite,
task_shard_id=task_shard_id,
task_id=task_id,
)

@abc.abstractmethod
Expand Down
64 changes: 23 additions & 41 deletions lilac/data/dataset_duckdb.py
Original file line number Diff line number Diff line change
Expand Up @@ -100,7 +100,7 @@
from ..source import NoSource, SourceManifest
from ..tasks import (
TaskExecutionType,
TaskShardId,
TaskId,
get_progress_bar,
)
from ..utils import (
Expand Down Expand Up @@ -280,22 +280,6 @@ class DuckDBQueryParams(BaseModel):
The columns are not included in this interface for now, because there are too many ways
that we use selects - all columns, one column, aliasing, unwrapping/flattening, etc.
Functions that will eventually wrap DuckDBQueryParams:
Exporters: choose the columns, choose some filters. Has some unique logic for "include" tags.
to_pandas -> _get_selection
to_json -> _get_selection
to_parquet -> _get_selection
to_csv -> _get_selection
The megafunction. Needs to be sliced down featurewise to simplify.
select_rows -> ???
Compute on one column. No partial application. Has complicated caching logic.
compute_concept/embedding -> _compute_disk_cached -> _select_iterable_values
Compute on a whole row. Has complicated caching logic. Has complicated sharding logic.
map/transform -> _map_worker -> _compute_disk_cached -> _select_iterable_values
Searches and tag inclusions should be compiled down to the equivalent Filter operations.
Sharding limit/offsets should be handled by computing and setting the desired offsets/limits/sort.
"""
Expand Down Expand Up @@ -994,12 +978,10 @@ def compute_signal(
limit: Optional[int] = None,
include_deleted: bool = False,
overwrite: bool = False,
task_shard_id: Optional[TaskShardId] = None,
task_id: Optional[TaskId] = None,
) -> None:
if isinstance(signal, TextEmbeddingSignal):
return self.compute_embedding(
signal.name, path, overwrite=overwrite, task_shard_id=task_shard_id
)
return self.compute_embedding(signal.name, path, overwrite=overwrite, task_id=task_id)

input_path = normalize_path(path)

Expand Down Expand Up @@ -1037,10 +1019,8 @@ def compute_signal(
offset = self._get_cache_len(jsonl_cache_filepath)
estimated_len = self.count(query_params)

if task_shard_id is not None:
progress_bar = get_progress_bar(
offset=offset, estimated_len=estimated_len, task_id=task_shard_id[0]
)
if task_id is not None:
progress_bar = get_progress_bar(offset=offset, estimated_len=estimated_len, task_id=task_id)
else:
progress_bar = get_progress_bar(offset=offset, estimated_len=estimated_len)

Expand Down Expand Up @@ -1102,7 +1082,7 @@ def compute_embedding(
limit: Optional[int] = None,
include_deleted: bool = False,
overwrite: bool = False,
task_shard_id: Optional[TaskShardId] = None,
task_id: Optional[TaskId] = None,
) -> None:
input_path = normalize_path(path)
add_project_embedding_config(
Expand Down Expand Up @@ -1138,10 +1118,8 @@ def compute_embedding(
offset = self._get_cache_len(jsonl_cache_filepath)
estimated_len = self.count(query_params)

if task_shard_id is not None:
progress_bar = get_progress_bar(
offset=offset, estimated_len=estimated_len, task_id=task_shard_id[0]
)
if task_id is not None:
progress_bar = get_progress_bar(offset=offset, estimated_len=estimated_len, task_id=task_id)
else:
progress_bar = get_progress_bar(offset=offset, estimated_len=estimated_len)

Expand Down Expand Up @@ -2651,6 +2629,7 @@ def map(
execution_type: TaskExecutionType = 'threads',
embedding: Optional[str] = None,
schema: Optional[Field] = None,
task_id: Optional[TaskId] = None,
) -> Iterable[Item]:
is_tmp_output = output_path is None
manifest = self.manifest()
Expand Down Expand Up @@ -2717,12 +2696,6 @@ def map(
is_temporary=is_tmp_output,
)

output_col_desc_suffix = f' to "{output_path}"' if output_path else ''
progress_description = (
f'[{self.namespace}/{self.dataset_name}][{num_jobs} shards] map '
f'"{map_fn_name}"{output_col_desc_suffix}'
)

sort_by = normalize_path(sort_by) if sort_by else None
query_params = DuckDBQueryParams(
filters=filters,
Expand All @@ -2734,11 +2707,20 @@ def map(

offset = self._get_cache_len(jsonl_cache_filepath)
estimated_len = self.count(query_params)
progress_bar = get_progress_bar(
offset=offset,
task_description=progress_description,
estimated_len=estimated_len,
)
if task_id is not None:
progress_bar = get_progress_bar(task_id, offset=offset, estimated_len=estimated_len)
else:
output_col_desc_suffix = f' to "{output_path}"' if output_path else ''
progress_description = (
f'[{self.namespace}/{self.dataset_name}][{num_jobs} shards] map '
f'"{map_fn_name}"{output_col_desc_suffix}'
)

progress_bar = get_progress_bar(
offset=offset,
task_description=progress_description,
estimated_len=estimated_len,
)

_consume_iterator(
progress_bar(
Expand Down
8 changes: 4 additions & 4 deletions lilac/load.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@
from .schema import ROWID, PathTuple
from .tasks import (
TaskExecutionType,
TaskShardId,
TaskId,
TaskType,
get_task_manager,
)
Expand Down Expand Up @@ -185,7 +185,7 @@ def load(
d.name,
s,
project_dir,
(task_id, 0),
task_id,
overwrite,
)
else:
Expand Down Expand Up @@ -216,7 +216,7 @@ def _compute_signal(
name: str,
signal_config: SignalConfig,
project_dir: Union[str, pathlib.Path],
task_shard_id: TaskShardId,
task_id: TaskId,
overwrite: bool = False,
) -> None:
# Turn off debug logging.
Expand All @@ -228,7 +228,7 @@ def _compute_signal(
signal=signal_config.signal,
path=signal_config.path,
overwrite=overwrite,
task_shard_id=task_shard_id,
task_id=task_id,
)

# Free up RAM.
Expand Down
2 changes: 1 addition & 1 deletion lilac/router_dataset_signals.py
Original file line number Diff line number Diff line change
Expand Up @@ -64,7 +64,7 @@ def compute_signal(
options.leaf_path,
# Overwrite for text embeddings since we don't have UI to control deleting embeddings.
overwrite=isinstance(options.signal, TextEmbeddingSignal),
task_shard_id=(task_id, 0),
task_id=task_id,
)

return ComputeSignalResponse(task_id=task_id)
Expand Down
4 changes: 0 additions & 4 deletions lilac/source.py
Original file line number Diff line number Diff line change
Expand Up @@ -99,10 +99,6 @@ def yield_items(self) -> Iterable[Item]:
This method is easier to use, and simply requires you to return an iterator of Python dicts.
Lilac will take your iterable of items and handle writing it to parquet. You will still have to
override source_schema.
Args:
task_shard_id: The TaskManager `task_shard_id` for this process run. This is used to update
the progress of the task.
"""
raise NotImplementedError

Expand Down
2 changes: 0 additions & 2 deletions lilac/tasks.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,9 +22,7 @@
from .env import env
from .utils import pretty_timedelta

# A tuple of the (task_id, shard_id).
TaskId = str
TaskShardId = tuple[TaskId, int]
TaskFn = Union[Callable[..., Any], Callable[..., Awaitable[Any]]]


Expand Down

0 comments on commit a9e5667

Please # to comment.