Skip to content

Commit

Permalink
Fix iterable.Cached. (#3060)
Browse files Browse the repository at this point in the history
  • Loading branch information
Jasper authored Nov 26, 2023
1 parent 01f6787 commit 07ef9f0
Show file tree
Hide file tree
Showing 2 changed files with 25 additions and 10 deletions.
25 changes: 15 additions & 10 deletions src/gluonts/itertools.py
Original file line number Diff line number Diff line change
Expand Up @@ -305,10 +305,9 @@ def split_into(xs: Sequence, n: int) -> Sequence:
@dataclass
class Cached:
"""
An iterable wrapper, which caches values in a list the first time it is
iterated.
An iterable wrapper, which caches values in a list while iterated.
The primary use-case for this is to avoid re-computing the element of the
The primary use-case for this is to avoid re-computing the elements of the
sequence, in case the inner iterable does it on demand.
This should be used to wrap deterministic iterables, i.e. iterables where
Expand All @@ -317,15 +316,21 @@ class Cached:
"""

iterable: SizedIterable
cache: list = field(default_factory=list, init=False)
provider: Iterable = field(init=False)
consumed: list = field(default_factory=list, init=False)

def __post_init__(self):
# ensure we only iterate once over the iterable
self.provider = iter(self.iterable)

def __iter__(self):
if not self.cache:
for element in self.iterable:
yield element
self.cache.append(element)
else:
yield from self.cache
# Yield already provided values first
yield from self.consumed

# Now yield remaining elements.
for element in self.provider:
self.consumed.append(element)
yield element

def __len__(self) -> int:
return len(self.iterable)
Expand Down
10 changes: 10 additions & 0 deletions test/test_itertools.py
Original file line number Diff line number Diff line change
Expand Up @@ -119,6 +119,16 @@ def test_pickle(iterable: Iterable, assert_content: bool):
assert data == data_copy


def test_cached_reentry():
data = Cached(range(10))

assert len(data) == 10
assert list(take(5, data)) == list(range(5))
assert len(data) == 10
assert list(take(10, data)) == list(range(10))
assert len(data) == 10


@pytest.mark.parametrize(
"given, expected",
[
Expand Down

0 comments on commit 07ef9f0

Please # to comment.