From a9e5667098b6e551344bc1610462e00b5b36406c Mon Sep 17 00:00:00 2001 From: Brian Lee Date: Tue, 26 Dec 2023 16:35:47 -0500 Subject: [PATCH] Remove TaskShardId --- lilac/data/dataset.py | 14 ++++---- lilac/data/dataset_duckdb.py | 64 ++++++++++++--------------------- lilac/load.py | 8 ++--- lilac/router_dataset_signals.py | 2 +- lilac/source.py | 4 --- lilac/tasks.py | 2 -- 6 files changed, 35 insertions(+), 59 deletions(-) diff --git a/lilac/data/dataset.py b/lilac/data/dataset.py index d71156386..ea42cca38 100644 --- a/lilac/data/dataset.py +++ b/lilac/data/dataset.py @@ -61,7 +61,7 @@ ) from ..signals.concept_scorer import ConceptSignal from ..source import Source, resolve_source -from ..tasks import TaskExecutionType, TaskShardId +from ..tasks import TaskExecutionType, TaskId from .dataset_format import DatasetFormat # Threshold for rejecting certain queries (e.g. group by) for columns with large cardinality. @@ -426,7 +426,7 @@ def compute_signal( limit: Optional[int] = None, include_deleted: bool = False, overwrite: bool = False, - task_shard_id: Optional[TaskShardId] = None, + task_id: Optional[TaskId] = None, ) -> None: """Compute a signal for a column. @@ -437,7 +437,7 @@ def compute_signal( limit: Limit the number of rows to compute the signal on. include_deleted: Whether to include deleted rows in the computation. overwrite: Whether to overwrite an existing signal computed at this path. - task_shard_id: The TaskManager `task_shard_id` for this process run. This is used to update + task_id: The TaskManager `task_id` for this process run. This is used to update the progress of the task. """ pass @@ -475,11 +475,11 @@ def compute_embedding( limit: Optional[int] = None, include_deleted: bool = False, overwrite: bool = False, - task_shard_id: Optional[TaskShardId] = None, + task_id: Optional[TaskId] = None, ) -> None: """Compute an embedding for a given field path.""" signal = get_signal_by_type(embedding, TextEmbeddingSignal)() - self.compute_signal(signal, path, filters, limit, include_deleted, overwrite, task_shard_id) + self.compute_signal(signal, path, filters, limit, include_deleted, overwrite, task_id) def compute_concept( self, @@ -491,7 +491,7 @@ def compute_concept( limit: Optional[int] = None, include_deleted: bool = False, overwrite: bool = False, - task_shard_id: Optional[TaskShardId] = None, + task_id: Optional[TaskId] = None, ) -> None: """Compute concept scores for a given field path.""" signal = ConceptSignal(namespace=namespace, concept_name=concept_name, embedding=embedding) @@ -502,7 +502,7 @@ def compute_concept( limit, include_deleted, overwrite=overwrite, - task_shard_id=task_shard_id, + task_id=task_id, ) @abc.abstractmethod diff --git a/lilac/data/dataset_duckdb.py b/lilac/data/dataset_duckdb.py index d2d3acf17..c48c562b2 100644 --- a/lilac/data/dataset_duckdb.py +++ b/lilac/data/dataset_duckdb.py @@ -100,7 +100,7 @@ from ..source import NoSource, SourceManifest from ..tasks import ( TaskExecutionType, - TaskShardId, + TaskId, get_progress_bar, ) from ..utils import ( @@ -280,22 +280,6 @@ class DuckDBQueryParams(BaseModel): The columns are not included in this interface for now, because there are too many ways that we use selects - all columns, one column, aliasing, unwrapping/flattening, etc. - Functions that will eventually wrap DuckDBQueryParams: - Exporters: choose the columns, choose some filters. Has some unique logic for "include" tags. - to_pandas -> _get_selection - to_json -> _get_selection - to_parquet -> _get_selection - to_csv -> _get_selection - - The megafunction. Needs to be sliced down featurewise to simplify. - select_rows -> ??? - - Compute on one column. No partial application. Has complicated caching logic. - compute_concept/embedding -> _compute_disk_cached -> _select_iterable_values - - Compute on a whole row. Has complicated caching logic. Has complicated sharding logic. - map/transform -> _map_worker -> _compute_disk_cached -> _select_iterable_values - Searches and tag inclusions should be compiled down to the equivalent Filter operations. Sharding limit/offsets should be handled by computing and setting the desired offsets/limits/sort. """ @@ -994,12 +978,10 @@ def compute_signal( limit: Optional[int] = None, include_deleted: bool = False, overwrite: bool = False, - task_shard_id: Optional[TaskShardId] = None, + task_id: Optional[TaskId] = None, ) -> None: if isinstance(signal, TextEmbeddingSignal): - return self.compute_embedding( - signal.name, path, overwrite=overwrite, task_shard_id=task_shard_id - ) + return self.compute_embedding(signal.name, path, overwrite=overwrite, task_id=task_id) input_path = normalize_path(path) @@ -1037,10 +1019,8 @@ def compute_signal( offset = self._get_cache_len(jsonl_cache_filepath) estimated_len = self.count(query_params) - if task_shard_id is not None: - progress_bar = get_progress_bar( - offset=offset, estimated_len=estimated_len, task_id=task_shard_id[0] - ) + if task_id is not None: + progress_bar = get_progress_bar(offset=offset, estimated_len=estimated_len, task_id=task_id) else: progress_bar = get_progress_bar(offset=offset, estimated_len=estimated_len) @@ -1102,7 +1082,7 @@ def compute_embedding( limit: Optional[int] = None, include_deleted: bool = False, overwrite: bool = False, - task_shard_id: Optional[TaskShardId] = None, + task_id: Optional[TaskId] = None, ) -> None: input_path = normalize_path(path) add_project_embedding_config( @@ -1138,10 +1118,8 @@ def compute_embedding( offset = self._get_cache_len(jsonl_cache_filepath) estimated_len = self.count(query_params) - if task_shard_id is not None: - progress_bar = get_progress_bar( - offset=offset, estimated_len=estimated_len, task_id=task_shard_id[0] - ) + if task_id is not None: + progress_bar = get_progress_bar(offset=offset, estimated_len=estimated_len, task_id=task_id) else: progress_bar = get_progress_bar(offset=offset, estimated_len=estimated_len) @@ -2651,6 +2629,7 @@ def map( execution_type: TaskExecutionType = 'threads', embedding: Optional[str] = None, schema: Optional[Field] = None, + task_id: Optional[TaskId] = None, ) -> Iterable[Item]: is_tmp_output = output_path is None manifest = self.manifest() @@ -2717,12 +2696,6 @@ def map( is_temporary=is_tmp_output, ) - output_col_desc_suffix = f' to "{output_path}"' if output_path else '' - progress_description = ( - f'[{self.namespace}/{self.dataset_name}][{num_jobs} shards] map ' - f'"{map_fn_name}"{output_col_desc_suffix}' - ) - sort_by = normalize_path(sort_by) if sort_by else None query_params = DuckDBQueryParams( filters=filters, @@ -2734,11 +2707,20 @@ def map( offset = self._get_cache_len(jsonl_cache_filepath) estimated_len = self.count(query_params) - progress_bar = get_progress_bar( - offset=offset, - task_description=progress_description, - estimated_len=estimated_len, - ) + if task_id is not None: + progress_bar = get_progress_bar(task_id, offset=offset, estimated_len=estimated_len) + else: + output_col_desc_suffix = f' to "{output_path}"' if output_path else '' + progress_description = ( + f'[{self.namespace}/{self.dataset_name}][{num_jobs} shards] map ' + f'"{map_fn_name}"{output_col_desc_suffix}' + ) + + progress_bar = get_progress_bar( + offset=offset, + task_description=progress_description, + estimated_len=estimated_len, + ) _consume_iterator( progress_bar( diff --git a/lilac/load.py b/lilac/load.py index d34cbc80f..ebca6e6e9 100644 --- a/lilac/load.py +++ b/lilac/load.py @@ -23,7 +23,7 @@ from .schema import ROWID, PathTuple from .tasks import ( TaskExecutionType, - TaskShardId, + TaskId, TaskType, get_task_manager, ) @@ -185,7 +185,7 @@ def load( d.name, s, project_dir, - (task_id, 0), + task_id, overwrite, ) else: @@ -216,7 +216,7 @@ def _compute_signal( name: str, signal_config: SignalConfig, project_dir: Union[str, pathlib.Path], - task_shard_id: TaskShardId, + task_id: TaskId, overwrite: bool = False, ) -> None: # Turn off debug logging. @@ -228,7 +228,7 @@ def _compute_signal( signal=signal_config.signal, path=signal_config.path, overwrite=overwrite, - task_shard_id=task_shard_id, + task_id=task_id, ) # Free up RAM. diff --git a/lilac/router_dataset_signals.py b/lilac/router_dataset_signals.py index 4ac03f1de..1b212e2c2 100644 --- a/lilac/router_dataset_signals.py +++ b/lilac/router_dataset_signals.py @@ -64,7 +64,7 @@ def compute_signal( options.leaf_path, # Overwrite for text embeddings since we don't have UI to control deleting embeddings. overwrite=isinstance(options.signal, TextEmbeddingSignal), - task_shard_id=(task_id, 0), + task_id=task_id, ) return ComputeSignalResponse(task_id=task_id) diff --git a/lilac/source.py b/lilac/source.py index b4f0c9e69..219fc3d9f 100644 --- a/lilac/source.py +++ b/lilac/source.py @@ -99,10 +99,6 @@ def yield_items(self) -> Iterable[Item]: This method is easier to use, and simply requires you to return an iterator of Python dicts. Lilac will take your iterable of items and handle writing it to parquet. You will still have to override source_schema. - - Args: - task_shard_id: The TaskManager `task_shard_id` for this process run. This is used to update - the progress of the task. """ raise NotImplementedError diff --git a/lilac/tasks.py b/lilac/tasks.py index 1011392fc..fa3d57f21 100644 --- a/lilac/tasks.py +++ b/lilac/tasks.py @@ -22,9 +22,7 @@ from .env import env from .utils import pretty_timedelta -# A tuple of the (task_id, shard_id). TaskId = str -TaskShardId = tuple[TaskId, int] TaskFn = Union[Callable[..., Any], Callable[..., Awaitable[Any]]]