diff --git a/merlin/dataloader/torch.py b/merlin/dataloader/torch.py index 8cbb343f..6807f65c 100644 --- a/merlin/dataloader/torch.py +++ b/merlin/dataloader/torch.py @@ -93,6 +93,17 @@ def __init__( transforms=transforms, device=device, ) + self._map_fns = [] + + def map(self, fn): + """ + Applying a function to each batch. + + This can for instance be used to add `sample_weight` to the model. + """ + self._map_fns.append(fn) + + return self def __iter__(self): return LoaderBase.__iter__(self) @@ -165,6 +176,14 @@ def _build_sparse_tensor( sparse_tensor = sparse_tensor.to_dense() return sparse_tensor + def _handle_tensors(self, tensors, tensor_names): + to_return = super()._handle_tensors(tensors, tensor_names) + + for map_fn in self._map_fns: + to_return = map_fn(*to_return) + + return to_return + def _cast_to_numpy_dtype(self, dtype): """ Get the numpy dtype from the framework dtype. diff --git a/tests/unit/dataloader/test_torch_dataloader.py b/tests/unit/dataloader/test_torch_dataloader.py index cf175e98..c644d7b5 100644 --- a/tests/unit/dataloader/test_torch_dataloader.py +++ b/tests/unit/dataloader/test_torch_dataloader.py @@ -328,3 +328,41 @@ def test_dataloader_schema(df, dataset, batch_size, cpu): num_label_cols = y.shape[1] if len(y.shape) > 1 else 1 assert num_label_cols == 1 + + +def test_torch_map(tmpdir): + df = make_df( + { + "cat1": [1] * 100, + "cat2": [2] * 100, + "cat3": [3] * 100, + "label": [0] * 100, + "sample_weight": [1.0] * 100, + "cont2": [2.0] * 100, + "cont1": [1.0] * 100, + } + ) + path = os.path.join(tmpdir, "dataset.parquet") + df.to_parquet(path) + ds = Dataset(df) + ds.schema["label"] = ds.schema["label"].with_tags(Tags.TARGET) + + def add_sample_weight(features, labels, sample_weight_col_name="sample_weight"): + sample_weight = features.pop(sample_weight_col_name) + + return features, labels, sample_weight + + data_itr = torch_dataloader.Loader( + ds, + batch_size=10, + shuffle=False, + ).map(add_sample_weight) + + for X, y, sample_weight in data_itr: + assert list(X["cat1"].cpu().numpy()) == [1] * 10 + assert list(X["cat2"].cpu().numpy()) == [2] * 10 + assert list(X["cat3"].cpu().numpy()) == [3] * 10 + assert list(X["cont1"].cpu().numpy()) == [1.0] * 10 + assert list(X["cont2"].cpu().numpy()) == [2.0] * 10 + + assert list(sample_weight.cpu().numpy()) == [1.0] * 10