Skip to content

Commit 11886fe

Browse files
committed
chore: Rename to CudaGraphsTorchTensorRTModule class
1 parent 5edf79a commit 11886fe

File tree

6 files changed

+13
-329
lines changed

6 files changed

+13
-329
lines changed

Diff for: examples/dynamo/cudagraphs_wrapper_example.py

-111
This file was deleted.

Diff for: py/torch_tensorrt/_compile.py

+4-4
Original file line numberDiff line numberDiff line change
@@ -12,8 +12,8 @@
1212
from torch_tensorrt._features import ENABLED_FEATURES
1313
from torch_tensorrt._Input import Input
1414
from torch_tensorrt.dynamo import _defaults
15-
from torch_tensorrt.dynamo.runtime._WrapperTorchTensorRTModule import (
16-
WrapperTorchTensorRTModule,
15+
from torch_tensorrt.dynamo.runtime._CudaGraphsTorchTensorRTModule import (
16+
CudaGraphsTorchTensorRTModule,
1717
)
1818
from torch_tensorrt.fx import InputTensorSpec
1919
from torch_tensorrt.fx.lower import compile as fx_compile
@@ -589,15 +589,15 @@ def save(
589589
Save the model to disk in the specified output format.
590590
591591
Arguments:
592-
module (Optional(torch.jit.ScriptModule | torch.export.ExportedProgram | torch.fx.GraphModule | WrapperTorchTensorRTModule)): Compiled Torch-TensorRT module
592+
module (Optional(torch.jit.ScriptModule | torch.export.ExportedProgram | torch.fx.GraphModule | CudaGraphsTorchTensorRTModule)): Compiled Torch-TensorRT module
593593
inputs (torch.Tensor): Torch input tensors
594594
arg_inputs (Tuple[Any, ...]): Same as inputs. Alias for better understanding with kwarg_inputs.
595595
kwarg_inputs (dict[Any, ...]): Optional, kwarg inputs to the module forward function.
596596
output_format (str): Format to save the model. Options include exported_program | torchscript.
597597
retrace (bool): When the module type is a fx.GraphModule, this option re-exports the graph using torch.export.export(strict=False) to save it.
598598
This flag is experimental for now.
599599
"""
600-
if isinstance(module, WrapperTorchTensorRTModule):
600+
if isinstance(module, CudaGraphsTorchTensorRTModule):
601601
module = module.compiled_module
602602
module_type = _parse_module_type(module)
603603
accepted_formats = {"exported_program", "torchscript"}

Diff for: py/torch_tensorrt/dynamo/runtime/_WrapperTorchTensorRTModule.py renamed to py/torch_tensorrt/dynamo/runtime/_CudaGraphsTorchTensorRTModule.py

+2-2
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,7 @@
1111
logger = logging.getLogger(__name__)
1212

1313

14-
class WrapperTorchTensorRTModule(torch.nn.Module): # type: ignore[misc]
14+
class CudaGraphsTorchTensorRTModule(torch.nn.Module): # type: ignore[misc]
1515
"""This Wrapper runtime module is to record/replay whole cuda graph in sub modules
1616
1717
Args:
@@ -24,7 +24,7 @@ def __init__(
2424
self,
2525
compiled_module: torch.nn.Module,
2626
):
27-
super(WrapperTorchTensorRTModule, self).__init__()
27+
super(CudaGraphsTorchTensorRTModule, self).__init__()
2828
self.compiled_module = compiled_module
2929
self.inputs = partitioning.construct_submodule_inputs(compiled_module)
3030

Diff for: py/torch_tensorrt/runtime/_cudagraphs.py

+3-3
Original file line numberDiff line numberDiff line change
@@ -3,8 +3,8 @@
33

44
import torch
55
import torch_tensorrt
6-
from torch_tensorrt.dynamo.runtime._WrapperTorchTensorRTModule import (
7-
WrapperTorchTensorRTModule,
6+
from torch_tensorrt.dynamo.runtime._CudaGraphsTorchTensorRTModule import (
7+
CudaGraphsTorchTensorRTModule,
88
)
99

1010

@@ -90,7 +90,7 @@ def __enter__(self) -> torch.nn.Module:
9090
logger.debug(
9191
f"{num_torch_module} torch modules are in subgraphs. Using wrapper module for cuda graphs"
9292
)
93-
return WrapperTorchTensorRTModule(self.compiled_module)
93+
return CudaGraphsTorchTensorRTModule(self.compiled_module)
9494
else:
9595
if num_trt_module > 0:
9696
logger.debug(

Diff for: py/torch_tensorrt/runtime/_weight_streaming.py

+4-4
Original file line numberDiff line numberDiff line change
@@ -3,8 +3,8 @@
33

44
import torch
55
from torch_tensorrt.dynamo.runtime import PythonTorchTensorRTModule, TorchTensorRTModule
6-
from torch_tensorrt.dynamo.runtime._WrapperTorchTensorRTModule import (
7-
WrapperTorchTensorRTModule,
6+
from torch_tensorrt.dynamo.runtime._CudaGraphsTorchTensorRTModule import (
7+
CudaGraphsTorchTensorRTModule,
88
)
99

1010
logger = logging.getLogger(__name__)
@@ -16,12 +16,12 @@ class _WeightStreamingContextManager(object):
1616
"""
1717

1818
def __init__(
19-
self, module: torch.fx.GraphModule | WrapperTorchTensorRTModule
19+
self, module: torch.fx.GraphModule | CudaGraphsTorchTensorRTModule
2020
) -> None:
2121
rt_mods = []
2222
self.current_device_budget = 0
2323

24-
if isinstance(module, WrapperTorchTensorRTModule):
24+
if isinstance(module, CudaGraphsTorchTensorRTModule):
2525
module = module.compiled_module
2626
for name, rt_mod in module.named_children():
2727
if "_run_on_acc" in name and isinstance(

0 commit comments

Comments
 (0)