diff --git a/merlin/dataloader/tensorflow.py b/merlin/dataloader/tensorflow.py index cbd9e154..093339ed 100644 --- a/merlin/dataloader/tensorflow.py +++ b/merlin/dataloader/tensorflow.py @@ -88,7 +88,11 @@ def __next__(self): def peek(self): """Grab the next batch from the dataloader without removing it from the queue""" - return self.convert_batch(self._peek_next_batch()) + converted_batch = self.convert_batch(self._peek_next_batch()) + for map_fn in self._map_fns: + converted_batch = map_fn(*converted_batch) + + return converted_batch def on_epoch_end(self): self.stop() diff --git a/tests/unit/dataloader/test_tf_dataloader.py b/tests/unit/dataloader/test_tf_dataloader.py index 382d5251..b60c4f3c 100644 --- a/tests/unit/dataloader/test_tf_dataloader.py +++ b/tests/unit/dataloader/test_tf_dataloader.py @@ -74,6 +74,21 @@ def test_peek(): assert len(all_batches) == 3 +def test_peek_map(): + inputs = make_df({"a": [1, 999, 999]}) + dataset = Dataset(inputs) + + def _map_fn(x_in, y_in): + x_out = make_df({"b": [42]}) + y_out = make_df({"c": [43]}) + return x_out, y_out + + with tf_loader(dataset, batch_size=1, shuffle=False).map(_map_fn) as loader: + x, y = loader.peek() + assert set(x.keys()) == {"b"} + assert set(y.keys()) == {"c"} + + def test_set_input_schema(): df = make_df({"a": [1, 2, 3], "b": [[4], [5, 6], [7]]}) dataset = Dataset(df)