diff --git a/tensordict/utils.py b/tensordict/utils.py index ea8bba786..e43e4e8fd 100644 --- a/tensordict/utils.py +++ b/tensordict/utils.py @@ -1499,6 +1499,7 @@ def assert_close( equal_nan: bool = True, intersection: bool = False, msg: str = "", + prefix: NestedKey = (), ) -> bool: """Asserts that two tensordicts, `actual` and `expected`, are element-wise equal within a tolerance for all entries. @@ -1516,6 +1517,7 @@ def assert_close( intersection (bool, optional): If True, only the intersection of the two tensordicts will be compared. Default is ``False``. msg (str, optional): An optional message to include in the assertion error if the check fails. + prefix (NestedKey, optional): a prefix to add to the key for error messages. Returns: bool: True if the tensors are close within the specified tolerances, raise an exception otherwise. @@ -1547,6 +1549,7 @@ def assert_close( msg=msg, intersection=intersection, equal_nan=equal_nan, + prefix=prefix, ) return True @@ -1581,22 +1584,31 @@ def assert_close( msg=msg, intersection=intersection, equal_nan=equal_nan, + prefix=prefix + (key,), ) continue elif not isinstance(input1, torch.Tensor): continue - if input1.is_nested: - input1v = input1.values() - input2v = input2.values() - mse = (input1v.to(torch.float) - input2v.to(torch.float)).pow(2).sum() - input1o = input1.offsets() - input2o = input2.offsets() - mse = mse + (input1o.to(torch.float) - input2o.to(torch.float)).pow(2).sum() - else: - mse = (input1.to(torch.float) - input2.to(torch.float)).pow(2).sum() + try: + if input1.is_nested: + input1v = input1.values() + input2v = input2.values() + mse = (input1v.to(torch.float) - input2v.to(torch.float)).pow(2).sum() + input1o = input1.offsets() + input2o = input2.offsets() + mse = ( + mse + + (input1o.to(torch.float) - input2o.to(torch.float)).pow(2).sum() + ) + else: + mse = (input1.to(torch.float) - input2.to(torch.float)).pow(2).sum() + except Exception as err: + raise RuntimeError( + f"Failed to compare key {prefix + (key,)}. Scroll up for more details." + ) from err mse = mse.div(input1.numel()).sqrt().item() - local_msg = f"key {key} does not match, got mse = {mse:4.4f}" + local_msg = f"key {prefix + (key,)} does not match, got mse = {mse:4.4f}" new_msg = ",\t".join([local_msg, msg]) if len(msg) else local_msg if input1.is_nested: torch.testing.assert_close( @@ -1611,7 +1623,7 @@ def assert_close( torch.testing.assert_close( input1, input2, rtol=rtol, atol=atol, equal_nan=equal_nan, msg=new_msg ) - local_msg = f"key {key} matches" + local_msg = f"key {prefix + (key,)} matches" msg = "\t".join([local_msg, msg]) if len(msg) else local_msg return True