Skip to content

Commit

Permalink
save
Browse files Browse the repository at this point in the history
  • Loading branch information
dsmilkov committed Dec 28, 2023
1 parent cd003cf commit d6c502f
Showing 1 changed file with 27 additions and 0 deletions.
27 changes: 27 additions & 0 deletions lilac/batch_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -76,6 +76,33 @@ def unflatten(
TFlatBatchedOutput = TypeVar('TFlatBatchedOutput')


def group_by_sorted_key_iter(
input: Iterator[Item], key_fn: Callable[[Item], object]
) -> Iterator[list[Item]]:
"""Takes a key-sorted iterator and yields each group of items sharing the same key.
Args:
input: An iterator of items, ordered by the key.
key_fn: A function that takes an item and returns a key.
Yields:
A list of items sharing the same key, for each group.
"""
last_key: object = None
last_group: list[Item] = []
for item in input:
key = key_fn(item)
if key != last_key:
if last_group:
yield last_group
last_group = [item]
last_key = key
else:
last_group.append(item)
if last_group:
yield last_group


def flat_batched_compute(
input: Iterable[Iterable[TFlatBatchedInput]],
f: Callable[[list[TFlatBatchedInput]], Iterable[TFlatBatchedOutput]],
Expand Down

0 comments on commit d6c502f

Please # to comment.