diff --git a/tensordict/_td.py b/tensordict/_td.py index a8f3ddce7..add07b1cf 100644 --- a/tensordict/_td.py +++ b/tensordict/_td.py @@ -1302,7 +1302,7 @@ def _apply_nest( nested_keys: bool = False, prefix: tuple = (), filter_empty: bool | None = None, - is_leaf: Callable = None, + is_leaf: Callable | None = None, out: TensorDictBase | None = None, **constructor_kwargs, ) -> T | None: @@ -1319,9 +1319,18 @@ def _apply_nest( "batch_size and out.batch_size must be equal when both are provided." ) if device is not NO_DEFAULT and device != out.device: - raise RuntimeError( - "device and out.device must be equal when both are provided." - ) + if not checked: + raise RuntimeError( + f"device and out.device must be equal when both are provided. Got device={device} and out.device={out.device}." + ) + else: + device = torch.device(device) + out._device = device + for node in out.values(True, True, is_leaf=_is_tensor_collection): + if is_tensorclass(node): + node._tensordict._device = device + else: + node._device = device else: def make_result(names=names, batch_size=batch_size): diff --git a/tensordict/base.py b/tensordict/base.py index e8898a01a..1b328aaa1 100644 --- a/tensordict/base.py +++ b/tensordict/base.py @@ -10586,6 +10586,8 @@ def to(tensor): else: apply_kwargs["device"] = device if device is not None else self.device apply_kwargs["batch_size"] = batch_size + apply_kwargs["out"] = self if inplace else None + apply_kwargs["checked"] = True if non_blocking_pin: def to_pinmem(tensor, _to=to): @@ -10595,7 +10597,19 @@ def to_pinmem(tensor, _to=to): to_pinmem, propagate_lock=True, **apply_kwargs ) else: - result = result._fast_apply(to, propagate_lock=True, **apply_kwargs) + # result = result._fast_apply(to, propagate_lock=True, **apply_kwargs) + keys, tensors = self._items_list(True, True) + tensors = [to(t) for t in tensors] + items = dict(zip(keys, tensors)) + result = self._fast_apply( + lambda name, val: items.get(name, val), + named=True, + nested_keys=True, + is_leaf=_NESTED_TENSORS_AS_LISTS, + propagate_lock=True, + **apply_kwargs, + ) + if batch_size is not None: result.batch_size = batch_size if (