From 699ca7b0e72e921752da2a31b5bfb9764c9cf2ee Mon Sep 17 00:00:00 2001 From: Julio Perez <37191411+jperez999@users.noreply.github.com> Date: Wed, 5 Apr 2023 15:41:57 -0400 Subject: [PATCH] fixing imports for tf and torch in dataloader (#127) * fixing imports for tf and torch in dataloader * fix tests imports of frameworks --- merlin/dataloader/tensorflow.py | 2 +- merlin/dataloader/torch.py | 2 +- tests/unit/dataloader/test_array_to_tensorflow.py | 2 +- tests/unit/dataloader/test_array_to_torch.py | 2 +- 4 files changed, 4 insertions(+), 4 deletions(-) diff --git a/merlin/dataloader/tensorflow.py b/merlin/dataloader/tensorflow.py index 756c3ce0..cbd9e154 100644 --- a/merlin/dataloader/tensorflow.py +++ b/merlin/dataloader/tensorflow.py @@ -15,7 +15,7 @@ # from functools import partial -from merlin.core.compat import tensorflow as tf +from merlin.core.compat.tensorflow import tensorflow as tf from merlin.dataloader.loader_base import LoaderBase from merlin.table import TensorColumn, TensorflowColumn, TensorTable from merlin.table.conversions import _dispatch_dlpack_fns, convert_col diff --git a/merlin/dataloader/torch.py b/merlin/dataloader/torch.py index 4297f558..148126ec 100644 --- a/merlin/dataloader/torch.py +++ b/merlin/dataloader/torch.py @@ -15,7 +15,7 @@ # from functools import partial -from merlin.core.compat import torch as th +from merlin.core.compat.torch import torch as th from merlin.dataloader.loader_base import LoaderBase from merlin.table import TensorColumn, TensorTable, TorchColumn from merlin.table.conversions import _dispatch_dlpack_fns, convert_col diff --git a/tests/unit/dataloader/test_array_to_tensorflow.py b/tests/unit/dataloader/test_array_to_tensorflow.py index d17582c7..d24d3026 100644 --- a/tests/unit/dataloader/test_array_to_tensorflow.py +++ b/tests/unit/dataloader/test_array_to_tensorflow.py @@ -16,7 +16,7 @@ import pytest -from merlin.core.compat import tensorflow as tf +from merlin.core.compat.tensorflow import tensorflow as tf from merlin.core.dispatch import make_df from merlin.io import Dataset from merlin.schema import Tags diff --git a/tests/unit/dataloader/test_array_to_torch.py b/tests/unit/dataloader/test_array_to_torch.py index 3207eff0..c7200bee 100644 --- a/tests/unit/dataloader/test_array_to_torch.py +++ b/tests/unit/dataloader/test_array_to_torch.py @@ -15,7 +15,7 @@ # import pytest -from merlin.core.compat import torch as th +from merlin.core.compat.torch import torch as th from merlin.core.dispatch import make_df from merlin.io import Dataset from merlin.schema import Tags