Skip to content

Commit ee05b59

Browse files
committed
feat: Prototype Module-Acceleration in Dynamo
- Add support for excluding entire Torch modules from tracing in Dynamo using Torch custom operators - Develop new dataclass to store required replacement functions and operators in a streamlined way - Add new registry to store mapping between replacement operators and their corresponding dataclass - Add documentation for easy additions of new module-level exclusion operators
1 parent e109049 commit ee05b59

File tree

4 files changed

+180
-1
lines changed

4 files changed

+180
-1
lines changed

.circleci/config.yml

+1-1
Original file line numberDiff line numberDiff line change
@@ -258,7 +258,7 @@ commands:
258258
name: Set up python environment
259259
command: |
260260
pip3 install --upgrade pip
261-
pip3 install wheel setuptools
261+
pip3 install wheel setuptools pyyaml
262262
pip3 install nvidia-pyindex
263263
pip3 install tabulate
264264
pip3 install tensorrt==<< parameters.trt-version-long >> nvidia-cudnn-cu11==<< parameters.cudnn-version-long >>

py/torch_tensorrt/dynamo/backend/__init__.py

+3
Original file line numberDiff line numberDiff line change
@@ -50,6 +50,9 @@ def compile(
5050
if debug:
5151
logger.setLevel(logging.DEBUG)
5252

53+
if debug:
54+
logger.setLevel(logging.DEBUG)
55+
5356
logger.warn(
5457
"The Dynamo backend is an experimental feature, for which only the "
5558
+ "following arguments are supported: "

py/torch_tensorrt/dynamo/backend/backends.py

+12
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,9 @@
88
from torch_tensorrt.dynamo.backend.lowering._decompositions import (
99
get_decompositions,
1010
)
11+
from torch_tensorrt.dynamo.backend.lowering._pre_aot_lowering import (
12+
pre_aot_module_replacement,
13+
)
1114
from torch_tensorrt.dynamo.backend.lowering._partition import (
1215
partition,
1316
get_submod_inputs,
@@ -46,6 +49,13 @@ def aot_torch_tensorrt_aten_backend(
4649
settings=settings,
4750
)
4851

52+
logger.debug("Pre-module replacement graph:\n" + str(gm.graph))
53+
54+
# Enable Pre-AOT Lowering for Module-Level Replacement
55+
gm = pre_aot_module_replacement(gm)
56+
57+
logger.debug("Post-module replacement graph:\n" + str(gm.graph))
58+
4959
# Invoke AOTAutograd to translate operators to aten
5060
return aot_module_simplified(
5161
gm,
@@ -71,6 +81,8 @@ def _pretraced_backend(
7181
Compiled FX GraphModule
7282
"""
7383
try:
84+
logger.debug("Post-AOT Autograd graph:\n" + str(gm.graph))
85+
7486
trt_compiled = _compile_module(
7587
gm,
7688
sample_inputs,
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,164 @@
1+
from dataclasses import dataclass
2+
import traceback
3+
from typing import Callable, Dict, Tuple
4+
import torch
5+
from torch._custom_op import custom_op
6+
from torch.fx.node import Argument, Target
7+
import logging
8+
9+
from torch_tensorrt.fx.converter_registry import tensorrt_converter
10+
from torch_tensorrt.fx.converters import acc_ops_converters
11+
from torch_tensorrt.fx.types import TRTNetwork, TRTTensor
12+
13+
logger = logging.getLogger(__name__)
14+
15+
16+
@custom_op(
17+
"(Tensor x, int[1] kernel_size, int[1] stride=[], int[1] padding=[], int[1] dilation=[], bool ceil_mode=False) -> Tensor",
18+
ns="tensorrt",
19+
)
20+
def maxpool1d(x, kernel_size, stride=None, padding=0, dilation=1, ceil_mode=False):
21+
# Defines operator schema, name, namespace, and function header
22+
...
23+
24+
25+
@maxpool1d.impl("cpu")
26+
@maxpool1d.impl("cuda")
27+
def maxpool1d_generic(
28+
*args,
29+
**kwargs,
30+
):
31+
# Defines a converter implementation for Autograd to use for shape analysis/propagation
32+
return torch.nn.functional.max_pool1d(
33+
*args,
34+
**kwargs,
35+
)
36+
37+
38+
def maxpool1d_insertion_fn(
39+
gm: torch.fx.GraphModule, submodule: torch.nn.Module, node: torch.fx.Node
40+
) -> torch.fx.Node:
41+
# Defines insertion function for new node
42+
new_node = gm.graph.call_function(
43+
torch.ops.tensorrt.maxpool1d,
44+
args=node.args,
45+
kwargs={
46+
"kernel_size": submodule.kernel_size,
47+
"stride": submodule.stride,
48+
"padding": submodule.padding,
49+
"dilation": submodule.dilation,
50+
"ceil_mode": submodule.ceil_mode,
51+
},
52+
)
53+
54+
return new_node
55+
56+
57+
@tensorrt_converter(torch.ops.tensorrt.maxpool1d.default)
58+
def aten_ops_maxpool1d(
59+
network: TRTNetwork,
60+
target: Target,
61+
args: Tuple[Argument, ...],
62+
kwargs: Dict[str, Argument],
63+
name: str,
64+
) -> TRTTensor:
65+
# Defines converter replacing the default operator for this function
66+
kwargs_new = {
67+
"input": args[0],
68+
"kernel_size": args[1],
69+
"stride": args[2],
70+
"padding": args[3],
71+
"dilation": args[4],
72+
"ceil_mode": False if len(args) < 6 else args[5],
73+
}
74+
75+
return acc_ops_converters.acc_ops_max_pool1d(
76+
network, target, None, kwargs_new, name
77+
)
78+
79+
80+
@dataclass(frozen=True)
81+
class ModuleReplacement:
82+
"""Class to store key functionality for module replacement"""
83+
84+
# torch.ops.___ name for replacement function for module
85+
new_operator: torch._ops.OpOverload
86+
87+
# Function taking a containing graph, a submodule, and a 'call_module' node and returning
88+
# a replacement node, with type 'call_function', or raising an Error if incompatibility is detected
89+
# Note: subgraph_insertion_fn should NOT delete nodes or recompile the graph
90+
subgraph_insertion_fn: Callable[
91+
[torch.fx.GraphModule, torch.nn.Module, torch.fx.Node], torch.fx.Node
92+
]
93+
94+
95+
# Dictionary mapping module to ModuleReplacement instance
96+
MODULE_SUBSTITUTION_REGISTRY: Dict[torch.nn.Module, ModuleReplacement] = {
97+
torch.nn.MaxPool1d: ModuleReplacement(
98+
new_operator=torch.ops.tensorrt.maxpool1d,
99+
subgraph_insertion_fn=maxpool1d_insertion_fn,
100+
),
101+
}
102+
103+
104+
def pre_aot_module_replacement(gm: torch.fx.GraphModule):
105+
"""Perform module-level graph replacement prior to AOT tracing
106+
107+
Args:
108+
gm: FX GraphModule to perform module replacement on
109+
Returns:
110+
torch.fx.GraphModule
111+
112+
"""
113+
# Ensure all parameters are in inference mode
114+
for param in gm.parameters():
115+
param.requires_grad = False
116+
117+
# Iterate over graph nodes, extracting module calls, to check for interceptions
118+
for n in gm.graph.nodes:
119+
if n.op == "call_module":
120+
# Extract submodule from graph
121+
submodule = gm.get_submodule(n.target)
122+
123+
# If submodule is a member of the substitution registry, replace it
124+
if type(submodule) in MODULE_SUBSTITUTION_REGISTRY:
125+
126+
try:
127+
replacement = MODULE_SUBSTITUTION_REGISTRY[type(submodule)]
128+
op, insertion_fn = (
129+
replacement.new_operator,
130+
replacement.subgraph_insertion_fn,
131+
)
132+
logger.debug(
133+
f"Replacing module of type {type(submodule)} with {op}"
134+
)
135+
136+
# Insert new node prior to older node
137+
with gm.graph.inserting_before(n):
138+
new_node = insertion_fn(gm, submodule, n)
139+
140+
# If submodule is not a native torch.nn module, it must be manually excluded
141+
# from Dynamo tracing
142+
if not type(submodule).__module__.startswith("torch.nn"):
143+
torch._dynamo.allowed_functions._allowed_function_ids.add(
144+
id(type(submodule))
145+
)
146+
147+
# Replace all original node uses and delete node
148+
n.replace_all_uses_with(new_node)
149+
gm.graph.eliminate_dead_code()
150+
gm.recompile()
151+
152+
# A module replacement can fail in the event that the specific instance of the submodule cannot
153+
# be replaced
154+
except Exception:
155+
logger.debug(
156+
f"Encountered the following error while replacing {type(submodule)}"
157+
)
158+
logger.debug(traceback.format_exc())
159+
continue
160+
161+
# Perform cleanup and recompilation before returning module
162+
gm.graph.eliminate_dead_code()
163+
gm.recompile()
164+
return gm

0 commit comments

Comments
 (0)