Skip to content

Commit ec2232f

Browse files
committed
fix: Fixed failures for host deps sessions
Signed-off-by: Anurag Dixit <a.dixit91@gmail.com>
1 parent 8580423 commit ec2232f

File tree

1 file changed

+24
-6
lines changed

1 file changed

+24
-6
lines changed

noxfile.py

+24-6
Original file line numberDiff line numberDiff line change
@@ -58,7 +58,8 @@ def train_model(session, use_host_env=False):
5858

5959
session.run_always('python',
6060
'export_ckpt.py',
61-
'vgg16_ckpts/ckpt_epoch25.pth')
61+
'vgg16_ckpts/ckpt_epoch25.pth',
62+
env={'PYTHONPATH': PYT_PATH})
6263
else:
6364
session.run_always('python',
6465
'main.py',
@@ -146,13 +147,27 @@ def run_accuracy_tests(session, use_host_env=False):
146147
else:
147148
session.run_always("python", test)
148149

150+
def copy_model(session):
151+
model_files = [ 'trained_vgg16.jit.pt',
152+
'trained_vgg16_qat.jit.pt']
153+
154+
for file_name in model_files:
155+
src_file = os.path.join(TOP_DIR, str('examples/int8/training/vgg16/') + file_name)
156+
if os.path.exists(src_file):
157+
session.run_always('cp',
158+
'-rpf',
159+
os.path.join(TOP_DIR, src_file),
160+
os.path.join(TOP_DIR, str('tests/py/') + file_name),
161+
external=True)
162+
149163
def run_int8_accuracy_tests(session, use_host_env=False):
150164
print("Running accuracy tests")
165+
copy_model(session)
151166
session.chdir(os.path.join(TOP_DIR, 'tests/py'))
152167
tests = [
153-
"test_ptq_dataloader.py",
168+
"test_ptq_dataloader_calibrator.py",
154169
"test_ptq_to_backend.py",
155-
"test_qat_trt_accuracy",
170+
"test_qat_trt_accuracy.py",
156171
]
157172
for test in tests:
158173
if use_host_env:
@@ -162,9 +177,10 @@ def run_int8_accuracy_tests(session, use_host_env=False):
162177

163178
def run_trt_compatibility_tests(session, use_host_env=False):
164179
print("Running TensorRT compatibility tests")
180+
copy_model(session)
165181
session.chdir(os.path.join(TOP_DIR, 'tests/py'))
166182
tests = [
167-
"test_trt_intercompatibilty.py",
183+
"test_trt_intercompatability.py",
168184
"test_ptq_trt_calibrator.py",
169185
]
170186
for test in tests:
@@ -218,7 +234,7 @@ def run_l1_accuracy_tests(session, use_host_env=False):
218234
install_deps(session)
219235
install_torch_trt(session)
220236
download_models(session, use_host_env)
221-
download_datasets(session, use_host_env)
237+
download_datasets(session)
222238
train_model(session, use_host_env)
223239
run_accuracy_tests(session, use_host_env)
224240
cleanup(session)
@@ -228,7 +244,7 @@ def run_l1_int8_accuracy_tests(session, use_host_env=False):
228244
install_deps(session)
229245
install_torch_trt(session)
230246
download_models(session, use_host_env)
231-
download_datasets(session, use_host_env)
247+
download_datasets(session)
232248
train_model(session, use_host_env)
233249
finetune_model(session, use_host_env)
234250
run_int8_accuracy_tests(session, use_host_env)
@@ -239,6 +255,8 @@ def run_l2_trt_compatibility_tests(session, use_host_env=False):
239255
install_deps(session)
240256
install_torch_trt(session)
241257
download_models(session, use_host_env)
258+
download_datasets(session)
259+
train_model(session, use_host_env)
242260
run_trt_compatibility_tests(session, use_host_env)
243261
cleanup(session)
244262

0 commit comments

Comments
 (0)