Skip to content

Commit c3c266d

Browse files
committed
fix: Refactor code and add testing
1 parent 70530a3 commit c3c266d

File tree

7 files changed

+187
-86
lines changed

7 files changed

+187
-86
lines changed
+6-4
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,9 @@
1-
from torch_tensorrt.dynamo.backend.lowering._decompositions import (
1+
from ._decompositions import (
22
get_decompositions,
33
)
4-
from torch_tensorrt.dynamo.backend.lowering._partition import (
5-
partition,
6-
get_submod_inputs,
4+
from ._pre_aot_lowering import (
5+
MODULE_SUBSTITUTION_REGISTRY,
6+
module_substitution,
77
)
8+
from ._partition import partition, get_submod_inputs, DEFAULT_SINGLE_NODE_PARTITIONS
9+
from .module_substitutions import *

Diff for: py/torch_tensorrt/dynamo/backend/lowering/_partition.py

+10-2
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,10 @@
11
import logging
2-
from typing import Dict, List, Optional, Sequence
2+
from typing import Dict, List, Optional, Sequence, Set
33

44
import torch
55

66
from torch_tensorrt.dynamo.backend._defaults import MIN_BLOCK_SIZE
7+
from torch_tensorrt.dynamo.backend.lowering import MODULE_SUBSTITUTION_REGISTRY
78
from torch.fx.passes.infra.partitioner import CapabilityBasedPartitioner, Partition
89
from torch.fx.graph_module import GraphModule
910
from torch.fx.node import _get_qualified_name
@@ -14,6 +15,11 @@
1415

1516
logger = logging.getLogger(__name__)
1617

18+
DEFAULT_SINGLE_NODE_PARTITIONS: Set[str] = set(
19+
"torch.ops." + str(module.new_operator)
20+
for module in MODULE_SUBSTITUTION_REGISTRY.values()
21+
)
22+
1723

1824
class TRTPartitioner(CapabilityBasedPartitioner):
1925
"""Partitioner to split an FX graph into subgraphs based on operator support
@@ -35,7 +41,9 @@ def __init__(
3541
operator_support: OperatorSupport,
3642
*,
3743
non_compute_ops: Optional[Sequence[str]] = None,
38-
allowed_single_node_partition_ops: Optional[Sequence[str]] = None,
44+
allowed_single_node_partition_ops: Optional[
45+
Sequence[str]
46+
] = DEFAULT_SINGLE_NODE_PARTITIONS,
3947
min_block_size=MIN_BLOCK_SIZE,
4048
) -> None:
4149
super().__init__(

Diff for: py/torch_tensorrt/dynamo/backend/lowering/_pre_aot_lowering.py

+35-80
Original file line numberDiff line numberDiff line change
@@ -1,82 +1,12 @@
11
from dataclasses import dataclass
2-
import traceback
3-
from typing import Callable, Dict, Tuple
2+
from typing import Any, Callable, Dict
43
import torch
5-
from torch._custom_op import custom_op
6-
from torch.fx.node import Argument, Target
74
import logging
85

9-
from torch_tensorrt.fx.converter_registry import tensorrt_converter
10-
from torch_tensorrt.fx.converters import acc_ops_converters
11-
from torch_tensorrt.fx.types import TRTNetwork, TRTTensor
126

137
logger = logging.getLogger(__name__)
148

159

16-
@custom_op(
17-
"(Tensor x, int[1] kernel_size, int[1] stride=[], int[1] padding=[], int[1] dilation=[], bool ceil_mode=False) -> Tensor",
18-
ns="tensorrt",
19-
)
20-
def maxpool1d(x, kernel_size, stride=None, padding=0, dilation=1, ceil_mode=False):
21-
# Defines operator schema, name, namespace, and function header
22-
...
23-
24-
25-
@maxpool1d.impl("cpu")
26-
@maxpool1d.impl("cuda")
27-
def maxpool1d_generic(
28-
*args,
29-
**kwargs,
30-
):
31-
# Defines a converter implementation for Autograd to use for shape analysis/propagation
32-
return torch.nn.functional.max_pool1d(
33-
*args,
34-
**kwargs,
35-
)
36-
37-
38-
def maxpool1d_insertion_fn(
39-
gm: torch.fx.GraphModule, submodule: torch.nn.Module, node: torch.fx.Node
40-
) -> torch.fx.Node:
41-
# Defines insertion function for new node
42-
new_node = gm.graph.call_function(
43-
torch.ops.tensorrt.maxpool1d,
44-
args=node.args,
45-
kwargs={
46-
"kernel_size": submodule.kernel_size,
47-
"stride": submodule.stride,
48-
"padding": submodule.padding,
49-
"dilation": submodule.dilation,
50-
"ceil_mode": submodule.ceil_mode,
51-
},
52-
)
53-
54-
return new_node
55-
56-
57-
@tensorrt_converter(torch.ops.tensorrt.maxpool1d.default)
58-
def aten_ops_maxpool1d(
59-
network: TRTNetwork,
60-
target: Target,
61-
args: Tuple[Argument, ...],
62-
kwargs: Dict[str, Argument],
63-
name: str,
64-
) -> TRTTensor:
65-
# Defines converter replacing the default operator for this function
66-
kwargs_new = {
67-
"input": args[0],
68-
"kernel_size": args[1],
69-
"stride": args[2],
70-
"padding": args[3],
71-
"dilation": args[4],
72-
"ceil_mode": False if len(args) < 6 else args[5],
73-
}
74-
75-
return acc_ops_converters.acc_ops_max_pool1d(
76-
network, target, None, kwargs_new, name
77-
)
78-
79-
8010
@dataclass(frozen=True)
8111
class ModuleReplacement:
8212
"""Class to store key functionality for module replacement"""
@@ -93,12 +23,37 @@ class ModuleReplacement:
9323

9424

9525
# Dictionary mapping module to ModuleReplacement instance
96-
MODULE_SUBSTITUTION_REGISTRY: Dict[torch.nn.Module, ModuleReplacement] = {
97-
torch.nn.MaxPool1d: ModuleReplacement(
98-
new_operator=torch.ops.tensorrt.maxpool1d,
99-
subgraph_insertion_fn=maxpool1d_insertion_fn,
100-
),
101-
}
26+
MODULE_SUBSTITUTION_REGISTRY: Dict[torch.nn.Module, ModuleReplacement] = dict()
27+
28+
29+
def module_substitution(
30+
module_to_replace: torch.nn.Module,
31+
new_operator: torch._ops.OpOverload,
32+
enabled: bool = True,
33+
) -> Callable[[Any], Any]:
34+
"""Decorator to register subgraph insertion functions
35+
36+
Args:
37+
module_to_replace: nn.Module to replace
38+
new_operator: Custom torch operator to replace with
39+
enabled: Whether the substitution is enabled or disabled
40+
Returns:
41+
torch.fx.GraphModule
42+
"""
43+
44+
def register_substitution(subgraph_insertion_fn):
45+
"""Function for use if substitution is enabled"""
46+
module_replacement = ModuleReplacement(
47+
new_operator=new_operator, subgraph_insertion_fn=subgraph_insertion_fn
48+
)
49+
MODULE_SUBSTITUTION_REGISTRY[module_to_replace] = module_replacement
50+
return subgraph_insertion_fn
51+
52+
def disable_substitution(subgraph_insertion_fn):
53+
"""Function for use if substitution is disabled"""
54+
return subgraph_insertion_fn
55+
56+
return register_substitution if enabled else disable_substitution
10257

10358

10459
def pre_aot_module_replacement(gm: torch.fx.GraphModule):
@@ -144,7 +99,7 @@ def pre_aot_module_replacement(gm: torch.fx.GraphModule):
14499
id(type(submodule))
145100
)
146101

147-
# Replace all original node uses and delete node
102+
# Replace all original node uses and clean up graph
148103
n.replace_all_uses_with(new_node)
149104
gm.graph.eliminate_dead_code()
150105
gm.recompile()
@@ -153,9 +108,9 @@ def pre_aot_module_replacement(gm: torch.fx.GraphModule):
153108
# be replaced
154109
except Exception:
155110
logger.debug(
156-
f"Encountered the following error while replacing {type(submodule)}"
111+
f"Encountered error while replacing {type(submodule)}",
112+
exc_info=True,
157113
)
158-
logger.debug(traceback.format_exc())
159114
continue
160115

161116
# Perform cleanup and recompilation before returning module
Original file line numberDiff line numberDiff line change
@@ -0,0 +1 @@
1+
from .maxpool1d import *
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,75 @@
1+
from typing import Dict, Tuple
2+
import torch
3+
from torch._custom_op import custom_op
4+
from torch.fx.node import Argument, Target
5+
6+
from torch_tensorrt.fx.converter_registry import tensorrt_converter
7+
from torch_tensorrt.fx.converters import acc_ops_converters
8+
from torch_tensorrt.fx.types import TRTNetwork, TRTTensor
9+
10+
from torch_tensorrt.dynamo.backend.lowering import module_substitution
11+
12+
13+
@custom_op(
14+
"(Tensor x, int[1] kernel_size, int[1] stride=[], int[1] padding=[], int[1] dilation=[], bool ceil_mode=False) -> Tensor",
15+
ns="tensorrt",
16+
)
17+
def maxpool1d(x, kernel_size, stride=None, padding=0, dilation=1, ceil_mode=False):
18+
# Defines operator schema, name, namespace, and function header
19+
...
20+
21+
22+
@maxpool1d.impl("cpu")
23+
@maxpool1d.impl("cuda")
24+
def maxpool1d_generic(
25+
*args,
26+
**kwargs,
27+
):
28+
# Defines a converter implementation for AOT Autograd to use for shape analysis/propagation
29+
return torch.nn.functional.max_pool1d(
30+
*args,
31+
**kwargs,
32+
)
33+
34+
35+
@tensorrt_converter(torch.ops.tensorrt.maxpool1d.default)
36+
def aten_ops_maxpool1d(
37+
network: TRTNetwork,
38+
target: Target,
39+
args: Tuple[Argument, ...],
40+
kwargs: Dict[str, Argument],
41+
name: str,
42+
) -> TRTTensor:
43+
# Defines converter replacing the default operator for this function
44+
kwargs_new = {
45+
"input": args[0],
46+
"kernel_size": args[1],
47+
"stride": args[2],
48+
"padding": args[3],
49+
"dilation": args[4],
50+
"ceil_mode": False if len(args) < 6 else args[5],
51+
}
52+
53+
return acc_ops_converters.acc_ops_max_pool1d(
54+
network, target, None, kwargs_new, name
55+
)
56+
57+
58+
@module_substitution(torch.nn.MaxPool1d, torch.ops.tensorrt.maxpool1d)
59+
def maxpool1d_insertion_fn(
60+
gm: torch.fx.GraphModule, submodule: torch.nn.Module, node: torch.fx.Node
61+
) -> torch.fx.Node:
62+
# Defines insertion function for new node
63+
new_node = gm.graph.call_function(
64+
torch.ops.tensorrt.maxpool1d,
65+
args=node.args,
66+
kwargs={
67+
"kernel_size": submodule.kernel_size,
68+
"stride": submodule.stride,
69+
"padding": submodule.padding,
70+
"dilation": submodule.dilation,
71+
"ceil_mode": submodule.ceil_mode,
72+
},
73+
)
74+
75+
return new_node
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,55 @@
1+
import torch
2+
from utils import lower_graph_testing
3+
from torch.testing._internal.common_utils import run_tests, TestCase
4+
from torch_tensorrt.dynamo import compile
5+
6+
7+
class TestMaxPool1D(TestCase):
8+
def test_pre_aot_lowering_maxpool1d(self):
9+
class MaxPool1D(torch.nn.Module):
10+
def __init__(self, *args, **kwargs) -> None:
11+
super().__init__(*args, **kwargs)
12+
self.maxpool = torch.nn.MaxPool1d(2)
13+
14+
def forward(self, x):
15+
return self.maxpool(x)
16+
17+
# Operations expected to be included in the traced graph after decompositions
18+
expected_ops = {torch.ops.tensorrt.maxpool1d.default}
19+
20+
inputs = [
21+
torch.rand(
22+
9,
23+
16,
24+
2,
25+
),
26+
]
27+
28+
fx_graph = torch.fx.symbolic_trace(MaxPool1D())
29+
_, expected_ops_unseen = lower_graph_testing(
30+
fx_graph, inputs, expected_ops=expected_ops, min_block_size=1
31+
)
32+
33+
self.assertEquals(
34+
len(expected_ops_unseen),
35+
0,
36+
f"The following expected ops were not encountered: {expected_ops_unseen}",
37+
)
38+
39+
torch._dynamo.reset()
40+
41+
# Validate that the results between Torch and Torch-TRT are similar
42+
optimized_model = compile(
43+
fx_graph, inputs, min_block_size=1, pass_through_build_failures=True
44+
)
45+
optimized_model_results = optimized_model(*inputs).detach().cpu()
46+
torch_model_results = fx_graph(*inputs).detach().cpu()
47+
48+
max_diff = torch.max(torch.abs(optimized_model_results - torch_model_results))
49+
self.assertAlmostEqual(
50+
max_diff, 0, f"Maxpool1d TRT outputs don't match with the original model."
51+
)
52+
53+
54+
if __name__ == "__main__":
55+
run_tests()

Diff for: py/torch_tensorrt/dynamo/backend/test/utils.py

+5
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,9 @@
88
from torch_tensorrt.dynamo.backend.lowering._partition import (
99
partition,
1010
)
11+
from torch_tensorrt.dynamo.backend.lowering._pre_aot_lowering import (
12+
pre_aot_module_replacement,
13+
)
1114

1215
from torch._dynamo.backends.common import fake_tensor_unsupported
1316

@@ -31,6 +34,8 @@ def fx_dynamo_testing_backend(
3134
torch_executed_ops=torch_executed_ops,
3235
)
3336

37+
gm = pre_aot_module_replacement(gm)
38+
3439
# Invoke AOTAutograd to translate operators to aten
3540
return aot_module_simplified(
3641
gm,

0 commit comments

Comments
 (0)