Skip to content

Commit 711930f

Browse files
committed
chore: wrapped module runtime api draft
1 parent 7e22f61 commit 711930f

File tree

8 files changed

+120
-74
lines changed

8 files changed

+120
-74
lines changed

py/torch_tensorrt/dynamo/_compiler.py

+4-6
Original file line numberDiff line numberDiff line change
@@ -376,6 +376,7 @@ def compile(
376376
use_explicit_typing: bool = _defaults.USE_EXPLICIT_TYPING,
377377
use_fp32_acc: bool = _defaults.USE_FP32_ACC,
378378
enable_weight_streaming: bool = _defaults.ENABLE_WEIGHT_STREAMING,
379+
enable_wrapper_module: bool = _defaults.ENABLE_WRAPPER_MODULE,
379380
**kwargs: Any,
380381
) -> torch.fx.GraphModule:
381382
"""Compile an ExportedProgram module for NVIDIA GPUs using TensorRT
@@ -592,6 +593,7 @@ def compile(
592593
"use_fp32_acc": use_fp32_acc,
593594
"enable_cross_compile_for_windows": False,
594595
"enable_weight_streaming": enable_weight_streaming,
596+
"enable_wrapper_module": enable_wrapper_module,
595597
}
596598

597599
settings = CompilationSettings(**compilation_options)
@@ -835,13 +837,9 @@ def contains_metadata(gm: torch.fx.GraphModule) -> bool:
835837

836838
dryrun_stats_display(dryrun_tracker, settings.dryrun)
837839

838-
if len(dryrun_tracker.to_run_in_torch) > 0:
840+
if settings.enable_wrapper_module:
839841
# Capture/replay a series of CUDA operations in subgraphs in a wrapped runtime module.
840-
partitioned_module = WrapperTorchTensorRTModule(
841-
partitioned_module,
842-
dryrun_tracker.output_shapes,
843-
dryrun_tracker.output_dtypes,
844-
)
842+
partitioned_module = WrapperTorchTensorRTModule(partitioned_module)
845843

846844
return partitioned_module
847845

py/torch_tensorrt/dynamo/_defaults.py

+1
Original file line numberDiff line numberDiff line change
@@ -44,6 +44,7 @@
4444
USE_FP32_ACC = False
4545
ENABLE_WEIGHT_STREAMING = False
4646
ENABLE_CROSS_COMPILE_FOR_WINDOWS = False
47+
ENABLE_WRAPPER_MODULE = False
4748

4849

4950
def default_device() -> Device:

py/torch_tensorrt/dynamo/_settings.py

+2
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@
1616
ENABLE_CROSS_COMPILE_FOR_WINDOWS,
1717
ENABLE_EXPERIMENTAL_DECOMPOSITIONS,
1818
ENABLE_WEIGHT_STREAMING,
19+
ENABLE_WRAPPER_MODULE,
1920
ENABLED_PRECISIONS,
2021
ENGINE_CAPABILITY,
2122
HARDWARE_COMPATIBLE,
@@ -125,6 +126,7 @@ class CompilationSettings:
125126
use_fp32_acc: bool = USE_FP32_ACC
126127
enable_weight_streaming: bool = ENABLE_WEIGHT_STREAMING
127128
enable_cross_compile_for_windows: bool = ENABLE_CROSS_COMPILE_FOR_WINDOWS
129+
enable_wrapper_module: bool = ENABLE_WRAPPER_MODULE
128130

129131

130132
_SETTINGS_TO_BE_ENGINE_INVARIANT = (

py/torch_tensorrt/dynamo/runtime/_PythonTorchTensorRTModule.py

-2
Original file line numberDiff line numberDiff line change
@@ -250,8 +250,6 @@ def forward(self, *inputs: torch.Tensor) -> torch.Tensor | Tuple[torch.Tensor, .
250250
(i.contiguous() if isinstance(i, torch.Tensor) else torch.tensor(i).cuda())
251251
for i in inputs
252252
]
253-
# TODO: calculate output shape under fakeTensorMode
254-
# fake_mode = detect_fake_mode(*inputs)
255253
with (
256254
torch.autograd.profiler.record_function("PythonTorchTensorRTModule:Forward")
257255
if self.profiling_enabled

py/torch_tensorrt/dynamo/runtime/_WrapperTorchTensorRTModule.py

+1-51
Original file line numberDiff line numberDiff line change
@@ -7,11 +7,8 @@
77

88
import torch
99
import torch_tensorrt
10-
from torch._subclasses.fake_tensor import FakeTensorMode
1110
from torch.fx.experimental.proxy_tensor import unset_fake_temporarily
1211
from torch_tensorrt.dynamo import partitioning
13-
from torch_tensorrt.dynamo.conversion import DYNAMIC_DIM
14-
from torch_tensorrt.dynamo.utils import input_is_dynamic
1512
from torch_tensorrt.runtime._utils import _is_switch_required, _select_rt_device
1613

1714
logger = logging.getLogger(__name__)
@@ -21,25 +18,18 @@ class WrapperTorchTensorRTModule(torch.nn.Module): # type: ignore[misc]
2118
"""This Wrapper runtime module is to record/replay whole cuda graph in sub modules
2219
2320
Args:
24-
original_module: Unmodified FX GraphModule
2521
compiled_module: Complied fx graphModule that will be wrapped
26-
output_shapes: Shapes of output Tensors of the graph
27-
output_dtypes: Output data types of the graph
2822
Returns:
2923
Output tensor or tensor list
3024
"""
3125

3226
def __init__(
3327
self,
3428
compiled_module: torch.nn.Module,
35-
output_shapes: List[torch.Size],
36-
output_dtypes: List[torch.dtype],
3729
):
3830
super(WrapperTorchTensorRTModule, self).__init__()
3931
self.compiled_module = compiled_module
4032
self.inputs = partitioning.construct_submodule_inputs(compiled_module)
41-
self.output_shapes = output_shapes
42-
self.output_dtypes = output_dtypes
4333

4434
self._input_buffers: List[torch.Tensor] = []
4535
self._output_buffers: List[torch.Tensor] = []
@@ -49,7 +39,6 @@ def __init__(
4939
self.prev_cudagraphs_enabled = False
5040
self._caller_stream: Optional[torch.cuda.Stream] = None
5141
self._engine_stream: Optional[torch.cuda.Stream] = None
52-
self.input_is_dynamic = input_is_dynamic(self.inputs)
5342

5443
# Disable cudagrphs in submodules as it will be enabled in wrapper
5544
for name, rt_mod in self.compiled_module.named_children():
@@ -82,18 +71,9 @@ def validate_input_shapes(self, inputs: Sequence[torch.Tensor]) -> bool:
8271
# x: (3, 4), y: (4, 5) --> Key: (3,4)(4,5)
8372
new_shape_key = "".join(str(tuple(t.shape)).replace(" ", "") for t in inputs)
8473

85-
# If the new shape key differs from the existing one, infer new output shape
8674
if new_shape_key != self.shape_key:
8775
logger.debug(f"Input shape changed {self.shape_key} -> {new_shape_key}")
8876
self.shape_key = new_shape_key
89-
90-
if self.input_is_dynamic:
91-
with FakeTensorMode(allow_non_fake_inputs=True):
92-
tmp_outputs = self.compiled_module(*inputs)
93-
if not isinstance(tmp_outputs, (list, tuple)):
94-
tmp_outputs = [tmp_outputs]
95-
self.output_shapes = [tuple(output.shape) for output in tmp_outputs]
96-
print("self.output_shapes ", self.output_shapes)
9777
return True
9878

9979
return False
@@ -128,7 +108,6 @@ def forward(self, *inputs: torch.Tensor) -> torch.Tensor | Tuple[torch.Tensor, .
128108
self.cudagraph.reset()
129109

130110
self._input_buffers = [None] * len(self.inputs)
131-
self._output_buffers = [None] * len(self.output_shapes)
132111

133112
if not cudagraphs_enabled and self.cudagraph:
134113
self.cudagraph.reset()
@@ -202,32 +181,6 @@ def forward(self, *inputs: torch.Tensor) -> torch.Tensor | Tuple[torch.Tensor, .
202181
elif cudagraphs_enabled:
203182
self._input_buffers[i].copy_(contiguous_inputs[i])
204183

205-
with (
206-
torch.autograd.profiler.record_function(
207-
"WrapperTorchTensorRTModule:ProcessOutputs"
208-
)
209-
if self.profiling_enabled
210-
else nullcontext()
211-
):
212-
# create output tensors
213-
outputs: List[torch.Tensor] = []
214-
215-
for o, shape in enumerate(self.output_shapes):
216-
if DYNAMIC_DIM in shape:
217-
raise ValueError(
218-
"Encountered dynamic output shapes during runtime. This could mean the network has data-dependent output shapes which is not currently supported."
219-
)
220-
221-
output = torch.empty(
222-
size=shape,
223-
dtype=self.output_dtypes[o],
224-
device=torch.cuda.current_device(),
225-
)
226-
227-
outputs.append(output)
228-
229-
if need_cudagraphs_record:
230-
self._output_buffers[o] = outputs[o].clone()
231184
with (
232185
torch.autograd.profiler.record_function(
233186
"WrapperTorchTensorRTModule:TensorRTRuntime"
@@ -277,13 +230,10 @@ def forward(self, *inputs: torch.Tensor) -> torch.Tensor | Tuple[torch.Tensor, .
277230
output_buffers = self._output_buffers
278231
else:
279232
output_buffers = [self._output_buffers]
280-
for idx, o in enumerate(outputs):
281-
o.copy_(output_buffers[idx])
282-
233+
outputs = [output.clone() for output in output_buffers]
283234
if len(outputs) == 1:
284235
return outputs[0]
285236

286237
return outputs
287238
else:
288-
289239
return outputs

py/torch_tensorrt/dynamo/runtime/register_fake_class.py

+1-10
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,6 @@
33
from typing import Any, List
44

55
import torch
6-
from torch._library.fake_class_registry import FakeScriptObject
76
from torch_tensorrt.dynamo.utils import input_is_dynamic, unwrap_tensor_shape
87

98

@@ -27,12 +26,7 @@ def fake_tensorrt_execute_engine(
2726
modes = ["opt"]
2827

2928
# Get the TRTEngine class and infer output shapes based on input shapes
30-
# If fake_trt_engine is not FakeScriptObject, assumes that it is the real object
31-
if isinstance(fake_trt_engine, FakeScriptObject):
32-
trt_engine = fake_trt_engine.wrapped_obj.engine
33-
else:
34-
trt_engine = fake_trt_engine
35-
29+
trt_engine = fake_trt_engine.wrapped_obj.engine
3630
outputs_mode_dict = defaultdict(list)
3731
for mode in modes:
3832
input_shapes = [unwrap_tensor_shape(input, mode=mode) for input in inputs]
@@ -131,8 +125,5 @@ def automatic_device_memory_budget_getter(self) -> Any:
131125
def infer_outputs(self, input_shapes: List[Any]) -> Any:
132126
pass
133127

134-
def set_whole_cudagraphs(self) -> Any:
135-
pass
136-
137128
def __setstate__(self, serialized_state: List[str]) -> Any:
138129
pass

py/torch_tensorrt/runtime/_cudagraphs.py

+14-5
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,11 @@
11
import logging
2-
from typing import Any
2+
from typing import Any, Optional
33

44
import torch
55
import torch_tensorrt
6+
from torch_tensorrt.dynamo.runtime._WrapperTorchTensorRTModule import (
7+
WrapperTorchTensorRTModule,
8+
)
69

710
if torch_tensorrt.ENABLED_FEATURES.torch_tensorrt_runtime:
811
_PY_RT_CUDAGRAPHS = torch.ops.tensorrt.get_cudagraphs_mode()
@@ -37,19 +40,25 @@ class _CudagraphsContextManager(object):
3740
Used to enable cudagraphs as a context manager
3841
"""
3942

40-
def __init__(self) -> None:
43+
def __init__(self, module_to_wrap: Optional[torch.nn.Module]) -> None:
4144
global _PY_RT_CUDAGRAPHS
4245
self.old_mode = _PY_RT_CUDAGRAPHS
46+
self.module_to_wrap = module_to_wrap
4347

4448
def __enter__(self) -> "_CudagraphsContextManager":
4549
# Enable cudagraphs
4650
set_cudagraphs_mode(True)
47-
return self
51+
if self.module_to_wrap:
52+
return WrapperTorchTensorRTModule(self.module_to_wrap)
53+
else:
54+
return self
4855

4956
def __exit__(self, *args: Any) -> None:
5057
# Set cudagraphs back to old mode
5158
set_cudagraphs_mode(self.old_mode)
5259

5360

54-
def enable_cudagraphs() -> _CudagraphsContextManager:
55-
return _CudagraphsContextManager()
61+
def enable_cudagraphs(
62+
module_to_wrap: Optional[torch.nn.Module] = None,
63+
) -> _CudagraphsContextManager:
64+
return _CudagraphsContextManager(module_to_wrap)

tests/py/dynamo/runtime/test_005_wrapper_cudagraphs.py

+97
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,9 @@
22
import torch_tensorrt as torchtrt
33
from parameterized import parameterized
44
from torch.testing._internal.common_utils import TestCase, run_tests
5+
from torch_tensorrt.dynamo.runtime._WrapperTorchTensorRTModule import (
6+
WrapperTorchTensorRTModule,
7+
)
58

69
INPUT_SIZE = (3, 16, 16)
710
TRIALS = 5
@@ -197,6 +200,100 @@ def forward(self, x):
197200
)
198201
torch._dynamo.reset()
199202

203+
@parameterized.expand(
204+
[
205+
("python_runtime", True),
206+
("cpp_runtime", False),
207+
]
208+
)
209+
def test_wrapper_cudagraphs_api(self, _, use_python_runtime):
210+
"""
211+
3 api draft
212+
"""
213+
214+
class SampleModel(torch.nn.Module):
215+
def __init__(self):
216+
super().__init__()
217+
self.conv = torch.nn.Conv1d(64, 6, 3)
218+
self.relu = torch.nn.ReLU()
219+
220+
def forward(self, x):
221+
out = 1 + self.conv(x)
222+
out = self.relu(out)
223+
return out
224+
225+
model = SampleModel().eval().cuda()
226+
input_list = []
227+
trt_out_list = []
228+
ref_out_list = []
229+
230+
for _ in range(TRIALS):
231+
input = [torch.randn((64, 32), dtype=torch.float32).cuda()]
232+
input_list.append(input)
233+
fx_graph = torch.fx.symbolic_trace(model)
234+
235+
# 1. Compiler option: enable_wrapper_module=True
236+
optimized_model = torchtrt.compile(
237+
fx_graph,
238+
inputs=input_list[0],
239+
ir="dynamo",
240+
min_block_size=1,
241+
cache_built_engines=False,
242+
reuse_cached_engines=False,
243+
torch_executed_ops={"torch.ops.aten.convolution.default"},
244+
use_python_runtime=use_python_runtime,
245+
enable_wrapper_module=True,
246+
)
247+
248+
with torchtrt.runtime.enable_cudagraphs():
249+
for i in range(TRIALS):
250+
trt_out_list.append(optimized_model(*input_list[i]))
251+
ref_out_list.append(fx_graph(*input_list[i]))
252+
253+
# Compiler again to generate normal module
254+
optimized_model = torchtrt.compile(
255+
fx_graph,
256+
inputs=input_list[0],
257+
ir="dynamo",
258+
min_block_size=1,
259+
cache_built_engines=False,
260+
reuse_cached_engines=False,
261+
torch_executed_ops={"torch.ops.aten.convolution.default"},
262+
use_python_runtime=use_python_runtime,
263+
)
264+
# This is current cuda runtime api
265+
with torchtrt.runtime.enable_cudagraphs():
266+
for i in range(TRIALS):
267+
trt_out_list.append(optimized_model(*input_list[i]))
268+
ref_out_list.append(fx_graph(*input_list[i]))
269+
270+
# 2. Optional parameter in existing cuda runtime api
271+
# WrapperTorchTensorRTModule can be simplified to have only cuda graph path
272+
with torchtrt.runtime.enable_cudagraphs(optimized_model) as wrapped_module:
273+
for i in range(TRIALS):
274+
trt_out_list.append(wrapped_module(*input_list[i]))
275+
ref_out_list.append(fx_graph(*input_list[i]))
276+
277+
# 3. Use Wrapper module directly
278+
wrapped_module = WrapperTorchTensorRTModule(optimized_model)
279+
with torchtrt.runtime.enable_cudagraphs():
280+
for i in range(TRIALS):
281+
trt_out_list.append(wrapped_module(*input_list[i]))
282+
ref_out_list.append(fx_graph(*input_list[i]))
283+
284+
for optimized_model_results, torch_model_results in zip(
285+
trt_out_list, ref_out_list
286+
):
287+
torch.testing.assert_close(
288+
torch_model_results,
289+
optimized_model_results,
290+
rtol=5e-03,
291+
atol=5e-03,
292+
equal_nan=True,
293+
check_dtype=True,
294+
)
295+
torch._dynamo.reset()
296+
200297

201298
if __name__ == "__main__":
202299
run_tests()

0 commit comments

Comments
 (0)