@@ -45,6 +45,27 @@ def test_compile_script(self):
45
45
same = (trt_mod (self .input ) - self .scripted_model (self .input )).abs ().max ()
46
46
self .assertTrue (same < 2e-3 )
47
47
48
+ class TestPTtoTRTtoPT (ModelTestCase ):
49
+ def setUp (self ):
50
+ self .input = torch .randn ((1 , 3 , 224 , 224 )).to ("cuda" )
51
+ self .ts_model = torch .jit .script (self .model )
52
+
53
+ def test_pt_to_trt_to_pt (self ):
54
+ compile_spec = {
55
+ "input_shapes" : [self .input .shape ],
56
+ "device" : {
57
+ "device_type" : trtorch .DeviceType .GPU ,
58
+ "gpu_id" : 0 ,
59
+ "dla_core" : 0 ,
60
+ "allow_gpu_fallback" : False ,
61
+ "disable_tf32" : False
62
+ }
63
+ }
64
+
65
+ trt_engine = trtorch .convert_method_to_trt_engine (self .ts_model , "forward" , compile_spec )
66
+ trt_mod = trtorch .embed_engine_in_new_module (trt_engine )
67
+ same = (trt_mod (self .input ) - self .ts_model (self .input )).abs ().max ()
68
+ self .assertTrue (same < 2e-3 )
48
69
49
70
class TestCheckMethodOpSupport (unittest .TestCase ):
50
71
@@ -59,13 +80,13 @@ def test_check_support(self):
59
80
class TestLoggingAPIs (unittest .TestCase ):
60
81
61
82
def test_logging_prefix (self ):
62
- new_prefix = "TEST "
83
+ new_prefix = "Python API Test: "
63
84
trtorch .logging .set_logging_prefix (new_prefix )
64
85
logging_prefix = trtorch .logging .get_logging_prefix ()
65
86
self .assertEqual (new_prefix , logging_prefix )
66
87
67
88
def test_reportable_log_level (self ):
68
- new_level = trtorch .logging .Level .Warning
89
+ new_level = trtorch .logging .Level .Error
69
90
trtorch .logging .set_reportable_log_level (new_level )
70
91
level = trtorch .logging .get_reportable_log_level ()
71
92
self .assertEqual (new_level , level )
@@ -78,10 +99,11 @@ def test_is_colored_output_on(self):
78
99
79
100
def test_suite ():
80
101
suite = unittest .TestSuite ()
102
+ suite .addTest (unittest .makeSuite (TestLoggingAPIs ))
81
103
suite .addTest (TestCompile .parametrize (TestCompile , model = models .resnet18 (pretrained = True )))
82
104
suite .addTest (TestCompile .parametrize (TestCompile , model = models .mobilenet_v2 (pretrained = True )))
105
+ suite .addTest (TestPTtoTRTtoPT .parametrize (TestPTtoTRTtoPT , model = models .mobilenet_v2 (pretrained = True )))
83
106
suite .addTest (unittest .makeSuite (TestCheckMethodOpSupport ))
84
- suite .addTest (unittest .makeSuite (TestLoggingAPIs ))
85
107
86
108
return suite
87
109
0 commit comments