1
1
from dataclasses import dataclass
2
- import traceback
3
- from typing import Callable , Dict , Tuple
2
+ from typing import Any , Callable , Dict
4
3
import torch
5
- from torch ._custom_op import custom_op
6
- from torch .fx .node import Argument , Target
7
4
import logging
8
5
9
- from torch_tensorrt .fx .converter_registry import tensorrt_converter
10
- from torch_tensorrt .fx .converters import acc_ops_converters
11
- from torch_tensorrt .fx .types import TRTNetwork , TRTTensor
12
6
13
7
logger = logging .getLogger (__name__ )
14
8
15
9
16
- @custom_op (
17
- "(Tensor x, int[1] kernel_size, int[1] stride=[], int[1] padding=[], int[1] dilation=[], bool ceil_mode=False) -> Tensor" ,
18
- ns = "tensorrt" ,
19
- )
20
- def maxpool1d (x , kernel_size , stride = None , padding = 0 , dilation = 1 , ceil_mode = False ):
21
- # Defines operator schema, name, namespace, and function header
22
- ...
23
-
24
-
25
- @maxpool1d .impl ("cpu" )
26
- @maxpool1d .impl ("cuda" )
27
- def maxpool1d_generic (
28
- * args ,
29
- ** kwargs ,
30
- ):
31
- # Defines a converter implementation for Autograd to use for shape analysis/propagation
32
- return torch .nn .functional .max_pool1d (
33
- * args ,
34
- ** kwargs ,
35
- )
36
-
37
-
38
- def maxpool1d_insertion_fn (
39
- gm : torch .fx .GraphModule , submodule : torch .nn .Module , node : torch .fx .Node
40
- ) -> torch .fx .Node :
41
- # Defines insertion function for new node
42
- new_node = gm .graph .call_function (
43
- torch .ops .tensorrt .maxpool1d ,
44
- args = node .args ,
45
- kwargs = {
46
- "kernel_size" : submodule .kernel_size ,
47
- "stride" : submodule .stride ,
48
- "padding" : submodule .padding ,
49
- "dilation" : submodule .dilation ,
50
- "ceil_mode" : submodule .ceil_mode ,
51
- },
52
- )
53
-
54
- return new_node
55
-
56
-
57
- @tensorrt_converter (torch .ops .tensorrt .maxpool1d .default )
58
- def aten_ops_maxpool1d (
59
- network : TRTNetwork ,
60
- target : Target ,
61
- args : Tuple [Argument , ...],
62
- kwargs : Dict [str , Argument ],
63
- name : str ,
64
- ) -> TRTTensor :
65
- # Defines converter replacing the default operator for this function
66
- kwargs_new = {
67
- "input" : args [0 ],
68
- "kernel_size" : args [1 ],
69
- "stride" : args [2 ],
70
- "padding" : args [3 ],
71
- "dilation" : args [4 ],
72
- "ceil_mode" : False if len (args ) < 6 else args [5 ],
73
- }
74
-
75
- return acc_ops_converters .acc_ops_max_pool1d (
76
- network , target , None , kwargs_new , name
77
- )
78
-
79
-
80
10
@dataclass (frozen = True )
81
11
class ModuleReplacement :
82
12
"""Class to store key functionality for module replacement"""
@@ -93,12 +23,37 @@ class ModuleReplacement:
93
23
94
24
95
25
# Dictionary mapping module to ModuleReplacement instance
96
- MODULE_SUBSTITUTION_REGISTRY : Dict [torch .nn .Module , ModuleReplacement ] = {
97
- torch .nn .MaxPool1d : ModuleReplacement (
98
- new_operator = torch .ops .tensorrt .maxpool1d ,
99
- subgraph_insertion_fn = maxpool1d_insertion_fn ,
100
- ),
101
- }
26
+ MODULE_SUBSTITUTION_REGISTRY : Dict [torch .nn .Module , ModuleReplacement ] = dict ()
27
+
28
+
29
+ def module_substitution (
30
+ module_to_replace : torch .nn .Module ,
31
+ new_operator : torch ._ops .OpOverload ,
32
+ enabled : bool = True ,
33
+ ) -> Callable [[Any ], Any ]:
34
+ """Decorator to register subgraph insertion functions
35
+
36
+ Args:
37
+ module_to_replace: nn.Module to replace
38
+ new_operator: Custom torch operator to replace with
39
+ enabled: Whether the substitution is enabled or disabled
40
+ Returns:
41
+ torch.fx.GraphModule
42
+ """
43
+
44
+ def register_substitution (subgraph_insertion_fn ):
45
+ """Function for use if substitution is enabled"""
46
+ module_replacement = ModuleReplacement (
47
+ new_operator = new_operator , subgraph_insertion_fn = subgraph_insertion_fn
48
+ )
49
+ MODULE_SUBSTITUTION_REGISTRY [module_to_replace ] = module_replacement
50
+ return subgraph_insertion_fn
51
+
52
+ def disable_substitution (subgraph_insertion_fn ):
53
+ """Function for use if substitution is disabled"""
54
+ return subgraph_insertion_fn
55
+
56
+ return register_substitution if enabled else disable_substitution
102
57
103
58
104
59
def pre_aot_module_replacement (gm : torch .fx .GraphModule ):
@@ -144,7 +99,7 @@ def pre_aot_module_replacement(gm: torch.fx.GraphModule):
144
99
id (type (submodule ))
145
100
)
146
101
147
- # Replace all original node uses and delete node
102
+ # Replace all original node uses and clean up graph
148
103
n .replace_all_uses_with (new_node )
149
104
gm .graph .eliminate_dead_code ()
150
105
gm .recompile ()
@@ -153,9 +108,9 @@ def pre_aot_module_replacement(gm: torch.fx.GraphModule):
153
108
# be replaced
154
109
except Exception :
155
110
logger .debug (
156
- f"Encountered the following error while replacing { type (submodule )} "
111
+ f"Encountered error while replacing { type (submodule )} " ,
112
+ exc_info = True ,
157
113
)
158
- logger .debug (traceback .format_exc ())
159
114
continue
160
115
161
116
# Perform cleanup and recompilation before returning module
0 commit comments