Skip to content

Commit 56fb02b

Browse files
committed
fix: Refactor assertions in E2E tests for Dynamo
- Add unittest assertion module to streamline error messaging and reporting
1 parent bf4474d commit 56fb02b

File tree

1 file changed

+13
-10
lines changed

1 file changed

+13
-10
lines changed

Diff for: py/torch_tensorrt/dynamo/test/test_dynamo_backend.py

+13-10
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
import torch
22
import timm
33
import pytest
4+
import unittest
45

56
import torch_tensorrt as torchtrt
67
import torchvision.models as models
@@ -12,6 +13,8 @@
1213
cosine_similarity,
1314
)
1415

16+
assertions = unittest.TestCase()
17+
1518

1619
@pytest.mark.unit
1720
def test_resnet18(ir):
@@ -31,9 +34,9 @@ def test_resnet18(ir):
3134

3235
trt_mod = torchtrt.compile(model, **compile_spec)
3336
cos_sim = cosine_similarity(model(input), trt_mod(input))
34-
assert (
37+
assertions.assertTrue(
3538
cos_sim > COSINE_THRESHOLD,
36-
f"Resnet18 TRT outputs don't match with the original model. Cosine sim score: {cos_sim} Threshold: {COSINE_THRESHOLD}",
39+
msg=f"Resnet18 TRT outputs don't match with the original model. Cosine sim score: {cos_sim} Threshold: {COSINE_THRESHOLD}",
3740
)
3841

3942
# Clean up model env
@@ -61,9 +64,9 @@ def test_mobilenet_v2(ir):
6164

6265
trt_mod = torchtrt.compile(model, **compile_spec)
6366
cos_sim = cosine_similarity(model(input), trt_mod(input))
64-
assert (
67+
assertions.assertTrue(
6568
cos_sim > COSINE_THRESHOLD,
66-
f"Mobilenet v2 TRT outputs don't match with the original model. Cosine sim score: {cos_sim} Threshold: {COSINE_THRESHOLD}",
69+
msg=f"Mobilenet v2 TRT outputs don't match with the original model. Cosine sim score: {cos_sim} Threshold: {COSINE_THRESHOLD}",
6770
)
6871

6972
# Clean up model env
@@ -91,9 +94,9 @@ def test_efficientnet_b0(ir):
9194

9295
trt_mod = torchtrt.compile(model, **compile_spec)
9396
cos_sim = cosine_similarity(model(input), trt_mod(input))
94-
assert (
97+
assertions.assertTrue(
9598
cos_sim > COSINE_THRESHOLD,
96-
f"EfficientNet-B0 TRT outputs don't match with the original model. Cosine sim score: {cos_sim} Threshold: {COSINE_THRESHOLD}",
99+
msg=f"EfficientNet-B0 TRT outputs don't match with the original model. Cosine sim score: {cos_sim} Threshold: {COSINE_THRESHOLD}",
97100
)
98101

99102
# Clean up model env
@@ -134,9 +137,9 @@ def test_bert_base_uncased(ir):
134137
for key in model_outputs.keys():
135138
out, trt_out = model_outputs[key], trt_model_outputs[key]
136139
cos_sim = cosine_similarity(out, trt_out)
137-
assert (
140+
assertions.assertTrue(
138141
cos_sim > COSINE_THRESHOLD,
139-
f"HF BERT base-uncased TRT outputs don't match with the original model. Cosine sim score: {cos_sim} Threshold: {COSINE_THRESHOLD}",
142+
msg=f"HF BERT base-uncased TRT outputs don't match with the original model. Cosine sim score: {cos_sim} Threshold: {COSINE_THRESHOLD}",
140143
)
141144

142145
# Clean up model env
@@ -164,9 +167,9 @@ def test_resnet18_half(ir):
164167

165168
trt_mod = torchtrt.compile(model, **compile_spec)
166169
cos_sim = cosine_similarity(model(input), trt_mod(input))
167-
assert (
170+
assertions.assertTrue(
168171
cos_sim > COSINE_THRESHOLD,
169-
f"Resnet18 Half TRT outputs don't match with the original model. Cosine sim score: {cos_sim} Threshold: {COSINE_THRESHOLD}",
172+
msg=f"Resnet18 Half TRT outputs don't match with the original model. Cosine sim score: {cos_sim} Threshold: {COSINE_THRESHOLD}",
170173
)
171174

172175
# Clean up model env

0 commit comments

Comments
 (0)