@@ -31,9 +31,25 @@ def test_compile_traced(self):
31
31
32
32
def test_compile_script (self ):
33
33
trt_mod = torchtrt .ts .compile (self .scripted_model ,
34
- inputs = [self .input ],
35
- device = torchtrt .Device (gpu_id = 0 ),
36
- enabled_precisions = {torch .float })
34
+ inputs = [self .input ],
35
+ device = torchtrt .Device (gpu_id = 0 ),
36
+ enabled_precisions = {torch .float })
37
+ same = (trt_mod (self .input ) - self .scripted_model (self .input )).abs ().max ()
38
+ self .assertTrue (same < 2e-2 )
39
+
40
+ def test_compile_global (self ):
41
+ trt_mod = torchtrt .compile (self .scripted_model ,
42
+ inputs = [self .input ],
43
+ device = torchtrt .Device (gpu_id = 0 ),
44
+ enabled_precisions = {torch .float })
45
+ same = (trt_mod (self .input ) - self .scripted_model (self .input )).abs ().max ()
46
+ self .assertTrue (same < 2e-2 )
47
+
48
+ def test_compile_global_nn_mod (self ):
49
+ trt_mod = torchtrt .compile (self .model ,
50
+ inputs = [self .input ],
51
+ device = torchtrt .Device (gpu_id = 0 ),
52
+ enabled_precisions = {torch .float })
37
53
same = (trt_mod (self .input ) - self .scripted_model (self .input )).abs ().max ()
38
54
self .assertTrue (same < 2e-2 )
39
55
@@ -207,16 +223,16 @@ def setUp(self):
207
223
def test_input_use_default_fp32 (self ):
208
224
ts_model = torch .jit .script (self .model )
209
225
trt_mod = torchtrt .ts .compile (ts_model ,
210
- inputs = [torchtrt .Input (self .input .shape )],
211
- enabled_precisions = {torch .float , torch .half })
226
+ inputs = [torchtrt .Input (self .input .shape )],
227
+ enabled_precisions = {torch .float , torch .half })
212
228
trt_mod (self .input )
213
229
214
230
def test_input_respect_user_setting_fp32_weights_fp16_in (self ):
215
231
ts_model = torch .jit .script (self .model )
216
232
trt_mod = torchtrt .ts .compile (ts_model ,
217
- inputs = [self .input .half ()],
218
- require_full_compilation = True ,
219
- enabled_precisions = {torch .float , torch .half })
233
+ inputs = [self .input .half ()],
234
+ require_full_compilation = True ,
235
+ enabled_precisions = {torch .float , torch .half })
220
236
trt_mod (self .input .half ())
221
237
222
238
def test_input_respect_user_setting_fp32_weights_fp16_in_non_constructor (self ):
@@ -225,9 +241,9 @@ def test_input_respect_user_setting_fp32_weights_fp16_in_non_constructor(self):
225
241
input_spec .dtype = torch .half
226
242
227
243
trt_mod = torchtrt .ts .compile (ts_model ,
228
- inputs = [input_spec ],
229
- require_full_compilation = True ,
230
- enabled_precisions = {torch .float , torch .half })
244
+ inputs = [input_spec ],
245
+ require_full_compilation = True ,
246
+ enabled_precisions = {torch .float , torch .half })
231
247
trt_mod (self .input .half ())
232
248
233
249
@@ -241,8 +257,8 @@ def test_input_use_default_fp16(self):
241
257
half_mod .half ()
242
258
243
259
trt_mod = torchtrt .ts .compile (half_mod ,
244
- inputs = [torchtrt .Input (self .input .shape )],
245
- enabled_precisions = {torch .float , torch .half })
260
+ inputs = [torchtrt .Input (self .input .shape )],
261
+ enabled_precisions = {torch .float , torch .half })
246
262
trt_mod (self .input .half ())
247
263
248
264
def test_input_use_default_fp16_without_fp16_enabled (self ):
@@ -257,9 +273,9 @@ def test_input_respect_user_setting_fp16_weights_fp32_in(self):
257
273
half_mod .half ()
258
274
259
275
trt_mod = torchtrt .ts .compile (half_mod ,
260
- inputs = [self .input ],
261
- require_full_compilation = True ,
262
- enabled_precisions = {torch .float , torch .half })
276
+ inputs = [self .input ],
277
+ require_full_compilation = True ,
278
+ enabled_precisions = {torch .float , torch .half })
263
279
trt_mod (self .input )
264
280
265
281
def test_input_respect_user_setting_fp16_weights_fp32_in_non_constuctor (self ):
@@ -270,9 +286,9 @@ def test_input_respect_user_setting_fp16_weights_fp32_in_non_constuctor(self):
270
286
input_spec .dtype = torch .float
271
287
272
288
trt_mod = torchtrt .ts .compile (half_mod ,
273
- inputs = [input_spec ],
274
- require_full_compilation = True ,
275
- enabled_precisions = {torch .float , torch .half })
289
+ inputs = [input_spec ],
290
+ require_full_compilation = True ,
291
+ enabled_precisions = {torch .float , torch .half })
276
292
trt_mod (self .input )
277
293
278
294
@@ -352,14 +368,15 @@ def test_from_torch(self):
352
368
self .assertEqual (device .device_type , torchtrt .DeviceType .GPU )
353
369
self .assertEqual (device .gpu_id , 0 )
354
370
371
+
355
372
class TestInput (unittest .TestCase ):
356
373
357
374
def _verify_correctness (self , struct : torchtrt .Input , target : Dict ) -> bool :
358
375
internal = struct ._to_internal ()
359
376
360
- list_eq = lambda al , bl : all ([a == b for (a , b ) in zip (al , bl )])
377
+ list_eq = lambda al , bl : all ([a == b for (a , b ) in zip (al , bl )])
361
378
362
- eq = lambda a , b : a == b
379
+ eq = lambda a , b : a == b
363
380
364
381
def field_is_correct (field , equal_fn , a1 , a2 ):
365
382
equal = equal_fn (a1 , a2 )
@@ -371,12 +388,12 @@ def field_is_correct(field, equal_fn, a1, a2):
371
388
opt_ = field_is_correct ("opt" , list_eq , internal .opt , target ["opt" ])
372
389
max_ = field_is_correct ("max" , list_eq , internal .max , target ["max" ])
373
390
is_dynamic_ = field_is_correct ("is_dynamic" , eq , internal .input_is_dynamic , target ["input_is_dynamic" ])
374
- explicit_set_dtype_ = field_is_correct ("explicit_dtype" , eq , internal ._explicit_set_dtype , target ["explicit_set_dtype" ])
391
+ explicit_set_dtype_ = field_is_correct ("explicit_dtype" , eq , internal ._explicit_set_dtype ,
392
+ target ["explicit_set_dtype" ])
375
393
dtype_ = field_is_correct ("dtype" , eq , int (internal .dtype ), int (target ["dtype" ]))
376
394
format_ = field_is_correct ("format" , eq , int (internal .format ), int (target ["format" ]))
377
395
378
- return all ([min_ ,opt_ ,max_ ,is_dynamic_ ,explicit_set_dtype_ ,dtype_ ,format_ ])
379
-
396
+ return all ([min_ , opt_ , max_ , is_dynamic_ , explicit_set_dtype_ , dtype_ , format_ ])
380
397
381
398
def test_infer_from_example_tensor (self ):
382
399
shape = [1 , 3 , 255 , 255 ]
@@ -394,7 +411,6 @@ def test_infer_from_example_tensor(self):
394
411
i = torchtrt .Input ._from_tensor (example_tensor )
395
412
self .assertTrue (self ._verify_correctness (i , target ))
396
413
397
-
398
414
def test_static_shape (self ):
399
415
shape = [1 , 3 , 255 , 255 ]
400
416
target = {
@@ -482,9 +498,12 @@ def test_dynamic_shape(self):
482
498
self .assertTrue (self ._verify_correctness (i , target ))
483
499
484
500
tensor_shape = lambda shape : torch .randn (shape ).shape
485
- i = torchtrt .Input (min_shape = tensor_shape (min_shape ), opt_shape = tensor_shape (opt_shape ), max_shape = tensor_shape (max_shape ))
501
+ i = torchtrt .Input (min_shape = tensor_shape (min_shape ),
502
+ opt_shape = tensor_shape (opt_shape ),
503
+ max_shape = tensor_shape (max_shape ))
486
504
self .assertTrue (self ._verify_correctness (i , target ))
487
505
506
+
488
507
def test_suite ():
489
508
suite = unittest .TestSuite ()
490
509
suite .addTest (unittest .makeSuite (TestLoggingAPIs ))
0 commit comments