From 1ef1188d0c1f81827dab8d53d1ccb9a0c1996f05 Mon Sep 17 00:00:00 2001 From: Vincent Moens Date: Fri, 8 Nov 2024 15:11:23 +0000 Subject: [PATCH] [Feature] Better logs of key errors in assert_close ghstack-source-id: 46cb41d0da34b17ccc248119c43ddba586d29d80 Pull Request resolved: https://github.com/pytorch/tensordict/pull/1082 (cherry picked from commit 747c593b29177fde2f383212b33607b97ab6b2ae) --- tensordict/utils.py | 34 +++++++++++++++++++++++----------- 1 file changed, 23 insertions(+), 11 deletions(-) diff --git a/tensordict/utils.py b/tensordict/utils.py index ea8bba786..e43e4e8fd 100644 --- a/tensordict/utils.py +++ b/tensordict/utils.py @@ -1499,6 +1499,7 @@ def assert_close( equal_nan: bool = True, intersection: bool = False, msg: str = "", + prefix: NestedKey = (), ) -> bool: """Asserts that two tensordicts, `actual` and `expected`, are element-wise equal within a tolerance for all entries. @@ -1516,6 +1517,7 @@ def assert_close( intersection (bool, optional): If True, only the intersection of the two tensordicts will be compared. Default is ``False``. msg (str, optional): An optional message to include in the assertion error if the check fails. + prefix (NestedKey, optional): a prefix to add to the key for error messages. Returns: bool: True if the tensors are close within the specified tolerances, raise an exception otherwise. @@ -1547,6 +1549,7 @@ def assert_close( msg=msg, intersection=intersection, equal_nan=equal_nan, + prefix=prefix, ) return True @@ -1581,22 +1584,31 @@ def assert_close( msg=msg, intersection=intersection, equal_nan=equal_nan, + prefix=prefix + (key,), ) continue elif not isinstance(input1, torch.Tensor): continue - if input1.is_nested: - input1v = input1.values() - input2v = input2.values() - mse = (input1v.to(torch.float) - input2v.to(torch.float)).pow(2).sum() - input1o = input1.offsets() - input2o = input2.offsets() - mse = mse + (input1o.to(torch.float) - input2o.to(torch.float)).pow(2).sum() - else: - mse = (input1.to(torch.float) - input2.to(torch.float)).pow(2).sum() + try: + if input1.is_nested: + input1v = input1.values() + input2v = input2.values() + mse = (input1v.to(torch.float) - input2v.to(torch.float)).pow(2).sum() + input1o = input1.offsets() + input2o = input2.offsets() + mse = ( + mse + + (input1o.to(torch.float) - input2o.to(torch.float)).pow(2).sum() + ) + else: + mse = (input1.to(torch.float) - input2.to(torch.float)).pow(2).sum() + except Exception as err: + raise RuntimeError( + f"Failed to compare key {prefix + (key,)}. Scroll up for more details." + ) from err mse = mse.div(input1.numel()).sqrt().item() - local_msg = f"key {key} does not match, got mse = {mse:4.4f}" + local_msg = f"key {prefix + (key,)} does not match, got mse = {mse:4.4f}" new_msg = ",\t".join([local_msg, msg]) if len(msg) else local_msg if input1.is_nested: torch.testing.assert_close( @@ -1611,7 +1623,7 @@ def assert_close( torch.testing.assert_close( input1, input2, rtol=rtol, atol=atol, equal_nan=equal_nan, msg=new_msg ) - local_msg = f"key {key} matches" + local_msg = f"key {prefix + (key,)} matches" msg = "\t".join([local_msg, msg]) if len(msg) else local_msg return True