From d6c502ffb4ed5f4342fe2e435f2ed1bd3c9eb5f7 Mon Sep 17 00:00:00 2001 From: Daniel Smilkov Date: Thu, 28 Dec 2023 17:42:12 -0500 Subject: [PATCH] save --- lilac/batch_utils.py | 27 +++++++++++++++++++++++++++ 1 file changed, 27 insertions(+) diff --git a/lilac/batch_utils.py b/lilac/batch_utils.py index 88aad43ab..c7b7a88e9 100644 --- a/lilac/batch_utils.py +++ b/lilac/batch_utils.py @@ -76,6 +76,33 @@ def unflatten( TFlatBatchedOutput = TypeVar('TFlatBatchedOutput') +def group_by_sorted_key_iter( + input: Iterator[Item], key_fn: Callable[[Item], object] +) -> Iterator[list[Item]]: + """Takes a key-sorted iterator and yields each group of items sharing the same key. + + Args: + input: An iterator of items, ordered by the key. + key_fn: A function that takes an item and returns a key. + + Yields: + A list of items sharing the same key, for each group. + """ + last_key: object = None + last_group: list[Item] = [] + for item in input: + key = key_fn(item) + if key != last_key: + if last_group: + yield last_group + last_group = [item] + last_key = key + else: + last_group.append(item) + if last_group: + yield last_group + + def flat_batched_compute( input: Iterable[Iterable[TFlatBatchedInput]], f: Callable[[list[TFlatBatchedInput]], Iterable[TFlatBatchedOutput]],