From 07ef9f0c88e3ce96ec5a3cf21c127ffc0efe9aa3 Mon Sep 17 00:00:00 2001 From: Jasper Date: Sun, 26 Nov 2023 11:48:05 +0100 Subject: [PATCH] Fix `iterable.Cached`. (#3060) --- src/gluonts/itertools.py | 25 +++++++++++++++---------- test/test_itertools.py | 10 ++++++++++ 2 files changed, 25 insertions(+), 10 deletions(-) diff --git a/src/gluonts/itertools.py b/src/gluonts/itertools.py index 8ed8d245cb..e90281a1d8 100644 --- a/src/gluonts/itertools.py +++ b/src/gluonts/itertools.py @@ -305,10 +305,9 @@ def split_into(xs: Sequence, n: int) -> Sequence: @dataclass class Cached: """ - An iterable wrapper, which caches values in a list the first time it is - iterated. + An iterable wrapper, which caches values in a list while iterated. - The primary use-case for this is to avoid re-computing the element of the + The primary use-case for this is to avoid re-computing the elements of the sequence, in case the inner iterable does it on demand. This should be used to wrap deterministic iterables, i.e. iterables where @@ -317,15 +316,21 @@ class Cached: """ iterable: SizedIterable - cache: list = field(default_factory=list, init=False) + provider: Iterable = field(init=False) + consumed: list = field(default_factory=list, init=False) + + def __post_init__(self): + # ensure we only iterate once over the iterable + self.provider = iter(self.iterable) def __iter__(self): - if not self.cache: - for element in self.iterable: - yield element - self.cache.append(element) - else: - yield from self.cache + # Yield already provided values first + yield from self.consumed + + # Now yield remaining elements. + for element in self.provider: + self.consumed.append(element) + yield element def __len__(self) -> int: return len(self.iterable) diff --git a/test/test_itertools.py b/test/test_itertools.py index 6cef2210f5..b9071d9ad0 100644 --- a/test/test_itertools.py +++ b/test/test_itertools.py @@ -119,6 +119,16 @@ def test_pickle(iterable: Iterable, assert_content: bool): assert data == data_copy +def test_cached_reentry(): + data = Cached(range(10)) + + assert len(data) == 10 + assert list(take(5, data)) == list(range(5)) + assert len(data) == 10 + assert list(take(10, data)) == list(range(10)) + assert len(data) == 10 + + @pytest.mark.parametrize( "given, expected", [