Skip to content

Commit

Permalink
[Feature] Better logs of key errors in assert_close
Browse files Browse the repository at this point in the history
ghstack-source-id: 46cb41d0da34b17ccc248119c43ddba586d29d80
Pull Request resolved: #1082

(cherry picked from commit 747c593)
  • Loading branch information
vmoens committed Nov 14, 2024
1 parent f24e3d8 commit 1ef1188
Showing 1 changed file with 23 additions and 11 deletions.
34 changes: 23 additions & 11 deletions tensordict/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -1499,6 +1499,7 @@ def assert_close(
equal_nan: bool = True,
intersection: bool = False,
msg: str = "",
prefix: NestedKey = (),
) -> bool:
"""Asserts that two tensordicts, `actual` and `expected`, are element-wise equal within a tolerance for all entries.
Expand All @@ -1516,6 +1517,7 @@ def assert_close(
intersection (bool, optional): If True, only the intersection of the two tensordicts will be compared.
Default is ``False``.
msg (str, optional): An optional message to include in the assertion error if the check fails.
prefix (NestedKey, optional): a prefix to add to the key for error messages.
Returns:
bool: True if the tensors are close within the specified tolerances, raise an exception otherwise.
Expand Down Expand Up @@ -1547,6 +1549,7 @@ def assert_close(
msg=msg,
intersection=intersection,
equal_nan=equal_nan,
prefix=prefix,
)
return True

Expand Down Expand Up @@ -1581,22 +1584,31 @@ def assert_close(
msg=msg,
intersection=intersection,
equal_nan=equal_nan,
prefix=prefix + (key,),
)
continue
elif not isinstance(input1, torch.Tensor):
continue
if input1.is_nested:
input1v = input1.values()
input2v = input2.values()
mse = (input1v.to(torch.float) - input2v.to(torch.float)).pow(2).sum()
input1o = input1.offsets()
input2o = input2.offsets()
mse = mse + (input1o.to(torch.float) - input2o.to(torch.float)).pow(2).sum()
else:
mse = (input1.to(torch.float) - input2.to(torch.float)).pow(2).sum()
try:
if input1.is_nested:
input1v = input1.values()
input2v = input2.values()
mse = (input1v.to(torch.float) - input2v.to(torch.float)).pow(2).sum()
input1o = input1.offsets()
input2o = input2.offsets()
mse = (
mse
+ (input1o.to(torch.float) - input2o.to(torch.float)).pow(2).sum()
)
else:
mse = (input1.to(torch.float) - input2.to(torch.float)).pow(2).sum()
except Exception as err:
raise RuntimeError(
f"Failed to compare key {prefix + (key,)}. Scroll up for more details."
) from err
mse = mse.div(input1.numel()).sqrt().item()

local_msg = f"key {key} does not match, got mse = {mse:4.4f}"
local_msg = f"key {prefix + (key,)} does not match, got mse = {mse:4.4f}"
new_msg = ",\t".join([local_msg, msg]) if len(msg) else local_msg
if input1.is_nested:
torch.testing.assert_close(
Expand All @@ -1611,7 +1623,7 @@ def assert_close(
torch.testing.assert_close(
input1, input2, rtol=rtol, atol=atol, equal_nan=equal_nan, msg=new_msg
)
local_msg = f"key {key} matches"
local_msg = f"key {prefix + (key,)} matches"
msg = "\t".join([local_msg, msg]) if len(msg) else local_msg

return True
Expand Down

0 comments on commit 1ef1188

Please # to comment.