diff --git a/merlin/dataloader/loader_base.py b/merlin/dataloader/loader_base.py index 171c3ece..81f5fd5c 100644 --- a/merlin/dataloader/loader_base.py +++ b/merlin/dataloader/loader_base.py @@ -447,8 +447,7 @@ def _to_tensor(self, df_or_series): if self.device == "cpu": tensor = df_or_series.to_numpy() else: - with cupy.cuda.Device(self.device): - tensor = df_or_series.to_cupy() + tensor = df_or_series.to_cupy() return tensor diff --git a/merlin/dataloader/torch.py b/merlin/dataloader/torch.py index 6921ce7e..148126ec 100644 --- a/merlin/dataloader/torch.py +++ b/merlin/dataloader/torch.py @@ -13,7 +13,6 @@ # See the License for the specific language governing permissions and # limitations under the License. # -import contextlib from functools import partial from merlin.core.compat.torch import torch as th @@ -119,11 +118,6 @@ def map(self, fn): return self - def _get_device_ctx(self, dev): - if dev == "cpu" or not th: - return contextlib.nullcontext() - return th.cuda.device(f"cuda:{dev}") - class DLDataLoader(th.utils.data.DataLoader): """