diff --git a/tensordict/utils.py b/tensordict/utils.py index 81f0ecb00..ce50b6fcd 100644 --- a/tensordict/utils.py +++ b/tensordict/utils.py @@ -1615,13 +1615,21 @@ def _td_fields(td: T, keys=None, sep=": ") -> str: # we know td is lazy stacked and the key is a leaf # so we can get the shape and escape the error temp_td = td - from tensordict import LazyStackedTensorDict, TensorDictBase + from tensordict import ( + is_tensor_collection, + LazyStackedTensorDict, + TensorDictBase, + ) while isinstance( temp_td, LazyStackedTensorDict - ): # we need to grab the het tensor from the inner nesting level + ): # we need to grab the heterogeneous tensor from the inner nesting level temp_td = temp_td.tensordicts[0] tensor = temp_td.get(key) + if is_tensor_collection(tensor): + tensor = td.get(key) + strs.append(_make_repr(key, tensor, td, sep=sep)) + continue if isinstance(tensor, TensorDictBase): substr = _td_fields(tensor) diff --git a/test/test_tensordict.py b/test/test_tensordict.py index 697590b14..928b5d317 100644 --- a/test/test_tensordict.py +++ b/test/test_tensordict.py @@ -7168,6 +7168,43 @@ def test_repr_nested(self, device, dtype): is_shared={is_shared})""" assert repr(nested_td) == expected + def test_repr_nested_lazy(self, device, dtype): + nested_td0 = self.nested_td(device, dtype) + nested_td1 = torch.cat([nested_td0, nested_td0], 1) + nested_td1["my_nested_td", "another"] = nested_td1["my_nested_td", "a"] + lazy_nested_td = TensorDict.lazy_stack([nested_td0, nested_td1], dim=1) + + if device is not None and device.type == "cuda": + is_shared = True + else: + is_shared = False + tensor_class = "Tensor" + tensor_device = device if device else nested_td0[:, 0]["b"].device + if tensor_device.type == "cuda": + is_shared_tensor = True + else: + is_shared_tensor = is_shared + expected = f"""LazyStackedTensorDict( + fields={{ + b: {tensor_class}(shape=torch.Size([4, 2, -1, 2, 1, 5]), device={tensor_device}, dtype={dtype}, is_shared={is_shared_tensor}), + my_nested_td: LazyStackedTensorDict( + fields={{ + a: {tensor_class}(shape=torch.Size([4, 2, -1, 2, 1, 5]), device={tensor_device}, dtype={dtype}, is_shared={is_shared_tensor})}}, + exclusive_fields={{ + 1 -> + another: Tensor(shape=torch.Size([4, 6, 2, 1, 5]), device={tensor_device}, dtype={dtype}, is_shared={is_shared_tensor})}}, + batch_size=torch.Size([4, 2, -1, 2, 1]), + device={str(device)}, + is_shared={is_shared}, + stack_dim=1)}}, + exclusive_fields={{ + }}, + batch_size=torch.Size([4, 2, -1, 2, 1]), + device={str(device)}, + is_shared={is_shared}, + stack_dim=1)""" + assert repr(lazy_nested_td) == expected + def test_repr_nested_update(self, device, dtype): nested_td = self.nested_td(device, dtype) nested_td["my_nested_td"].rename_key_("a", "z")