Skip to content

Commit 88d07a9

Browse files
committed
feat(//py): New API to embed engine in new module
Also adds tests to confirm TRT Python API intercompatiability Signed-off-by: Naren Dasan <naren@narendasan.com> Signed-off-by: Naren Dasan <narens@nvidia.com>
1 parent 3ec836e commit 88d07a9

File tree

10 files changed

+129
-18
lines changed

10 files changed

+129
-18
lines changed

core/compiler.cpp

+3-3
Original file line numberDiff line numberDiff line change
@@ -46,7 +46,7 @@ c10::FunctionSchema GenerateGraphSchema(
4646
void AddEngineToGraph(
4747
torch::jit::script::Module mod,
4848
std::shared_ptr<torch::jit::Graph>& g,
49-
std::string& serialized_engine) {
49+
const std::string& serialized_engine) {
5050
auto engine_ptr = c10::make_intrusive<runtime::TRTEngine>(mod._ivalue()->name(), serialized_engine);
5151
// Get required metadata about the engine out
5252
auto num_io = engine_ptr->num_io;
@@ -173,9 +173,9 @@ torch::jit::script::Module CompileGraph(const torch::jit::script::Module& mod, C
173173
return new_mod;
174174
}
175175

176-
torch::jit::script::Module EmbedEngineInNewModule(std::string& engine) {
176+
torch::jit::script::Module EmbedEngineInNewModule(const std::string& engine) {
177177
std::ostringstream engine_id;
178-
engine_id << reinterpret_cast<int*>(&engine);
178+
engine_id << reinterpret_cast<const int*>(&engine);
179179
torch::jit::script::Module new_mod("tensorrt_engine_mod_" + engine_id.str());
180180
auto new_g = std::make_shared<torch::jit::Graph>();
181181
AddEngineToGraph(new_mod, new_g, engine);

core/compiler.h

+1-1
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,7 @@ std::string ConvertGraphToTRTEngine(const torch::jit::script::Module& mod, std::
1919

2020
torch::jit::script::Module CompileGraph(const torch::jit::script::Module& module, CompileSpec cfg);
2121

22-
torch::jit::script::Module EmbedEngineInNewModule(std::string& engine);
22+
torch::jit::script::Module EmbedEngineInNewModule(const std::string& engine);
2323

2424
void set_device(const int gpu_id);
2525

cpp/api/include/trtorch/trtorch.h

+5-4
Original file line numberDiff line numberDiff line change
@@ -485,14 +485,15 @@ TRTORCH_API std::string ConvertGraphToTRTEngine(
485485
* @brief Take a previously created TensorRT engine and embed it in
486486
* in a TorchScript module
487487
*
488-
* @param engine: std::string - Precompiled serialized TensorRT engine
488+
* @param engine: std::string - Pre-built serialized TensorRT engine
489489
*
490-
* Takes a prebuilt serialized TensorRT engine and embeds it in a TorchScript
491-
* graph. Registers the engine as the forward method of the module
490+
* Takes a pre-built serialized TensorRT engine and embeds it in a TorchScript
491+
* module. Registers execution of the engine as the forward method of the module
492+
* Forward is defined as: forward(Tensor[]) -> Tensor[]
492493
*
493494
* @return: A new module trageting a TensorRT engine
494495
*/
495-
TRTORCH_API torch::jit::Module EmbedEngineInNewModule(std::string& engine);
496+
TRTORCH_API torch::jit::Module EmbedEngineInNewModule(const std::string& engine);
496497

497498
/**
498499
* @brief Set gpu device id

cpp/api/src/trtorch.cpp

+1-1
Original file line numberDiff line numberDiff line change
@@ -31,7 +31,7 @@ torch::jit::script::Module CompileGraph(const torch::jit::script::Module& module
3131
return core::CompileGraph(module, to_internal_compile_spec(info));
3232
}
3333

34-
torch::jit::Module EmbedEngineInNewModule(std::string& engine) {
34+
torch::jit::Module EmbedEngineInNewModule(const std::string& engine) {
3535
return core::EmbedEngineInNewModule(engine);
3636
}
3737

py/trtorch/_compiler.py

+20
Original file line numberDiff line numberDiff line change
@@ -124,6 +124,26 @@ def convert_method_to_trt_engine(module: torch.jit.ScriptModule, method_name: st
124124
return trtorch._C.convert_graph_to_trt_engine(module._c, method_name, _parse_compile_spec(compile_spec))
125125

126126

127+
def embed_engine_in_new_module(serialized_engine: bytes) -> torch.jit.ScriptModule:
128+
"""Takes a pre-built serialized TensorRT engine and embeds it within a TorchScript module
129+
130+
Takes a pre-built serialied TensorRT engine (as bytes) and embeds it within a TorchScript module.
131+
Registers the forward method to execute the TensorRT engine with the function signature:
132+
133+
forward(Tensor[]) -> Tensor[]
134+
135+
Module can be save with engine embedded with torch.jit.save and moved / loaded according to TRTorch portability rules
136+
137+
Args:
138+
serialized_engine (bytes): Serialized TensorRT engine from either TRTorch or TensorRT APIs
139+
140+
Returns:
141+
torch.jit.ScriptModule: New TorchScript module with engine embedded
142+
"""
143+
cpp_mod = trtorch._C.embed_engine_in_new_module(serialized_engine)
144+
return torch.jit._recursive.wrap_cpp_module(cpp_mod)
145+
146+
127147
def check_method_op_support(module: torch.jit.ScriptModule, method_name: str) -> bool:
128148
"""Checks to see if a method is fully supported by TRTorch
129149

py/trtorch/csrc/trtorch_py.cpp

+8
Original file line numberDiff line numberDiff line change
@@ -119,6 +119,10 @@ bool CheckMethodOperatorSupport(const torch::jit::Module& module, const std::str
119119
return core::CheckMethodOperatorSupport(module, method_name);
120120
}
121121

122+
torch::jit::Module EmbedEngineInNewModule(const py::bytes& engine) {
123+
return core::EmbedEngineInNewModule(engine);
124+
}
125+
122126
std::string get_build_info() {
123127
auto info = core::util::get_build_info();
124128
return info;
@@ -270,6 +274,10 @@ PYBIND11_MODULE(_C, m) {
270274
"check_method_op_support",
271275
&trtorch::pyapi::CheckMethodOperatorSupport,
272276
"Takes a module and a method name and checks if the method graph contains purely convertable operators");
277+
m.def(
278+
"embed_engine_in_new_module",
279+
&trtorch::pyapi::EmbedEngineInNewModule,
280+
"Takes a serialized TensorRT engine and wraps it in the forward method of a new TorchScript module");
273281
m.def("get_build_info", &get_build_info, "Returns build info about the compiler as a string");
274282

275283
m.def("_get_logging_prefix", &logging::get_logging_prefix, "Get the current prefix for the logging output");

tests/modules/test_modules_as_engines.cpp

+1-1
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,7 @@ TEST_P(ModuleTests, ModuleAsEngineIsClose) {
1616
ASSERT_TRUE(trtorch::tests::util::almostEqual(jit_results[0], trt_results[0].reshape_as(jit_results[0]), 2e-5));
1717
}
1818

19-
TEST_P(ModuleTests, ModuleToModuleIsClose) {
19+
TEST_P(ModuleTests, ModuleToEngineToModuleIsClose) {
2020
std::vector<at::Tensor> inputs;
2121
std::vector<torch::jit::IValue> inputs_ivalues;
2222
for (auto in_shape : input_shapes) {

tests/py/BUILD

+14-5
Original file line numberDiff line numberDiff line change
@@ -30,7 +30,7 @@ py_test(
3030
srcs = [
3131
"test_ptq_dataloader_calibrator.py",
3232
"model_test_case.py"
33-
]
33+
],
3434
deps = [
3535
requirement("torchvision")
3636
]
@@ -43,7 +43,7 @@ py_test(
4343
srcs = [
4444
"test_ptq_trt_calibrator.py",
4545
"model_test_case.py"
46-
]
46+
],
4747
deps = [
4848
requirement("torchvision")
4949
]
@@ -56,8 +56,6 @@ py_test(
5656
"test_multi_gpu.py",
5757
"model_test_case.py"
5858
],
59-
"//conditions:default" : []
60-
}),
6159
deps = [
6260
requirement("torchvision")
6361
]
@@ -74,12 +72,23 @@ py_test(
7472
]
7573
)
7674

75+
py_test(
76+
name = "test_trt_intercompatability",
77+
srcs = [
78+
"test_trt_intercompatability.py",
79+
"model_test_case.py"
80+
],
81+
deps = [
82+
requirement("torchvision")
83+
]
84+
)
85+
7786
py_test(
7887
name = "test_ptq_to_backend",
7988
srcs = [
8089
"test_ptq_to_backend.py",
8190
"model_test_case.py"
82-
]
91+
],
8392
deps = [
8493
requirement("torchvision")
8594
]

tests/py/test_api.py

+25-3
Original file line numberDiff line numberDiff line change
@@ -45,6 +45,27 @@ def test_compile_script(self):
4545
same = (trt_mod(self.input) - self.scripted_model(self.input)).abs().max()
4646
self.assertTrue(same < 2e-3)
4747

48+
class TestPTtoTRTtoPT(ModelTestCase):
49+
def setUp(self):
50+
self.input = torch.randn((1, 3, 224, 224)).to("cuda")
51+
self.ts_model = torch.jit.script(self.model)
52+
53+
def test_pt_to_trt_to_pt(self):
54+
compile_spec = {
55+
"input_shapes": [self.input.shape],
56+
"device": {
57+
"device_type": trtorch.DeviceType.GPU,
58+
"gpu_id": 0,
59+
"dla_core": 0,
60+
"allow_gpu_fallback": False,
61+
"disable_tf32": False
62+
}
63+
}
64+
65+
trt_engine = trtorch.convert_method_to_trt_engine(self.ts_model, "forward", compile_spec)
66+
trt_mod = trtorch.embed_engine_in_new_module(trt_engine)
67+
same = (trt_mod(self.input) - self.ts_model(self.input)).abs().max()
68+
self.assertTrue(same < 2e-3)
4869

4970
class TestCheckMethodOpSupport(unittest.TestCase):
5071

@@ -59,13 +80,13 @@ def test_check_support(self):
5980
class TestLoggingAPIs(unittest.TestCase):
6081

6182
def test_logging_prefix(self):
62-
new_prefix = "TEST"
83+
new_prefix = "Python API Test: "
6384
trtorch.logging.set_logging_prefix(new_prefix)
6485
logging_prefix = trtorch.logging.get_logging_prefix()
6586
self.assertEqual(new_prefix, logging_prefix)
6687

6788
def test_reportable_log_level(self):
68-
new_level = trtorch.logging.Level.Warning
89+
new_level = trtorch.logging.Level.Error
6990
trtorch.logging.set_reportable_log_level(new_level)
7091
level = trtorch.logging.get_reportable_log_level()
7192
self.assertEqual(new_level, level)
@@ -78,10 +99,11 @@ def test_is_colored_output_on(self):
7899

79100
def test_suite():
80101
suite = unittest.TestSuite()
102+
suite.addTest(unittest.makeSuite(TestLoggingAPIs))
81103
suite.addTest(TestCompile.parametrize(TestCompile, model=models.resnet18(pretrained=True)))
82104
suite.addTest(TestCompile.parametrize(TestCompile, model=models.mobilenet_v2(pretrained=True)))
105+
suite.addTest(TestPTtoTRTtoPT.parametrize(TestPTtoTRTtoPT, model=models.mobilenet_v2(pretrained=True)))
83106
suite.addTest(unittest.makeSuite(TestCheckMethodOpSupport))
84-
suite.addTest(unittest.makeSuite(TestLoggingAPIs))
85107

86108
return suite
87109

+51
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,51 @@
1+
import unittest
2+
import trtorch
3+
import torch
4+
import torchvision.models as models
5+
import tensorrt as trt
6+
7+
from model_test_case import ModelTestCase
8+
9+
10+
class TestPyTorchToTRTEngine(ModelTestCase):
11+
def setUp(self):
12+
self.input = torch.randn((1, 3, 224, 224)).to("cuda:0")
13+
self.ts_model = torch.jit.script(self.model)
14+
15+
def test_pt_to_trt(self):
16+
compile_spec = {
17+
"input_shapes": [self.input.shape],
18+
"device": {
19+
"device_type": trtorch.DeviceType.GPU,
20+
"gpu_id": 0,
21+
"dla_core": 0,
22+
"allow_gpu_fallback": False,
23+
"disable_tf32": False
24+
}
25+
}
26+
27+
trt_engine = trtorch.convert_method_to_trt_engine(self.ts_model, "forward", compile_spec)
28+
29+
TRT_LOGGER = trt.Logger(trt.Logger.WARNING)
30+
with trt.Runtime(TRT_LOGGER) as rt:
31+
engine = rt.deserialize_cuda_engine(trt_engine)
32+
with engine.create_execution_context() as ctx:
33+
out = torch.empty(size=tuple(engine.get_binding_shape(1))).to("cuda:0")
34+
bindings = [self.input.contiguous().data_ptr(), out.contiguous().data_ptr()]
35+
ctx.execute_async(batch_size=1, bindings=bindings, stream_handle=torch.cuda.current_stream(device='cuda:0').cuda_stream)
36+
same = (out - self.ts_model(self.input)).abs().max()
37+
self.assertTrue(same < 2e-3)
38+
39+
def test_suite():
40+
suite = unittest.TestSuite()
41+
suite.addTest(TestPyTorchToTRTEngine.parametrize(TestPyTorchToTRTEngine, model=models.resnet18(pretrained=True)))
42+
43+
return suite
44+
45+
46+
suite = test_suite()
47+
48+
runner = unittest.TextTestRunner()
49+
result = runner.run(suite)
50+
51+
exit(int(not result.wasSuccessful()))

0 commit comments

Comments
 (0)