From 1242ecd002512028573025638de4a13bbedfb8a9 Mon Sep 17 00:00:00 2001 From: Oliver Holworthy Date: Mon, 20 Feb 2023 16:10:00 +0000 Subject: [PATCH] Implement _sum method fo jax --- merlin/dataloader/jax.py | 3 +++ 1 file changed, 3 insertions(+) diff --git a/merlin/dataloader/jax.py b/merlin/dataloader/jax.py index b4c9733c..7db6ce9a 100644 --- a/merlin/dataloader/jax.py +++ b/merlin/dataloader/jax.py @@ -69,6 +69,9 @@ def _split_fn(self, tensor, idx, axis=0): _tensor_split = _split_fn + def _sum(self, tensor): + return tensor.sum() + def _to_tensor(self, gdf): if gdf.empty: return