diff --git a/axlearn/common/test_utils.py b/axlearn/common/test_utils.py index 6139de82..15298507 100644 --- a/axlearn/common/test_utils.py +++ b/axlearn/common/test_utils.py @@ -215,7 +215,7 @@ def assertNestedAllClose(self, a, b, atol=1e-6, rtol=1e-3): self.assertEqual(a_value.shape, b_value.shape, msg=f"{a_name}") assert_allclose(a_value, b_value, atol=atol, rtol=rtol, err_msg=f"{a_name}") else: - self.assertAlmostEqual(a_value, b_value) + self.assertAlmostEqual(a_value, b_value, msg=f"{a_name}") def assertNestedEqual(self, a, b): a_kv = flatten_items(a)