Skip to content

Commit

Permalink
Store values/row_lengths in dict in make_tensors
Browse files Browse the repository at this point in the history
  • Loading branch information
oliverholworthy committed Feb 20, 2023
1 parent fb7f6f5 commit 3fe94b6
Showing 1 changed file with 8 additions and 4 deletions.
12 changes: 8 additions & 4 deletions merlin/dataloader/loader_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -367,17 +367,21 @@ def make_tensors(self, gdf, use_row_lengths=False):
batch_row_lengths = self._split_fn(row_lengths, split_idx)
values_split_idx = [self._sum(row_lengths) for row_lengths in batch_row_lengths]
batch_values = self._split_fn(values, values_split_idx)
tensor_batches[tensor_key] = (batch_values, batch_row_lengths)
tensor_batches[tensor_key] = {
"values": batch_values,
"row_lengths": batch_row_lengths,
}
else:
tensor_batches[tensor_key] = self._split_fn(tensor_value, split_idx)

for batch_idx in range(len(split_idx)):
batch = {}
for tensor_key in tensors_by_name:
tensor_value = tensor_batches[tensor_key]
if isinstance(tensor_value, tuple):
batch[tensor_key] = tuple(
tuple_value[batch_idx] for tuple_value in tensor_value
if isinstance(tensor_value, dict):
batch[tensor_key] = (
tensor_value["values"][batch_idx],
tensor_value["row_lengths"][batch_idx],
)
else:
batch[tensor_key] = tensor_value[batch_idx]
Expand Down

0 comments on commit 3fe94b6

Please # to comment.