@@ -58,7 +58,8 @@ def train_model(session, use_host_env=False):
58
58
59
59
session .run_always ('python' ,
60
60
'export_ckpt.py' ,
61
- 'vgg16_ckpts/ckpt_epoch25.pth' )
61
+ 'vgg16_ckpts/ckpt_epoch25.pth' ,
62
+ env = {'PYTHONPATH' : PYT_PATH })
62
63
else :
63
64
session .run_always ('python' ,
64
65
'main.py' ,
@@ -146,13 +147,27 @@ def run_accuracy_tests(session, use_host_env=False):
146
147
else :
147
148
session .run_always ("python" , test )
148
149
150
+ def copy_model (session ):
151
+ model_files = [ 'trained_vgg16.jit.pt' ,
152
+ 'trained_vgg16_qat.jit.pt' ]
153
+
154
+ for file_name in model_files :
155
+ src_file = os .path .join (TOP_DIR , str ('examples/int8/training/vgg16/' ) + file_name )
156
+ if os .path .exists (src_file ):
157
+ session .run_always ('cp' ,
158
+ '-rpf' ,
159
+ os .path .join (TOP_DIR , src_file ),
160
+ os .path .join (TOP_DIR , str ('tests/py/' ) + file_name ),
161
+ external = True )
162
+
149
163
def run_int8_accuracy_tests (session , use_host_env = False ):
150
164
print ("Running accuracy tests" )
165
+ copy_model (session )
151
166
session .chdir (os .path .join (TOP_DIR , 'tests/py' ))
152
167
tests = [
153
- "test_ptq_dataloader .py" ,
168
+ "test_ptq_dataloader_calibrator .py" ,
154
169
"test_ptq_to_backend.py" ,
155
- "test_qat_trt_accuracy" ,
170
+ "test_qat_trt_accuracy.py " ,
156
171
]
157
172
for test in tests :
158
173
if use_host_env :
@@ -162,9 +177,10 @@ def run_int8_accuracy_tests(session, use_host_env=False):
162
177
163
178
def run_trt_compatibility_tests (session , use_host_env = False ):
164
179
print ("Running TensorRT compatibility tests" )
180
+ copy_model (session )
165
181
session .chdir (os .path .join (TOP_DIR , 'tests/py' ))
166
182
tests = [
167
- "test_trt_intercompatibilty .py" ,
183
+ "test_trt_intercompatability .py" ,
168
184
"test_ptq_trt_calibrator.py" ,
169
185
]
170
186
for test in tests :
@@ -218,7 +234,7 @@ def run_l1_accuracy_tests(session, use_host_env=False):
218
234
install_deps (session )
219
235
install_torch_trt (session )
220
236
download_models (session , use_host_env )
221
- download_datasets (session , use_host_env )
237
+ download_datasets (session )
222
238
train_model (session , use_host_env )
223
239
run_accuracy_tests (session , use_host_env )
224
240
cleanup (session )
@@ -228,7 +244,7 @@ def run_l1_int8_accuracy_tests(session, use_host_env=False):
228
244
install_deps (session )
229
245
install_torch_trt (session )
230
246
download_models (session , use_host_env )
231
- download_datasets (session , use_host_env )
247
+ download_datasets (session )
232
248
train_model (session , use_host_env )
233
249
finetune_model (session , use_host_env )
234
250
run_int8_accuracy_tests (session , use_host_env )
@@ -239,6 +255,8 @@ def run_l2_trt_compatibility_tests(session, use_host_env=False):
239
255
install_deps (session )
240
256
install_torch_trt (session )
241
257
download_models (session , use_host_env )
258
+ download_datasets (session )
259
+ train_model (session , use_host_env )
242
260
run_trt_compatibility_tests (session , use_host_env )
243
261
cleanup (session )
244
262
0 commit comments