1
1
import torch
2
2
import timm
3
3
import pytest
4
+ import unittest
4
5
5
6
import torch_tensorrt as torchtrt
6
7
import torchvision .models as models
12
13
cosine_similarity ,
13
14
)
14
15
16
+ assertions = unittest .TestCase ()
17
+
15
18
16
19
@pytest .mark .unit
17
20
def test_resnet18 (ir ):
@@ -31,9 +34,9 @@ def test_resnet18(ir):
31
34
32
35
trt_mod = torchtrt .compile (model , ** compile_spec )
33
36
cos_sim = cosine_similarity (model (input ), trt_mod (input ))
34
- assert (
37
+ assertions . assertTrue (
35
38
cos_sim > COSINE_THRESHOLD ,
36
- f"Resnet18 TRT outputs don't match with the original model. Cosine sim score: { cos_sim } Threshold: { COSINE_THRESHOLD } " ,
39
+ msg = f"Resnet18 TRT outputs don't match with the original model. Cosine sim score: { cos_sim } Threshold: { COSINE_THRESHOLD } " ,
37
40
)
38
41
39
42
# Clean up model env
@@ -61,9 +64,9 @@ def test_mobilenet_v2(ir):
61
64
62
65
trt_mod = torchtrt .compile (model , ** compile_spec )
63
66
cos_sim = cosine_similarity (model (input ), trt_mod (input ))
64
- assert (
67
+ assertions . assertTrue (
65
68
cos_sim > COSINE_THRESHOLD ,
66
- f"Mobilenet v2 TRT outputs don't match with the original model. Cosine sim score: { cos_sim } Threshold: { COSINE_THRESHOLD } " ,
69
+ msg = f"Mobilenet v2 TRT outputs don't match with the original model. Cosine sim score: { cos_sim } Threshold: { COSINE_THRESHOLD } " ,
67
70
)
68
71
69
72
# Clean up model env
@@ -91,9 +94,9 @@ def test_efficientnet_b0(ir):
91
94
92
95
trt_mod = torchtrt .compile (model , ** compile_spec )
93
96
cos_sim = cosine_similarity (model (input ), trt_mod (input ))
94
- assert (
97
+ assertions . assertTrue (
95
98
cos_sim > COSINE_THRESHOLD ,
96
- f"EfficientNet-B0 TRT outputs don't match with the original model. Cosine sim score: { cos_sim } Threshold: { COSINE_THRESHOLD } " ,
99
+ msg = f"EfficientNet-B0 TRT outputs don't match with the original model. Cosine sim score: { cos_sim } Threshold: { COSINE_THRESHOLD } " ,
97
100
)
98
101
99
102
# Clean up model env
@@ -134,9 +137,9 @@ def test_bert_base_uncased(ir):
134
137
for key in model_outputs .keys ():
135
138
out , trt_out = model_outputs [key ], trt_model_outputs [key ]
136
139
cos_sim = cosine_similarity (out , trt_out )
137
- assert (
140
+ assertions . assertTrue (
138
141
cos_sim > COSINE_THRESHOLD ,
139
- f"HF BERT base-uncased TRT outputs don't match with the original model. Cosine sim score: { cos_sim } Threshold: { COSINE_THRESHOLD } " ,
142
+ msg = f"HF BERT base-uncased TRT outputs don't match with the original model. Cosine sim score: { cos_sim } Threshold: { COSINE_THRESHOLD } " ,
140
143
)
141
144
142
145
# Clean up model env
@@ -164,9 +167,9 @@ def test_resnet18_half(ir):
164
167
165
168
trt_mod = torchtrt .compile (model , ** compile_spec )
166
169
cos_sim = cosine_similarity (model (input ), trt_mod (input ))
167
- assert (
170
+ assertions . assertTrue (
168
171
cos_sim > COSINE_THRESHOLD ,
169
- f"Resnet18 Half TRT outputs don't match with the original model. Cosine sim score: { cos_sim } Threshold: { COSINE_THRESHOLD } " ,
172
+ msg = f"Resnet18 Half TRT outputs don't match with the original model. Cosine sim score: { cos_sim } Threshold: { COSINE_THRESHOLD } " ,
170
173
)
171
174
172
175
# Clean up model env
0 commit comments