|
| 1 | +from dataclasses import dataclass |
| 2 | +import traceback |
| 3 | +from typing import Callable, Dict, Tuple |
| 4 | +import torch |
| 5 | +from torch._custom_op import custom_op |
| 6 | +from torch.fx.node import Argument, Target |
| 7 | +import logging |
| 8 | + |
| 9 | +from torch_tensorrt.fx.converter_registry import tensorrt_converter |
| 10 | +from torch_tensorrt.fx.converters import acc_ops_converters |
| 11 | +from torch_tensorrt.fx.types import TRTNetwork, TRTTensor |
| 12 | + |
| 13 | +logger = logging.getLogger(__name__) |
| 14 | + |
| 15 | + |
| 16 | +@custom_op( |
| 17 | + "(Tensor x, int[1] kernel_size, int[1] stride=[], int[1] padding=[], int[1] dilation=[], bool ceil_mode=False) -> Tensor", |
| 18 | + ns="tensorrt", |
| 19 | +) |
| 20 | +def maxpool1d(x, kernel_size, stride=None, padding=0, dilation=1, ceil_mode=False): |
| 21 | + # Defines operator schema, name, namespace, and function header |
| 22 | + ... |
| 23 | + |
| 24 | + |
| 25 | +@maxpool1d.impl("cpu") |
| 26 | +@maxpool1d.impl("cuda") |
| 27 | +def maxpool1d_generic( |
| 28 | + *args, |
| 29 | + **kwargs, |
| 30 | +): |
| 31 | + # Defines a converter implementation for Autograd to use for shape analysis/propagation |
| 32 | + return torch.nn.functional.max_pool1d( |
| 33 | + *args, |
| 34 | + **kwargs, |
| 35 | + ) |
| 36 | + |
| 37 | + |
| 38 | +def maxpool1d_insertion_fn( |
| 39 | + gm: torch.fx.GraphModule, submodule: torch.nn.Module, node: torch.fx.Node |
| 40 | +) -> torch.fx.Node: |
| 41 | + # Defines insertion function for new node |
| 42 | + new_node = gm.graph.call_function( |
| 43 | + torch.ops.tensorrt.maxpool1d, |
| 44 | + args=node.args, |
| 45 | + kwargs={ |
| 46 | + "kernel_size": submodule.kernel_size, |
| 47 | + "stride": submodule.stride, |
| 48 | + "padding": submodule.padding, |
| 49 | + "dilation": submodule.dilation, |
| 50 | + "ceil_mode": submodule.ceil_mode, |
| 51 | + }, |
| 52 | + ) |
| 53 | + |
| 54 | + return new_node |
| 55 | + |
| 56 | + |
| 57 | +@tensorrt_converter(torch.ops.tensorrt.maxpool1d.default) |
| 58 | +def aten_ops_maxpool1d( |
| 59 | + network: TRTNetwork, |
| 60 | + target: Target, |
| 61 | + args: Tuple[Argument, ...], |
| 62 | + kwargs: Dict[str, Argument], |
| 63 | + name: str, |
| 64 | +) -> TRTTensor: |
| 65 | + # Defines converter replacing the default operator for this function |
| 66 | + kwargs_new = { |
| 67 | + "input": args[0], |
| 68 | + "kernel_size": args[1], |
| 69 | + "stride": args[2], |
| 70 | + "padding": args[3], |
| 71 | + "dilation": args[4], |
| 72 | + "ceil_mode": False if len(args) < 6 else args[5], |
| 73 | + } |
| 74 | + |
| 75 | + return acc_ops_converters.acc_ops_max_pool1d( |
| 76 | + network, target, None, kwargs_new, name |
| 77 | + ) |
| 78 | + |
| 79 | + |
| 80 | +@dataclass(frozen=True) |
| 81 | +class ModuleReplacement: |
| 82 | + """Class to store key functionality for module replacement""" |
| 83 | + |
| 84 | + # torch.ops.___ name for replacement function for module |
| 85 | + new_operator: torch._ops.OpOverload |
| 86 | + |
| 87 | + # Function taking a containing graph, a submodule, and a 'call_module' node and returning |
| 88 | + # a replacement node, with type 'call_function', or raising an Error if incompatibility is detected |
| 89 | + # Note: subgraph_insertion_fn should NOT delete nodes or recompile the graph |
| 90 | + subgraph_insertion_fn: Callable[ |
| 91 | + [torch.fx.GraphModule, torch.nn.Module, torch.fx.Node], torch.fx.Node |
| 92 | + ] |
| 93 | + |
| 94 | + |
| 95 | +# Dictionary mapping module to ModuleReplacement instance |
| 96 | +MODULE_SUBSTITUTION_REGISTRY: Dict[torch.nn.Module, ModuleReplacement] = { |
| 97 | + torch.nn.MaxPool1d: ModuleReplacement( |
| 98 | + new_operator=torch.ops.tensorrt.maxpool1d, |
| 99 | + subgraph_insertion_fn=maxpool1d_insertion_fn, |
| 100 | + ), |
| 101 | +} |
| 102 | + |
| 103 | + |
| 104 | +def pre_aot_module_replacement(gm: torch.fx.GraphModule): |
| 105 | + """Perform module-level graph replacement prior to AOT tracing |
| 106 | +
|
| 107 | + Args: |
| 108 | + gm: FX GraphModule to perform module replacement on |
| 109 | + Returns: |
| 110 | + torch.fx.GraphModule |
| 111 | +
|
| 112 | + """ |
| 113 | + # Ensure all parameters are in inference mode |
| 114 | + for param in gm.parameters(): |
| 115 | + param.requires_grad = False |
| 116 | + |
| 117 | + # Iterate over graph nodes, extracting module calls, to check for interceptions |
| 118 | + for n in gm.graph.nodes: |
| 119 | + if n.op == "call_module": |
| 120 | + # Extract submodule from graph |
| 121 | + submodule = gm.get_submodule(n.target) |
| 122 | + |
| 123 | + # If submodule is a member of the substitution registry, replace it |
| 124 | + if type(submodule) in MODULE_SUBSTITUTION_REGISTRY: |
| 125 | + |
| 126 | + try: |
| 127 | + replacement = MODULE_SUBSTITUTION_REGISTRY[type(submodule)] |
| 128 | + op, insertion_fn = ( |
| 129 | + replacement.new_operator, |
| 130 | + replacement.subgraph_insertion_fn, |
| 131 | + ) |
| 132 | + logger.debug( |
| 133 | + f"Replacing module of type {type(submodule)} with {op}" |
| 134 | + ) |
| 135 | + |
| 136 | + # Insert new node prior to older node |
| 137 | + with gm.graph.inserting_before(n): |
| 138 | + new_node = insertion_fn(gm, submodule, n) |
| 139 | + |
| 140 | + # If submodule is not a native torch.nn module, it must be manually excluded |
| 141 | + # from Dynamo tracing |
| 142 | + if not type(submodule).__module__.startswith("torch.nn"): |
| 143 | + torch._dynamo.allowed_functions._allowed_function_ids.add( |
| 144 | + id(type(submodule)) |
| 145 | + ) |
| 146 | + |
| 147 | + # Replace all original node uses and delete node |
| 148 | + n.replace_all_uses_with(new_node) |
| 149 | + gm.graph.eliminate_dead_code() |
| 150 | + gm.recompile() |
| 151 | + |
| 152 | + # A module replacement can fail in the event that the specific instance of the submodule cannot |
| 153 | + # be replaced |
| 154 | + except Exception: |
| 155 | + logger.debug( |
| 156 | + f"Encountered the following error while replacing {type(submodule)}" |
| 157 | + ) |
| 158 | + logger.debug(traceback.format_exc()) |
| 159 | + continue |
| 160 | + |
| 161 | + # Perform cleanup and recompilation before returning module |
| 162 | + gm.graph.eliminate_dead_code() |
| 163 | + gm.recompile() |
| 164 | + return gm |
0 commit comments