Skip to content

Commit

Permalink
Apply map function when peeking
Browse files Browse the repository at this point in the history
  • Loading branch information
edknv committed Apr 6, 2023
1 parent 699ca7b commit 8223910
Show file tree
Hide file tree
Showing 2 changed files with 20 additions and 1 deletion.
6 changes: 5 additions & 1 deletion merlin/dataloader/tensorflow.py
Original file line number Diff line number Diff line change
Expand Up @@ -88,7 +88,11 @@ def __next__(self):
def peek(self):
"""Grab the next batch from the dataloader
without removing it from the queue"""
return self.convert_batch(self._peek_next_batch())
converted_batch = self.convert_batch(self._peek_next_batch())
for map_fn in self._map_fns:
converted_batch = map_fn(*converted_batch)

return converted_batch

def on_epoch_end(self):
self.stop()
Expand Down
15 changes: 15 additions & 0 deletions tests/unit/dataloader/test_tf_dataloader.py
Original file line number Diff line number Diff line change
Expand Up @@ -74,6 +74,21 @@ def test_peek():
assert len(all_batches) == 3


def test_peek_map():
inputs = make_df({"a": [1, 999, 999]})
dataset = Dataset(inputs)

def _map_fn(x_in, y_in):
x_out = make_df({"b": [42]})
y_out = make_df({"c": [43]})
return x_out, y_out

with tf_loader(dataset, batch_size=1, shuffle=False).map(_map_fn) as loader:
x, y = loader.peek()
assert set(x.keys()) == {"b"}
assert set(y.keys()) == {"c"}


def test_set_input_schema():
df = make_df({"a": [1, 2, 3], "b": [[4], [5, 6], [7]]})
dataset = Dataset(df)
Expand Down

0 comments on commit 8223910

Please # to comment.