From e00965c8df2057ac02e1c90380cf0ca17640454c Mon Sep 17 00:00:00 2001 From: Vincent Moens Date: Wed, 6 Nov 2024 14:53:35 +0000 Subject: [PATCH] [BugFix] Better repr of lazy stacks ghstack-source-id: 7256b4c95b239bf9e6467c0ea687abe2c9179922 Pull Request resolved: https://github.com/pytorch/tensordict/pull/1076 (cherry picked from commit eaba7110be649c5fa23d0bc1fe87fb56b8bcc554) --- tensordict/utils.py | 12 ++++++++++-- test/test_tensordict.py | 37 +++++++++++++++++++++++++++++++++++++ 2 files changed, 47 insertions(+), 2 deletions(-) diff --git a/tensordict/utils.py b/tensordict/utils.py index 81f0ecb00..ce50b6fcd 100644 --- a/tensordict/utils.py +++ b/tensordict/utils.py @@ -1615,13 +1615,21 @@ def _td_fields(td: T, keys=None, sep=": ") -> str: # we know td is lazy stacked and the key is a leaf # so we can get the shape and escape the error temp_td = td - from tensordict import LazyStackedTensorDict, TensorDictBase + from tensordict import ( + is_tensor_collection, + LazyStackedTensorDict, + TensorDictBase, + ) while isinstance( temp_td, LazyStackedTensorDict - ): # we need to grab the het tensor from the inner nesting level + ): # we need to grab the heterogeneous tensor from the inner nesting level temp_td = temp_td.tensordicts[0] tensor = temp_td.get(key) + if is_tensor_collection(tensor): + tensor = td.get(key) + strs.append(_make_repr(key, tensor, td, sep=sep)) + continue if isinstance(tensor, TensorDictBase): substr = _td_fields(tensor) diff --git a/test/test_tensordict.py b/test/test_tensordict.py index 697590b14..928b5d317 100644 --- a/test/test_tensordict.py +++ b/test/test_tensordict.py @@ -7168,6 +7168,43 @@ def test_repr_nested(self, device, dtype): is_shared={is_shared})""" assert repr(nested_td) == expected + def test_repr_nested_lazy(self, device, dtype): + nested_td0 = self.nested_td(device, dtype) + nested_td1 = torch.cat([nested_td0, nested_td0], 1) + nested_td1["my_nested_td", "another"] = nested_td1["my_nested_td", "a"] + lazy_nested_td = TensorDict.lazy_stack([nested_td0, nested_td1], dim=1) + + if device is not None and device.type == "cuda": + is_shared = True + else: + is_shared = False + tensor_class = "Tensor" + tensor_device = device if device else nested_td0[:, 0]["b"].device + if tensor_device.type == "cuda": + is_shared_tensor = True + else: + is_shared_tensor = is_shared + expected = f"""LazyStackedTensorDict( + fields={{ + b: {tensor_class}(shape=torch.Size([4, 2, -1, 2, 1, 5]), device={tensor_device}, dtype={dtype}, is_shared={is_shared_tensor}), + my_nested_td: LazyStackedTensorDict( + fields={{ + a: {tensor_class}(shape=torch.Size([4, 2, -1, 2, 1, 5]), device={tensor_device}, dtype={dtype}, is_shared={is_shared_tensor})}}, + exclusive_fields={{ + 1 -> + another: Tensor(shape=torch.Size([4, 6, 2, 1, 5]), device={tensor_device}, dtype={dtype}, is_shared={is_shared_tensor})}}, + batch_size=torch.Size([4, 2, -1, 2, 1]), + device={str(device)}, + is_shared={is_shared}, + stack_dim=1)}}, + exclusive_fields={{ + }}, + batch_size=torch.Size([4, 2, -1, 2, 1]), + device={str(device)}, + is_shared={is_shared}, + stack_dim=1)""" + assert repr(lazy_nested_td) == expected + def test_repr_nested_update(self, device, dtype): nested_td = self.nested_td(device, dtype) nested_td["my_nested_td"].rename_key_("a", "z")