From 534154006cc4c4ab909fc490bd2adf80b4f6319a Mon Sep 17 00:00:00 2001 From: Matthew Hoffman Date: Tue, 4 Mar 2025 00:12:06 -0600 Subject: [PATCH 1/9] Refactor Dataset.map to reuse cache files mapped with different num_proc Fixes #7433 This refactor unifies num_proc is None or num_proc == 1 and num_proc > 1; instead of handling them completely separately where one uses a list of kwargs and shards and the other just uses a single set of kwargs and self, by wrapping the num_proc == 1 case in a list and making the difference just whether or not you use a pool, you set up either case to be able to load each other cache_files just by changing num_shards; num_proc == 1 can sequentially load the shards of a dataset mapped num_shards > 1 and sequentially map any missing shards Other than the structural refactor, the main contribution of this PR is get_existing_cache_file_map, which uses a regex of cache_file_name and suffix_template to find existing cache files, grouped by their num_shards; using this data structure, we can reset num_shards to an existing set of cache files, and load them accordingly --- src/datasets/arrow_dataset.py | 292 +++++++++++++++++++++------------- tests/test_arrow_dataset.py | 85 ++++++++++ 2 files changed, 270 insertions(+), 107 deletions(-) diff --git a/src/datasets/arrow_dataset.py b/src/datasets/arrow_dataset.py index 6c8a48f7757..1af65ba2e20 100644 --- a/src/datasets/arrow_dataset.py +++ b/src/datasets/arrow_dataset.py @@ -19,6 +19,7 @@ import contextlib import copy import fnmatch +import glob import inspect import itertools import json @@ -27,12 +28,13 @@ import posixpath import re import shutil +import string import sys import tempfile import time import warnings import weakref -from collections import Counter +from collections import Counter, defaultdict from collections.abc import Mapping from copy import deepcopy from functools import partial, wraps @@ -2964,6 +2966,11 @@ def map( if num_proc is not None and num_proc <= 0: raise ValueError("num_proc must be an integer > 0.") + string_formatter = string.Formatter() + fields = {field_name for _, field_name, _, _ in string_formatter.parse(suffix_template) if field_name} + if fields != {"rank", "num_proc"}: + raise ValueError(f"suffix_template must contain exactly the fields 'rank' and 'num_proc', got: {fields}") + # If the array is empty we do nothing (but we make sure to handle an empty indices mapping and remove the requested columns anyway) if len(self) == 0: if self._indices is not None: # empty indices mapping @@ -3045,7 +3052,7 @@ def map( cache_file_name = self._get_cache_file_path(new_fingerprint) dataset_kwargs["cache_file_name"] = cache_file_name - def load_processed_shard_from_cache(shard_kwargs): + def load_processed_shard_from_cache(shard_kwargs: dict[str, Any]) -> Dataset: """Load a processed shard from cache if it exists, otherwise throw an error.""" shard = shard_kwargs["shard"] # Check if we've already cached this computation (indexed by a hash) @@ -3056,64 +3063,98 @@ def load_processed_shard_from_cache(shard_kwargs): return Dataset.from_file(shard_kwargs["cache_file_name"], info=info, split=shard.split) raise NonExistentDatasetError - num_shards = num_proc if num_proc is not None else 1 - if batched and drop_last_batch: - pbar_total = len(self) // num_shards // batch_size * num_shards * batch_size - else: - pbar_total = len(self) + def pbar_total(num_shards: int, batch_size: Optional[int]) -> int: + total = len(self) + if len(existing_cache_files) < num_shards: + total -= len(existing_cache_files) * total // num_shards + if batched and drop_last_batch: + batch_size = batch_size or 1 + return total // num_shards // batch_size * num_shards * batch_size + return total + + def get_existing_cache_file_map( + cache_file_name: Optional[str], + ) -> dict[int, list[str]]: + cache_files_by_num_proc: dict[int, list[str]] = defaultdict(list) + if cache_file_name is None: + return cache_files_by_num_proc + if os.path.exists(cache_file_name): + cache_files_by_num_proc[1] = [cache_file_name] + + suffix_pattern_parts: list[str] = [] + for literal_text, field_name, format_spec, _ in string_formatter.parse(suffix_template): + suffix_pattern_parts.append(re.escape(literal_text)) + if field_name: + # TODO: we may want to place restrictions on acceptable format_spec or we will fail to match + # someone's hexidecimal or scientific notation format 😵 + suffix_pattern_parts.append(f"(?P<{field_name}>\\d+)") + suffix_pattern = "".join(suffix_pattern_parts) + + cache_file_prefix, cache_file_ext = os.path.splitext(cache_file_name) + if not cache_file_ext: + raise ValueError(f"Expected cache_file_name to have an extension, but got: {cache_file_name}") + + cache_file_pattern = "^" + re.escape(cache_file_prefix) + suffix_pattern + re.escape(cache_file_ext) + "$" + cache_file_regex = re.compile(cache_file_pattern) + + for cache_file in glob.iglob(f"{cache_file_prefix}*{cache_file_ext}"): + if m := cache_file_regex.match(cache_file): + file_num_proc = int(m.group("num_proc")) + cache_files_by_num_proc[file_num_proc].append(cache_file) + + return cache_files_by_num_proc + + existing_cache_file_map = get_existing_cache_file_map(cache_file_name) + + num_shards = num_proc or 1 + if existing_cache_file_map: + # to avoid remapping when a different num_proc is given than when originally cached, update num_shards to + # what was used originally + + def select_existing_cache_files(mapped_num_proc: int) -> tuple[float, ...]: + percent_missing = (mapped_num_proc - len(existing_cache_file_map[mapped_num_proc])) / mapped_num_proc + num_shards_diff = abs(mapped_num_proc - num_shards) + return ( + percent_missing, # choose the most complete set of existing cache files + num_shards_diff, # then choose the mapped_num_proc closest to the current num_proc + mapped_num_proc, # finally, choose whichever mapped_num_proc is lower + ) - shards_done = 0 - if num_proc is None or num_proc == 1: - transformed_dataset = None - try: - transformed_dataset = load_processed_shard_from_cache(dataset_kwargs) - logger.info(f"Loading cached processed dataset at {dataset_kwargs['cache_file_name']}") - except NonExistentDatasetError: - pass - if transformed_dataset is None: - with hf_tqdm( - unit=" examples", - total=pbar_total, - desc=desc or "Map", - ) as pbar: - for rank, done, content in Dataset._map_single(**dataset_kwargs): - if done: - shards_done += 1 - logger.debug(f"Finished processing shard number {rank} of {num_shards}.") - transformed_dataset = content - else: - pbar.update(content) - assert transformed_dataset is not None, "Failed to retrieve the result from map" - # update fingerprint if the dataset changed - if transformed_dataset._fingerprint != self._fingerprint: - transformed_dataset._fingerprint = new_fingerprint - return transformed_dataset - else: + num_shards = min(existing_cache_file_map, key=select_existing_cache_files) - def format_cache_file_name( - cache_file_name: Optional[str], - rank: Union[int, Literal["*"]], # noqa: F722 - ) -> Optional[str]: - if not cache_file_name: - return cache_file_name - sep = cache_file_name.rindex(".") - base_name, extension = cache_file_name[:sep], cache_file_name[sep:] - if isinstance(rank, int): - cache_file_name = base_name + suffix_template.format(rank=rank, num_proc=num_proc) + extension - logger.info(f"Process #{rank} will write at {cache_file_name}") - else: - cache_file_name = ( - base_name - + suffix_template.replace("{rank:05d}", "{rank}").format(rank=rank, num_proc=num_proc) - + extension - ) + existing_cache_files = existing_cache_file_map.get(num_shards, []) + + def format_cache_file_name( + cache_file_name: Optional[str], + rank: Union[int, Literal["*"]], # noqa: F722 + ) -> Optional[str]: + if not cache_file_name: return cache_file_name - def format_new_fingerprint(new_fingerprint: str, rank: int) -> str: - new_fingerprint = new_fingerprint + suffix_template.format(rank=rank, num_proc=num_proc) - validate_fingerprint(new_fingerprint) - return new_fingerprint + cache_file_prefix, cache_file_ext = os.path.splitext(cache_file_name) + if not cache_file_ext: + raise ValueError(f"Expected cache_file_name to have an extension, but got: {cache_file_name}") + + if isinstance(rank, int): + cache_file_name = ( + cache_file_prefix + suffix_template.format(rank=rank, num_proc=num_shards) + cache_file_ext + ) + logger.info(f"Process #{rank} will write at {cache_file_name}") + else: + # TODO: this assumes the format_spec of rank in suffix_template + cache_file_name = ( + cache_file_prefix + + suffix_template.replace("{rank:05d}", "{rank}").format(rank=rank, num_proc=num_shards) + + cache_file_ext + ) + return cache_file_name + + def format_new_fingerprint(new_fingerprint: str, rank: int) -> str: + new_fingerprint = new_fingerprint + suffix_template.format(rank=rank, num_proc=num_shards) + validate_fingerprint(new_fingerprint) + return new_fingerprint + if num_proc is not None and num_proc > 1: prev_env = deepcopy(os.environ) # check if parallelism if off # from https://github.com/huggingface/tokenizers/blob/bb668bc439dc34389b71dbb8ce0c597f15707b53/tokenizers/src/utils/parallelism.rs#L22 @@ -3128,9 +3169,17 @@ def format_new_fingerprint(new_fingerprint: str, rank: int) -> str: ): logger.warning("Setting TOKENIZERS_PARALLELISM=false for forked processes.") os.environ["TOKENIZERS_PARALLELISM"] = "false" + else: + prev_env = os.environ + + kwargs_per_job: list[Optional[dict[str, Any]]] + if num_shards == 1: + shards = [self] + kwargs_per_job = [dataset_kwargs] + else: shards = [ - self.shard(num_shards=num_proc, index=rank, contiguous=True, keep_in_memory=keep_in_memory) - for rank in range(num_proc) + self.shard(num_shards=num_shards, index=rank, contiguous=True, keep_in_memory=keep_in_memory) + for rank in range(num_shards) ] kwargs_per_job = [ { @@ -3144,60 +3193,89 @@ def format_new_fingerprint(new_fingerprint: str, rank: int) -> str: for rank in range(num_shards) ] - transformed_shards = [None] * num_shards - for rank in range(num_shards): - try: - transformed_shards[rank] = load_processed_shard_from_cache(kwargs_per_job[rank]) - kwargs_per_job[rank] = None - except NonExistentDatasetError: - pass - - kwargs_per_job = [kwargs for kwargs in kwargs_per_job if kwargs is not None] - - # We try to create a pool with as many workers as dataset not yet cached. - if kwargs_per_job: - if len(kwargs_per_job) < num_shards: - logger.info( - f"Reprocessing {len(kwargs_per_job)}/{num_shards} shards because some of them were missing from the cache." - ) - with Pool(len(kwargs_per_job)) as pool: - os.environ = prev_env - logger.info(f"Spawning {num_proc} processes") - with hf_tqdm( - unit=" examples", - total=pbar_total, - desc=(desc or "Map") + f" (num_proc={num_proc})", - ) as pbar: + transformed_shards: list[Optional[Dataset]] = [None] * num_shards + for rank in range(num_shards): + try: + job_kwargs = kwargs_per_job[rank] + assert job_kwargs is not None + transformed_shards[rank] = load_processed_shard_from_cache(job_kwargs) + kwargs_per_job[rank] = None + except NonExistentDatasetError: + pass + + if unprocessed_kwargs_per_job := [kwargs for kwargs in kwargs_per_job if kwargs is not None]: + if len(unprocessed_kwargs_per_job) < num_shards: + logger.info( + f"Reprocessing {len(unprocessed_kwargs_per_job)}/{num_shards} shards because some of them were " + " missing from the cache." + ) + + with hf_tqdm( + unit=" examples", + total=pbar_total(num_shards, batch_size), + desc=(desc or "Map") + (f" (num_proc={num_proc})" if num_proc is not None and num_proc > 1 else ""), + ) as pbar: + shards_done = 0 + + def check_if_shard_done(rank: Optional[int], done: bool, content: Union[Dataset, int]) -> None: + nonlocal shards_done + if done: + shards_done += 1 + logger.debug(f"Finished processing shard number {rank} of {num_shards}.") + assert isinstance(content, Dataset) + transformed_shards[rank or 0] = content + else: + assert isinstance(content, int) + pbar.update(content) + + if num_proc is not None and num_proc > 1: + with Pool(num_proc) as pool: + os.environ = prev_env + logger.info(f"Spawning {num_proc} processes") + for rank, done, content in iflatmap_unordered( - pool, Dataset._map_single, kwargs_iterable=kwargs_per_job + pool, Dataset._map_single, kwargs_iterable=unprocessed_kwargs_per_job ): - if done: - shards_done += 1 - logger.debug(f"Finished processing shard number {rank} of {num_shards}.") - transformed_shards[rank] = content - else: - pbar.update(content) - pool.close() - pool.join() - # Avoids PermissionError on Windows (the error: https://github.com/huggingface/datasets/actions/runs/4026734820/jobs/6921621805) - for kwargs in kwargs_per_job: - del kwargs["shard"] - else: - logger.info(f"Loading cached processed dataset at {format_cache_file_name(cache_file_name, '*')}") - assert None not in transformed_shards, ( - f"Failed to retrieve results from map: result list {transformed_shards} still contains None - at least one worker failed to return its results" + check_if_shard_done(rank, done, content) + + pool.close() + pool.join() + else: + for unprocessed_kwargs in unprocessed_kwargs_per_job: + for rank, done, content in Dataset._map_single(**unprocessed_kwargs): + check_if_shard_done(rank, done, content) + + # Avoids PermissionError on Windows (the error: https://github.com/huggingface/datasets/actions/runs/4026734820/jobs/6921621805) + for job_kwargs in unprocessed_kwargs_per_job: + if "shard" in job_kwargs: + del job_kwargs["shard"] + else: + logger.info(f"Loading cached processed dataset at {format_cache_file_name(cache_file_name, '*')}") + + all_transformed_shards = [shard for shard in transformed_shards if shard is not None] + if len(transformed_shards) != len(all_transformed_shards): + raise ValueError( + f"Failed to retrieve results from map: result list {transformed_shards} still contains None - " + "at least one worker failed to return its results" ) - logger.info(f"Concatenating {num_proc} shards") - result = _concatenate_map_style_datasets(transformed_shards) - # update fingerprint if the dataset changed + + if num_shards == 1: + result = all_transformed_shards[0] + else: + logger.info(f"Concatenating {num_shards} shards") + result = _concatenate_map_style_datasets(all_transformed_shards) + + # update fingerprint if the dataset changed + result._fingerprint = ( + new_fingerprint if any( transformed_shard._fingerprint != shard._fingerprint - for transformed_shard, shard in zip(transformed_shards, shards) - ): - result._fingerprint = new_fingerprint - else: - result._fingerprint = self._fingerprint - return result + for transformed_shard, shard in zip(all_transformed_shards, shards) + ) + else self._fingerprint + ) + + return result @staticmethod def _map_single( @@ -3219,7 +3297,7 @@ def _map_single( new_fingerprint: Optional[str] = None, rank: Optional[int] = None, offset: int = 0, - ) -> Iterable[Tuple[int, bool, Union[int, "Dataset"]]]: + ) -> Iterable[Tuple[Optional[int], bool, Union[int, "Dataset"]]]: """Apply a function to all the elements in the table (individually or in batches) and update the table (if function does update examples). diff --git a/tests/test_arrow_dataset.py b/tests/test_arrow_dataset.py index 2e54aadf7b6..d8083b79ca3 100644 --- a/tests/test_arrow_dataset.py +++ b/tests/test_arrow_dataset.py @@ -1459,6 +1459,91 @@ def test_map_caching(self, in_memory): finally: datasets.enable_caching() + def test_suffix_template_format(self, in_memory): + with ( + tempfile.TemporaryDirectory() as tmp_dir, + self._caplog.at_level(INFO, logger=get_logger().name), + self._create_dummy_dataset(in_memory, tmp_dir) as dset, + self.assertRaises(ValueError) as e, + dset.map(lambda x: {"foo": "bar"}, suffix_template="_{}_of_{}"), + ): + self.assertIn( + "suffix_template must contain exactly the fields 'rank' and 'num_proc', got: ", + e.exception.args[0], + ) + + def test_cache_file_name_no_ext_raises_error(self, in_memory): + with ( + tempfile.TemporaryDirectory() as tmp_dir, + self._caplog.at_level(INFO, logger=get_logger().name), + self._create_dummy_dataset(in_memory, tmp_dir) as dset, + self.assertRaises(ValueError) as e, + dset.map(lambda x: {"foo": "bar"}, cache_file_name=os.path.join(tmp_dir, "train")), + ): + self.assertIn("Expected cache_file_name to have an extension, but got: ", e.exception.args[0]) + + def test_map_caching_reuses_cache_with_different_num_proc(self, in_memory): + for dset_test1_num_proc, dset_test2_num_proc in [(1, 2), (2, 1)]: + with ( + tempfile.TemporaryDirectory() as tmp_dir, + self._caplog.at_level(INFO, logger=get_logger().name), + self._create_dummy_dataset(in_memory, tmp_dir) as dset, + ): + # cannot mock _map_single here because mock objects aren't picklable + # see: https://github.com/python/cpython/issues/100090 + self._caplog.clear() + with dset.map(lambda x: {"foo": "bar"}, num_proc=dset_test1_num_proc) as dset_test1: + dset_test1_data_files = list(dset_test1.cache_files) + self.assertFalse("Loading cached processed dataset" in self._caplog.text) + + self._caplog.clear() + with dset.map(lambda x: {"foo": "bar"}, num_proc=dset_test2_num_proc) as dset_test2: + self.assertEqual(dset_test1_data_files, dset_test2.cache_files) + self.assertEqual(len(dset_test2.cache_files), 0 if in_memory else dset_test1_num_proc) + self.assertTrue(("Loading cached processed dataset" in self._caplog.text) ^ in_memory) + + def test_map_caching_partial_remap(self, in_memory): + with ( + tempfile.TemporaryDirectory() as tmp_dir, + self._caplog.at_level(INFO, logger=get_logger().name), + self._create_dummy_dataset(in_memory, tmp_dir) as dset, + ): + # cannot mock _map_single here because mock objects aren't picklable + # see: https://github.com/python/cpython/issues/100090 + self._caplog.clear() + dset_test1_num_proc = 4 + with dset.map(lambda x: {"foo": "bar"}, num_proc=dset_test1_num_proc) as dset_test1: + dset_test1_data_files = list(dset_test1.cache_files) + self.assertFalse("Loading cached processed dataset" in self._caplog.text) + + num_files_to_delete = 2 + expected_msg = ( + f"Reprocessing {num_files_to_delete}/{dset_test1_num_proc} shards because some of them " + "were missing from the cache." + ) + for cache_file in dset_test1_data_files[num_files_to_delete:]: + os.remove(cache_file["filename"]) + + self._caplog.clear() + dset_test2_num_proc = 1 + with dset.map(lambda x: {"foo": "bar"}, num_proc=dset_test2_num_proc) as dset_test2: + self.assertEqual(dset_test1_data_files, dset_test2.cache_files) + self.assertEqual(len(dset_test2.cache_files), 0 if in_memory else dset_test1_num_proc) + self.assertTrue((expected_msg in self._caplog.text) ^ in_memory) + self.assertFalse(f"Spawning {dset_test1_num_proc} processes" in self._caplog.text) + self.assertFalse(f"Spawning {dset_test2_num_proc} processes" in self._caplog.text) + + for cache_file in dset_test1_data_files[num_files_to_delete:]: + os.remove(cache_file["filename"]) + + self._caplog.clear() + dset_test3_num_proc = 3 + with dset.map(lambda x: {"foo": "bar"}, num_proc=dset_test3_num_proc) as dset_test3: + self.assertEqual(dset_test1_data_files, dset_test3.cache_files) + self.assertEqual(len(dset_test3.cache_files), 0 if in_memory else dset_test1_num_proc) + self.assertTrue((expected_msg in self._caplog.text) ^ in_memory) + self.assertTrue(f"Spawning {dset_test3_num_proc} processes" in self._caplog.text) + def test_map_return_pa_table(self, in_memory): def func_return_single_row_pa_table(x): return pa.table({"id": [0], "text": ["a"]}) From bdc17c9fce8650174b8f7e29bc4f74813c175776 Mon Sep 17 00:00:00 2001 From: Matthew Hoffman Date: Tue, 4 Mar 2025 10:23:09 -0600 Subject: [PATCH 2/9] Only give reprocessing message doing a partial remap also fix spacing in message --- src/datasets/arrow_dataset.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/datasets/arrow_dataset.py b/src/datasets/arrow_dataset.py index 1af65ba2e20..37e90197551 100644 --- a/src/datasets/arrow_dataset.py +++ b/src/datasets/arrow_dataset.py @@ -3204,10 +3204,10 @@ def format_new_fingerprint(new_fingerprint: str, rank: int) -> str: pass if unprocessed_kwargs_per_job := [kwargs for kwargs in kwargs_per_job if kwargs is not None]: - if len(unprocessed_kwargs_per_job) < num_shards: + if len(unprocessed_kwargs_per_job) != num_shards: logger.info( f"Reprocessing {len(unprocessed_kwargs_per_job)}/{num_shards} shards because some of them were " - " missing from the cache." + "missing from the cache." ) with hf_tqdm( From d7c63fd76dd17b8092953c645a09e9778b85c9ba Mon Sep 17 00:00:00 2001 From: Matthew Hoffman Date: Tue, 4 Mar 2025 10:31:42 -0600 Subject: [PATCH 3/9] Update logging message to account for if a cache file will be written at all and written by the main process or not --- src/datasets/arrow_dataset.py | 6 +++++- 1 file changed, 5 insertions(+), 1 deletion(-) diff --git a/src/datasets/arrow_dataset.py b/src/datasets/arrow_dataset.py index 37e90197551..f72d587174e 100644 --- a/src/datasets/arrow_dataset.py +++ b/src/datasets/arrow_dataset.py @@ -3139,7 +3139,11 @@ def format_cache_file_name( cache_file_name = ( cache_file_prefix + suffix_template.format(rank=rank, num_proc=num_shards) + cache_file_ext ) - logger.info(f"Process #{rank} will write at {cache_file_name}") + if not os.path.exists(cache_file_name): + process_name = ( + "Main process" if num_proc is None or num_proc == 1 else f"Process #{rank % num_shards + 1}" + ) + logger.info(f"{process_name} will write at {cache_file_name}") else: # TODO: this assumes the format_spec of rank in suffix_template cache_file_name = ( From 0df413201f741df5f74921384e18c9f97efe1e85 Mon Sep 17 00:00:00 2001 From: Matthew Hoffman Date: Tue, 4 Mar 2025 15:58:16 -0600 Subject: [PATCH 4/9] Refactor string_to_dict to return None if there is no match instead of raising ValueError instead of having the pattern of using try-except to handle when there is no match, we can instead check if the return value is None; we can also assert that the return value should not be None if we know that should be true --- src/datasets/arrow_dataset.py | 25 ++++++++------- src/datasets/data_files.py | 14 +++++---- src/datasets/dataset_dict.py | 14 ++++----- src/datasets/features/audio.py | 8 ++--- src/datasets/features/image.py | 8 ++--- src/datasets/features/video.py | 57 +++++++++++++++++++++------------- src/datasets/utils/py_utils.py | 13 ++++---- tests/test_py_utils.py | 20 ++++++++++++ 8 files changed, 97 insertions(+), 62 deletions(-) diff --git a/src/datasets/arrow_dataset.py b/src/datasets/arrow_dataset.py index 6c8a48f7757..8d8324c9b08 100644 --- a/src/datasets/arrow_dataset.py +++ b/src/datasets/arrow_dataset.py @@ -3184,9 +3184,11 @@ def format_new_fingerprint(new_fingerprint: str, rank: int) -> str: del kwargs["shard"] else: logger.info(f"Loading cached processed dataset at {format_cache_file_name(cache_file_name, '*')}") - assert None not in transformed_shards, ( - f"Failed to retrieve results from map: result list {transformed_shards} still contains None - at least one worker failed to return its results" - ) + if None in transformed_shards: + raise ValueError( + f"Failed to retrieve results from map: result list {transformed_shards} still contains None - at " + "least one worker failed to return its results" + ) logger.info(f"Concatenating {num_proc} shards") result = _concatenate_map_style_datasets(transformed_shards) # update fingerprint if the dataset changed @@ -5333,7 +5335,7 @@ def _push_parquet_shards_to_hub( max_shard_size: Optional[Union[int, str]] = None, num_shards: Optional[int] = None, embed_external_files: bool = True, - ) -> Tuple[str, str, int, int, List[str], int]: + ) -> tuple[list[CommitOperationAdd], int, int]: """Pushes the dataset shards as Parquet files to the hub. Returns: @@ -5379,7 +5381,7 @@ def shards_with_embedded_external_files(shards: Iterator[Dataset]) -> Iterator[D api = HfApi(endpoint=config.HF_ENDPOINT, token=token) uploaded_size = 0 - additions = [] + additions: list[CommitOperationAdd] = [] for index, shard in hf_tqdm( enumerate(shards), desc="Uploading the dataset shards", @@ -5564,8 +5566,9 @@ def push_to_hub( # Check if the repo already has a README.md and/or a dataset_infos.json to update them with the new split info (size and pattern) # and delete old split shards (if they exist) repo_with_dataset_card, repo_with_dataset_infos = False, False - deletions, deleted_size = [], 0 - repo_splits = [] # use a list to keep the order of the splits + deletions: list[CommitOperationDelete] = [] + deleted_size = 0 + repo_splits: list[str] = [] # use a list to keep the order of the splits repo_files_to_add = [addition.path_in_repo for addition in additions] for repo_file in api.list_repo_tree( repo_id=repo_id, revision=revision, repo_type="dataset", token=token, recursive=True @@ -5584,10 +5587,10 @@ def push_to_hub( elif fnmatch.fnmatch( repo_file.rfilename, PUSH_TO_HUB_WITHOUT_METADATA_CONFIGS_SPLIT_PATTERN_SHARDED.replace("{split}", "*") ): - repo_split = string_to_dict( - repo_file.rfilename, - glob_pattern_to_regex(PUSH_TO_HUB_WITHOUT_METADATA_CONFIGS_SPLIT_PATTERN_SHARDED), - )["split"] + pattern = glob_pattern_to_regex(PUSH_TO_HUB_WITHOUT_METADATA_CONFIGS_SPLIT_PATTERN_SHARDED) + split_pattern_fields = string_to_dict(repo_file.rfilename, pattern) + assert split_pattern_fields is not None + repo_split = split_pattern_fields["split"] if repo_split not in repo_splits: repo_splits.append(repo_split) diff --git a/src/datasets/data_files.py b/src/datasets/data_files.py index 26987e299a2..7d162e79f14 100644 --- a/src/datasets/data_files.py +++ b/src/datasets/data_files.py @@ -3,7 +3,7 @@ from functools import partial from glob import has_magic from pathlib import Path, PurePath -from typing import Callable, Dict, List, Optional, Set, Tuple, Union +from typing import Callable, Dict, List, Optional, Tuple, Union import huggingface_hub from fsspec.core import url_to_fs @@ -276,14 +276,16 @@ def _get_data_files_patterns(pattern_resolver: Callable[[str], List[str]]) -> Di except FileNotFoundError: continue if len(data_files) > 0: - splits: Set[str] = { - string_to_dict(xbasename(p), glob_pattern_to_regex(xbasename(split_pattern)))["split"] - for p in data_files - } + splits: set[str] = set() + for p in data_files: + p_parts = string_to_dict(xbasename(p), glob_pattern_to_regex(xbasename(split_pattern))) + assert p_parts is not None + splits.add(p_parts["split"]) + if any(not re.match(_split_re, split) for split in splits): raise ValueError(f"Split name should match '{_split_re}'' but got '{splits}'.") sorted_splits = [str(split) for split in DEFAULT_SPLITS if split in splits] + sorted( - splits - set(DEFAULT_SPLITS) + splits - {str(split) for split in DEFAULT_SPLITS} ) return {split: [split_pattern.format(split=split)] for split in sorted_splits} # then check the default patterns based on train/valid/test splits diff --git a/src/datasets/dataset_dict.py b/src/datasets/dataset_dict.py index 251cbaa5055..c75d4e04221 100644 --- a/src/datasets/dataset_dict.py +++ b/src/datasets/dataset_dict.py @@ -1700,8 +1700,8 @@ def push_to_hub( # Check if the repo already has a README.md and/or a dataset_infos.json to update them with the new split info (size and pattern) # and delete old split shards (if they exist) repo_with_dataset_card, repo_with_dataset_infos = False, False - repo_splits = [] # use a list to keep the order of the splits - deletions = [] + repo_splits: list[str] = [] # use a list to keep the order of the splits + deletions: list[CommitOperationDelete] = [] repo_files_to_add = [addition.path_in_repo for addition in additions] for repo_file in api.list_repo_tree( repo_id=repo_id, revision=revision, repo_type="dataset", token=token, recursive=True @@ -1720,12 +1720,12 @@ def push_to_hub( elif fnmatch.fnmatch( repo_file.rfilename, PUSH_TO_HUB_WITHOUT_METADATA_CONFIGS_SPLIT_PATTERN_SHARDED.replace("{split}", "*") ): - repo_split = string_to_dict( - repo_file.rfilename, - glob_pattern_to_regex(PUSH_TO_HUB_WITHOUT_METADATA_CONFIGS_SPLIT_PATTERN_SHARDED), - )["split"] + pattern = glob_pattern_to_regex(PUSH_TO_HUB_WITHOUT_METADATA_CONFIGS_SPLIT_PATTERN_SHARDED) + split_pattern_fields = string_to_dict(repo_file.rfilename, pattern) + assert split_pattern_fields is not None + repo_split = split_pattern_fields["split"] if repo_split not in repo_splits: - repo_splits.append(split) + repo_splits.append(repo_split) # get the info from the README to update them if repo_with_dataset_card: diff --git a/src/datasets/features/audio.py b/src/datasets/features/audio.py index f7df47b7a06..dfc3e573233 100644 --- a/src/datasets/features/audio.py +++ b/src/datasets/features/audio.py @@ -173,11 +173,9 @@ def decode_example( pattern = ( config.HUB_DATASETS_URL if source_url.startswith(config.HF_ENDPOINT) else config.HUB_DATASETS_HFFS_URL ) - try: - repo_id = string_to_dict(source_url, pattern)["repo_id"] - token = token_per_repo_id[repo_id] - except (ValueError, KeyError): - token = None + source_url_fields = string_to_dict(source_url, pattern) + assert source_url_fields is not None + token = token_per_repo_id.get(source_url_fields["repo_id"]) download_config = DownloadConfig(token=token) with xopen(path, "rb", download_config=download_config) as f: diff --git a/src/datasets/features/image.py b/src/datasets/features/image.py index 0393689fc46..9cfcbc2bb8c 100644 --- a/src/datasets/features/image.py +++ b/src/datasets/features/image.py @@ -174,11 +174,9 @@ def decode_example(self, value: dict, token_per_repo_id=None) -> "PIL.Image.Imag if source_url.startswith(config.HF_ENDPOINT) else config.HUB_DATASETS_HFFS_URL ) - try: - repo_id = string_to_dict(source_url, pattern)["repo_id"] - token = token_per_repo_id.get(repo_id) - except ValueError: - token = None + source_url_fields = string_to_dict(source_url, pattern) + assert source_url_fields is not None + token = token_per_repo_id.get(source_url_fields["repo_id"]) download_config = DownloadConfig(token=token) with xopen(path, "rb", download_config=download_config) as f: bytes_ = BytesIO(f.read()) diff --git a/src/datasets/features/video.py b/src/datasets/features/video.py index 2cde83930ac..a02f345543b 100644 --- a/src/datasets/features/video.py +++ b/src/datasets/features/video.py @@ -1,7 +1,7 @@ import os from dataclasses import dataclass, field from io import BytesIO -from typing import TYPE_CHECKING, Any, ClassVar, Dict, Optional, Union +from typing import TYPE_CHECKING, Any, ClassVar, Dict, Optional, TypedDict, Union import numpy as np import pyarrow as pa @@ -19,6 +19,11 @@ from .features import FeatureType +class Example(TypedDict): + path: Optional[str] + bytes: Optional[bytes] + + @dataclass class Video: """ @@ -71,7 +76,7 @@ def __post_init__(self): def __call__(self): return self.pa_type - def encode_example(self, value: Union[str, bytes, dict, np.ndarray, "VideoReader"]) -> dict: + def encode_example(self, value: Union[str, bytes, Example, np.ndarray, "VideoReader"]) -> Example: """Encode example into a format for Arrow. Args: @@ -97,21 +102,29 @@ def encode_example(self, value: Union[str, bytes, dict, np.ndarray, "VideoReader elif isinstance(value, np.ndarray): # convert the video array to bytes return encode_np_array(value) - elif VideoReader and isinstance(value, VideoReader): + elif VideoReader is not None and isinstance(value, VideoReader): # convert the decord video reader to bytes return encode_decord_video(value) - elif value.get("path") is not None and os.path.isfile(value["path"]): - # we set "bytes": None to not duplicate the data if they're already available locally - return {"bytes": None, "path": value.get("path")} - elif value.get("bytes") is not None or value.get("path") is not None: - # store the video bytes, and path is used to infer the video format using the file extension - return {"bytes": value.get("bytes"), "path": value.get("path")} + elif isinstance(value, dict): + path, bytes_ = value.get("path"), value.get("bytes") + if path is not None and os.path.isfile(path): + # we set "bytes": None to not duplicate the data if they're already available locally + return {"bytes": None, "path": path} + elif bytes_ is not None or path is not None: + # store the video bytes, and path is used to infer the video format using the file extension + return {"bytes": bytes_, "path": path} + else: + raise ValueError( + f"A video sample should have one of 'path' or 'bytes' but they are missing or None in {value}." + ) else: - raise ValueError( - f"A video sample should have one of 'path' or 'bytes' but they are missing or None in {value}." - ) + raise TypeError(f"Unsupported encode_example type: {type(value)}") - def decode_example(self, value: dict, token_per_repo_id=None) -> "VideoReader": + def decode_example( + self, + value: Union[str, Example], + token_per_repo_id: Optional[dict[str, Union[bool, str]]] = None, + ) -> "VideoReader": """Decode example video file into video data. Args: @@ -141,7 +154,11 @@ def decode_example(self, value: dict, token_per_repo_id=None) -> "VideoReader": if token_per_repo_id is None: token_per_repo_id = {} - path, bytes_ = value["path"], value["bytes"] + if isinstance(value, str): + path, bytes_ = value, None + else: + path, bytes_ = value["path"], value["bytes"] + if bytes_ is None: if path is None: raise ValueError(f"A video should have one of 'path' or 'bytes' but both are None in {value}.") @@ -155,11 +172,9 @@ def decode_example(self, value: dict, token_per_repo_id=None) -> "VideoReader": if source_url.startswith(config.HF_ENDPOINT) else config.HUB_DATASETS_HFFS_URL ) - try: - repo_id = string_to_dict(source_url, pattern)["repo_id"] - token = token_per_repo_id.get(repo_id) - except ValueError: - token = None + source_url_fields = string_to_dict(source_url, pattern) + assert source_url_fields is not None + token = token_per_repo_id.get(source_url_fields["repo_id"]) download_config = DownloadConfig(token=token) with xopen(path, "rb", download_config=download_config) as f: bytes_ = BytesIO(f.read()) @@ -233,7 +248,7 @@ def video_to_bytes(video: "VideoReader") -> bytes: raise NotImplementedError() -def encode_decord_video(video: "VideoReader") -> dict: +def encode_decord_video(video: "VideoReader") -> Example: if hasattr(video, "_hf_encoded"): return video._hf_encoded else: @@ -243,7 +258,7 @@ def encode_decord_video(video: "VideoReader") -> dict: ) -def encode_np_array(array: np.ndarray) -> dict: +def encode_np_array(array: np.ndarray) -> Example: raise NotImplementedError() diff --git a/src/datasets/utils/py_utils.py b/src/datasets/utils/py_utils.py index 55602e26996..9bee99fcbe8 100644 --- a/src/datasets/utils/py_utils.py +++ b/src/datasets/utils/py_utils.py @@ -30,7 +30,7 @@ from pathlib import Path from queue import Empty from shutil import disk_usage -from typing import Any, Callable, Dict, Iterable, List, Optional, Set, Tuple, TypeVar, Union +from typing import Any, Callable, Iterable, List, Optional, Set, Tuple, TypeVar, Union from urllib.parse import urlparse import multiprocess @@ -156,7 +156,7 @@ def glob_pattern_to_regex(pattern): ) -def string_to_dict(string: str, pattern: str) -> Dict[str, str]: +def string_to_dict(string: str, pattern: str) -> Optional[dict[str, str]]: """Un-format a string using a python f-string pattern. From https://stackoverflow.com/a/36838374 @@ -174,15 +174,14 @@ def string_to_dict(string: str, pattern: str) -> Dict[str, str]: pattern (str): pattern formatted like a python f-string Returns: - Dict[str, str]: dictionary of variable -> value, retrieved from the input using the pattern - - Raises: - ValueError: if the string doesn't match the pattern + Optional[dict[str, str]]: dictionary of variable -> value, retrieved from the input using the pattern, or + `None` if the string does not match the pattern. """ + pattern = re.sub(r"{([^:}]+)(?::[^}]+)?}", r"{\1}", pattern) # remove format specifiers, e.g. {rank:05d} -> {rank} regex = re.sub(r"{(.+?)}", r"(?P<_\1>.+)", pattern) result = re.search(regex, string) if result is None: - raise ValueError(f"String {string} doesn't match the pattern {pattern}") + return None values = list(result.groups()) keys = re.findall(r"{(.+?)}", pattern) _dict = dict(zip(keys, values)) diff --git a/tests/test_py_utils.py b/tests/test_py_utils.py index d9d95969aff..d3e7795bf9d 100644 --- a/tests/test_py_utils.py +++ b/tests/test_py_utils.py @@ -1,3 +1,4 @@ +import os import time from dataclasses import dataclass from multiprocessing import Pool @@ -13,6 +14,7 @@ asdict, iflatmap_unordered, map_nested, + string_to_dict, temp_seed, temporary_assignment, zip_dict, @@ -267,3 +269,21 @@ def test_iflatmap_unordered(): assert out.count("a") == 2 assert out.count("b") == 2 assert len(out) == 4 + + +def test_string_to_dict(): + file_name = "dataset/cache-3b163736cf4505085d8b5f9b4c266c26.arrow" + file_name_prefix, file_name_ext = os.path.splitext(file_name) + + suffix_template = "_{rank:05d}_of_{num_proc:05d}" + cache_file_name_pattern = file_name_prefix + suffix_template + file_name_ext + + file_name_parts = string_to_dict(file_name, cache_file_name_pattern) + assert file_name_parts is None + + rank = 1 + num_proc = 2 + file_name = file_name_prefix + suffix_template.format(rank=rank, num_proc=num_proc) + file_name_ext + file_name_parts = string_to_dict(file_name, cache_file_name_pattern) + assert file_name_parts is not None + assert file_name_parts == {"rank": f"{rank:05d}", "num_proc": f"{num_proc:05d}"} From 79dc83bdccf62a0c12970ee1b527aadef1633ca5 Mon Sep 17 00:00:00 2001 From: Matthew Hoffman Date: Tue, 4 Mar 2025 16:16:21 -0600 Subject: [PATCH 5/9] Simplify existing existing_cache_file_map with string_to_dict https://github.com/huggingface/datasets/pull/7434#discussion_r1979719019 --- src/datasets/arrow_dataset.py | 53 +++++++++++++---------------------- 1 file changed, 19 insertions(+), 34 deletions(-) diff --git a/src/datasets/arrow_dataset.py b/src/datasets/arrow_dataset.py index d2358177622..e22118fbdf0 100644 --- a/src/datasets/arrow_dataset.py +++ b/src/datasets/arrow_dataset.py @@ -3052,6 +3052,13 @@ def map( cache_file_name = self._get_cache_file_path(new_fingerprint) dataset_kwargs["cache_file_name"] = cache_file_name + if cache_file_name is not None: + cache_file_prefix, cache_file_ext = os.path.splitext(cache_file_name) + if not cache_file_ext: + raise ValueError(f"Expected cache_file_name to have an extension, but got: {cache_file_name}") + else: + cache_file_prefix = cache_file_ext = None + def load_processed_shard_from_cache(shard_kwargs: dict[str, Any]) -> Dataset: """Load a processed shard from cache if it exists, otherwise throw an error.""" shard = shard_kwargs["shard"] @@ -3072,39 +3079,19 @@ def pbar_total(num_shards: int, batch_size: Optional[int]) -> int: return total // num_shards // batch_size * num_shards * batch_size return total - def get_existing_cache_file_map( - cache_file_name: Optional[str], - ) -> dict[int, list[str]]: - cache_files_by_num_proc: dict[int, list[str]] = defaultdict(list) - if cache_file_name is None: - return cache_files_by_num_proc + existing_cache_file_map: dict[int, list[str]] = defaultdict(list) + if cache_file_name is not None: if os.path.exists(cache_file_name): - cache_files_by_num_proc[1] = [cache_file_name] + existing_cache_file_map[1] = [cache_file_name] - suffix_pattern_parts: list[str] = [] - for literal_text, field_name, format_spec, _ in string_formatter.parse(suffix_template): - suffix_pattern_parts.append(re.escape(literal_text)) - if field_name: - # TODO: we may want to place restrictions on acceptable format_spec or we will fail to match - # someone's hexidecimal or scientific notation format 😵 - suffix_pattern_parts.append(f"(?P<{field_name}>\\d+)") - suffix_pattern = "".join(suffix_pattern_parts) - - cache_file_prefix, cache_file_ext = os.path.splitext(cache_file_name) - if not cache_file_ext: - raise ValueError(f"Expected cache_file_name to have an extension, but got: {cache_file_name}") - - cache_file_pattern = "^" + re.escape(cache_file_prefix) + suffix_pattern + re.escape(cache_file_ext) + "$" - cache_file_regex = re.compile(cache_file_pattern) + assert cache_file_prefix is not None and cache_file_ext is not None + cache_file_with_suffix_pattern = cache_file_prefix + suffix_template + cache_file_ext for cache_file in glob.iglob(f"{cache_file_prefix}*{cache_file_ext}"): - if m := cache_file_regex.match(cache_file): - file_num_proc = int(m.group("num_proc")) - cache_files_by_num_proc[file_num_proc].append(cache_file) - - return cache_files_by_num_proc - - existing_cache_file_map = get_existing_cache_file_map(cache_file_name) + suffix_variable_map = string_to_dict(cache_file, cache_file_with_suffix_pattern) + if suffix_variable_map is not None: + file_num_proc = int(suffix_variable_map["num_proc"]) + existing_cache_file_map[file_num_proc].append(cache_file) num_shards = num_proc or 1 if existing_cache_file_map: @@ -3122,7 +3109,7 @@ def select_existing_cache_files(mapped_num_proc: int) -> tuple[float, ...]: num_shards = min(existing_cache_file_map, key=select_existing_cache_files) - existing_cache_files = existing_cache_file_map.get(num_shards, []) + existing_cache_files = existing_cache_file_map[num_shards] def format_cache_file_name( cache_file_name: Optional[str], @@ -3131,9 +3118,7 @@ def format_cache_file_name( if not cache_file_name: return cache_file_name - cache_file_prefix, cache_file_ext = os.path.splitext(cache_file_name) - if not cache_file_ext: - raise ValueError(f"Expected cache_file_name to have an extension, but got: {cache_file_name}") + assert cache_file_prefix is not None and cache_file_ext is not None if isinstance(rank, int): cache_file_name = ( @@ -5835,7 +5820,7 @@ def push_to_hub( @transmit_format @fingerprint_transform(inplace=False) def add_column( - self, name: str, column: Union[list, np.array], new_fingerprint: str, feature: Optional[FeatureType] = None + self, name: str, column: Union[list, np.ndarray], new_fingerprint: str, feature: Optional[FeatureType] = None ): """Add column to Dataset. From bb7f9b55d6cd795bb7cd7f2d74faa574a2e1456e Mon Sep 17 00:00:00 2001 From: Matthew Hoffman Date: Tue, 4 Mar 2025 16:30:28 -0600 Subject: [PATCH 6/9] Set initial value if there are already existing cache files https://github.com/huggingface/datasets/pull/7434#discussion_r1979716904 --- src/datasets/arrow_dataset.py | 19 +++++++++---------- 1 file changed, 9 insertions(+), 10 deletions(-) diff --git a/src/datasets/arrow_dataset.py b/src/datasets/arrow_dataset.py index e22118fbdf0..fcc978daa33 100644 --- a/src/datasets/arrow_dataset.py +++ b/src/datasets/arrow_dataset.py @@ -3070,15 +3070,6 @@ def load_processed_shard_from_cache(shard_kwargs: dict[str, Any]) -> Dataset: return Dataset.from_file(shard_kwargs["cache_file_name"], info=info, split=shard.split) raise NonExistentDatasetError - def pbar_total(num_shards: int, batch_size: Optional[int]) -> int: - total = len(self) - if len(existing_cache_files) < num_shards: - total -= len(existing_cache_files) * total // num_shards - if batched and drop_last_batch: - batch_size = batch_size or 1 - return total // num_shards // batch_size * num_shards * batch_size - return total - existing_cache_file_map: dict[int, list[str]] = defaultdict(list) if cache_file_name is not None: if os.path.exists(cache_file_name): @@ -3199,9 +3190,17 @@ def format_new_fingerprint(new_fingerprint: str, rank: int) -> str: "missing from the cache." ) + pbar_total = len(self) + pbar_initial = len(existing_cache_files) * pbar_total // num_shards + if batched and drop_last_batch: + batch_size = batch_size or 1 + pbar_initial = pbar_initial // num_shards // batch_size * num_shards * batch_size + pbar_total = pbar_total // num_shards // batch_size * num_shards * batch_size + with hf_tqdm( unit=" examples", - total=pbar_total(num_shards, batch_size), + initial=pbar_initial, + total=pbar_total, desc=(desc or "Map") + (f" (num_proc={num_proc})" if num_proc is not None and num_proc > 1 else ""), ) as pbar: shards_done = 0 From c82cab4a3d373e48502d7c17a33c48a9b5af17cc Mon Sep 17 00:00:00 2001 From: Matthew Hoffman Date: Fri, 7 Mar 2025 14:19:11 -0600 Subject: [PATCH 7/9] Allow for source_url_fields to be None they can be local file paths here https://github.com/huggingface/datasets/actions/runs/13683185040/job/38380924390?pr=7435#step:10:9731 --- src/datasets/features/audio.py | 3 +-- src/datasets/features/image.py | 5 +++-- src/datasets/features/video.py | 3 +-- 3 files changed, 5 insertions(+), 6 deletions(-) diff --git a/src/datasets/features/audio.py b/src/datasets/features/audio.py index ccad84a89e3..349d89f3e1b 100644 --- a/src/datasets/features/audio.py +++ b/src/datasets/features/audio.py @@ -174,8 +174,7 @@ def decode_example( config.HUB_DATASETS_URL if source_url.startswith(config.HF_ENDPOINT) else config.HUB_DATASETS_HFFS_URL ) source_url_fields = string_to_dict(source_url, pattern) - assert source_url_fields is not None - token = token_per_repo_id.get(source_url_fields["repo_id"]) + token = token_per_repo_id.get(source_url_fields["repo_id"]) if source_url_fields is not None else None download_config = DownloadConfig(token=token) with xopen(path, "rb", download_config=download_config) as f: diff --git a/src/datasets/features/image.py b/src/datasets/features/image.py index c1efd8566b7..9682258b271 100644 --- a/src/datasets/features/image.py +++ b/src/datasets/features/image.py @@ -175,8 +175,9 @@ def decode_example(self, value: dict, token_per_repo_id=None) -> "PIL.Image.Imag else config.HUB_DATASETS_HFFS_URL ) source_url_fields = string_to_dict(source_url, pattern) - assert source_url_fields is not None - token = token_per_repo_id.get(source_url_fields["repo_id"]) + token = ( + token_per_repo_id.get(source_url_fields["repo_id"]) if source_url_fields is not None else None + ) download_config = DownloadConfig(token=token) with xopen(path, "rb", download_config=download_config) as f: bytes_ = BytesIO(f.read()) diff --git a/src/datasets/features/video.py b/src/datasets/features/video.py index 8e42ec71074..11eae4812b3 100644 --- a/src/datasets/features/video.py +++ b/src/datasets/features/video.py @@ -263,8 +263,7 @@ def hf_video_reader( source_url = path.split("::")[-1] pattern = config.HUB_DATASETS_URL if source_url.startswith(config.HF_ENDPOINT) else config.HUB_DATASETS_HFFS_URL source_url_fields = string_to_dict(source_url, pattern) - assert source_url_fields is not None - token = token_per_repo_id.get(source_url_fields["repo_id"]) + token = token_per_repo_id.get(source_url_fields["repo_id"]) if source_url_fields is not None else None download_config = DownloadConfig(token=token) f = xopen(path, "rb", download_config=download_config) From 637c1600fe7dd601eff571fda446937bd96c5c84 Mon Sep 17 00:00:00 2001 From: Matthew Hoffman Date: Wed, 12 Mar 2025 13:51:47 -0500 Subject: [PATCH 8/9] Add unicode escape to handle parsing string_to_dict in Windows paths --- src/datasets/utils/py_utils.py | 1 + 1 file changed, 1 insertion(+) diff --git a/src/datasets/utils/py_utils.py b/src/datasets/utils/py_utils.py index d954f548f22..79606f71cd2 100644 --- a/src/datasets/utils/py_utils.py +++ b/src/datasets/utils/py_utils.py @@ -180,6 +180,7 @@ def string_to_dict(string: str, pattern: str) -> Optional[dict[str, str]]: Optional[dict[str, str]]: dictionary of variable -> value, retrieved from the input using the pattern, or `None` if the string does not match the pattern. """ + pattern = pattern.encode("unicode_escape").decode("utf-8") # C:\\Users -> C:\\\\Users for Windows paths pattern = re.sub(r"{([^:}]+)(?::[^}]+)?}", r"{\1}", pattern) # remove format specifiers, e.g. {rank:05d} -> {rank} regex = re.sub(r"{(.+?)}", r"(?P<_\1>.+)", pattern) result = re.search(regex, string) From 583c28e7560b9d6db2e13048731f41ec8fa11361 Mon Sep 17 00:00:00 2001 From: Matthew Hoffman Date: Fri, 14 Mar 2025 13:12:11 -0500 Subject: [PATCH 9/9] Remove glob_pattern_to_regex All the tests still pass when it is removed; I think the unicode escaping must do some of the work that glob_pattern_to_regex was doing here before --- src/datasets/data_files.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/datasets/data_files.py b/src/datasets/data_files.py index 2a1caa81bf7..bedbba7a795 100644 --- a/src/datasets/data_files.py +++ b/src/datasets/data_files.py @@ -19,7 +19,7 @@ from .utils import logging from .utils import tqdm as hf_tqdm from .utils.file_utils import _prepare_path_and_storage_options, is_local_path, is_relative_path, xbasename, xjoin -from .utils.py_utils import glob_pattern_to_regex, string_to_dict +from .utils.py_utils import string_to_dict SingleOriginMetadata = Union[tuple[str, str], tuple[str], tuple[()]] @@ -266,7 +266,7 @@ def _get_data_files_patterns(pattern_resolver: Callable[[str], list[str]]) -> di if len(data_files) > 0: splits: set[str] = set() for p in data_files: - p_parts = string_to_dict(xbasename(p), glob_pattern_to_regex(xbasename(split_pattern))) + p_parts = string_to_dict(xbasename(p), xbasename(split_pattern)) assert p_parts is not None splits.add(p_parts["split"])