Skip to content

Commit

Permalink
Implement _sum method fo jax
Browse files Browse the repository at this point in the history
  • Loading branch information
oliverholworthy committed Feb 20, 2023
1 parent 3fe94b6 commit 1242ecd
Showing 1 changed file with 3 additions and 0 deletions.
3 changes: 3 additions & 0 deletions merlin/dataloader/jax.py
Original file line number Diff line number Diff line change
Expand Up @@ -69,6 +69,9 @@ def _split_fn(self, tensor, idx, axis=0):

_tensor_split = _split_fn

def _sum(self, tensor):
return tensor.sum()

def _to_tensor(self, gdf):
if gdf.empty:
return
Expand Down

0 comments on commit 1242ecd

Please # to comment.