Skip to content

Commit f092bbf

Browse files
committed
chore: Set default return type to ExportedProgram
chore: Add output_format flag chore: updates chore: additional fixes chore: add break
1 parent 4b608f0 commit f092bbf

File tree

8 files changed

+211
-153
lines changed

8 files changed

+211
-153
lines changed

Diff for: .github/workflows/build-test.yml

+1
Original file line numberDiff line numberDiff line change
@@ -142,6 +142,7 @@ jobs:
142142
${CONDA_RUN} python -m pip install --pre pytest timm transformers parameterized expecttest==0.1.6 --use-deprecated=legacy-resolver
143143
${CONDA_RUN} python -m pytest --junitxml=${RUNNER_TEST_RESULTS_DIR}/dynamo_fe_test_results.xml --ir dynamo models/test_models_export.py
144144
${CONDA_RUN} python -m pytest --junitxml=${RUNNER_TEST_RESULTS_DIR}/dyn_models_export.xml --ir dynamo models/test_dyn_models.py
145+
${CONDA_RUN} python -m pytest --junitxml=${RUNNER_TEST_RESULTS_DIR}/output_format.xml --ir dynamo models/test_output_format.py
145146
popd
146147
147148
tests-py-dynamo-serde:

Diff for: docsrc/user_guide/saving_models.rst

+31-18
Original file line numberDiff line numberDiff line change
@@ -14,14 +14,18 @@ Saving models compiled with Torch-TensorRT varies slightly with the `ir` that ha
1414
Dynamo IR
1515
-------------
1616

17-
Starting with 2.1 release of Torch-TensorRT, we are switching the default compilation to be dynamo based.
18-
The output of `ir=dynamo` compilation is a `torch.fx.GraphModule` object. There are two ways to save these objects
17+
The output type of `ir=dynamo` compilation of Torch-TensorRT is `torch.export.ExportedProgram` object by default.
18+
In addition, we provide a new parameter `output_format` in the `CompilationSetting` object provided before compilation.
19+
The `output_format` can take the following options
1920

20-
a) Converting to Torchscript
21+
* `exported_program` (or) `ep` : This is the default. Returns an ExportedProgram
22+
* `torchscript` (or) `ts` : This returns a TorchScript module
23+
* `graph_module` (or) `fx` : This returns a torch.fx.GraphModule which can be traced into Torchscript to save to disk.
24+
25+
a) Torchscript
2126
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
2227

23-
`torch.fx.GraphModule` objects cannot be serialized directly. Hence we use `torch.jit.trace` to convert this into a `ScriptModule` object which can be saved to disk.
24-
The following code illustrates this approach.
28+
If you set the `output_format="torchscript"`, this will return a `ScriptModule` which can be serialized via torch.jit.save
2529

2630
.. code-block:: python
2731
@@ -30,9 +34,9 @@ The following code illustrates this approach.
3034
3135
model = MyModel().eval().cuda()
3236
inputs = [torch.randn((1, 3, 224, 224)).cuda()]
33-
trt_gm = torch_tensorrt.compile(model, ir="dynamo", inputs) # Output is a torch.fx.GraphModule
34-
trt_traced_model = torch.jit.trace(trt_gm, inputs)
35-
torch.jit.save(trt_traced_model, "trt_model.ts")
37+
# trt_ts is a torch.jit.ScriptModule object
38+
trt_ts = torch_tensorrt.compile(model, ir="dynamo", inputs, output_format="torchscript")
39+
torch.jit.save(trt_ts, "trt_model.ts")
3640
3741
# Later, you can load it and run inference
3842
model = torch.jit.load("trt_model.ts").cuda()
@@ -41,8 +45,7 @@ The following code illustrates this approach.
4145
b) ExportedProgram
4246
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
4347

44-
`torch.export.ExportedProgram` is a new format introduced in Pytorch 2.1. After we compile a Pytorch module using Torch-TensorRT, the resultant
45-
`torch.fx.GraphModule` along with additional metadata can be used to create `ExportedProgram` which can be saved and loaded from disk.
48+
`torch.export.ExportedProgram`, a new format introduced in Pytorch 2.X is the default return type of Torch-TensorRT compilation.
4649

4750
.. code-block:: python
4851
@@ -51,26 +54,36 @@ b) ExportedProgram
5154
5255
model = MyModel().eval().cuda()
5356
inputs = [torch.randn((1, 3, 224, 224)).cuda()]
54-
trt_gm = torch_tensorrt.compile(model, ir="dynamo", inputs) # Output is a torch.fx.GraphModule
55-
# Transform and create an exported program
56-
trt_exp_program = torch_tensorrt.dynamo.export(trt_gm, inputs)
57-
torch.export.save(trt_exp_program, "trt_model.ep")
57+
# trt_ep is a torch.export.ExportedProgram object
58+
trt_ep = torch_tensorrt.compile(model, ir="dynamo", inputs)
59+
torch.export.save(trt_ep, "trt_model.ep")
5860
5961
# Later, you can load it and run inference
6062
model = torch.export.load("trt_model.ep")
6163
model(*inputs)
6264
63-
`torch_tensorrt.dynamo.export` inlines the submodules within a GraphModule to their corresponding nodes and stiches all the nodes together.
64-
This is needed as `torch._export` serialization cannot handle serializing and deserializing of submodules (`call_module` nodes).
65+
c) GraphModule
66+
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
6567

66-
.. note:: This way of saving the models using `ExportedProgram` is experimental. Here is a known issue : https://github.com/pytorch/TensorRT/issues/2341
68+
We can also return a `torch.fx.GraphModule` object as the output of Torch-TensorRT compilation by setting `output_format="graph_module"`.
69+
Internally, partitioning, lowering, conversion phases operate using GraphModule objects. These can be either traced into a Torchscript modules or
70+
exported into `ExportedProgram` objects
6771

72+
.. code-block:: python
73+
74+
import torch
75+
import torch_tensorrt
76+
77+
model = MyModel().eval().cuda()
78+
inputs = [torch.randn((1, 3, 224, 224)).cuda()]
79+
# trt_gm is a torch.fx.GraphModule object
80+
trt_gm = torch_tensorrt.compile(model, ir="dynamo", inputs, output_format="graph_module")
6881
6982
Torchscript IR
7083
-------------
7184

7285
In Torch-TensorRT 1.X versions, the primary way to compile and run inference with Torch-TensorRT is using Torchscript IR.
73-
This behavior stays the same in 2.X versions as well.
86+
For `ir=ts`, this behavior stays the same in 2.X versions as well.
7487

7588
.. code-block:: python
7689

Diff for: py/torch_tensorrt/dynamo/_compiler.py

+9-3
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@
55
from typing import Any, Collection, List, Optional, Sequence, Set, Tuple, Union
66

77
import torch
8+
import torch_tensorrt
89
from torch.export import ExportedProgram
910
from torch.fx.node import Target
1011
from torch_tensorrt import _enums
@@ -29,6 +30,7 @@
2930
MIN_BLOCK_SIZE,
3031
NUM_AVG_TIMING_ITERS,
3132
OPTIMIZATION_LEVEL,
33+
OUTPUT_FORMAT,
3234
PASS_THROUGH_BUILD_FAILURES,
3335
PRECISION,
3436
REFIT,
@@ -46,6 +48,7 @@
4648
dryrun_stats_display,
4749
parse_non_trt_nodes,
4850
)
51+
from torch_tensorrt.dynamo._exporter import export
4952
from torch_tensorrt.dynamo.conversion import (
5053
CompilationSettings,
5154
UnsupportedOperatorException,
@@ -66,8 +69,6 @@
6669
to_torch_tensorrt_device,
6770
)
6871

69-
import torch_tensorrt
70-
7172
logger = logging.getLogger(__name__)
7273

7374

@@ -103,6 +104,7 @@ def compile(
103104
enable_experimental_decompositions: bool = ENABLE_EXPERIMENTAL_DECOMPOSITIONS,
104105
dryrun: bool = DRYRUN,
105106
hardware_compatible: bool = HARDWARE_COMPATIBLE,
107+
output_format: str = OUTPUT_FORMAT,
106108
**kwargs: Any,
107109
) -> torch.fx.GraphModule:
108110
"""Compile a TorchScript module for NVIDIA GPUs using TensorRT
@@ -161,6 +163,7 @@ def compile(
161163
enable_experimental_decompositions (bool): Use the full set of operator decompositions. These decompositions may not be tested but serve to make the grap easier to covert to TensorRT, potentially increasing the amount of graphs run in TensorRT.
162164
dryrun (bool): Toggle for "Dryrun" mode, running everything except conversion to TRT and logging outputs
163165
hardware_compatible (bool): Build the TensorRT engines compatible with GPU architectures other than that of the GPU on which the engine was built (currently works for NVIDIA Ampere and newer)
166+
output_format (str): Output format of the result of TRT compilation. Options include "exported_program" (or) "ep" | "torchscript" (or) "ts" | "graph_module" (or) "fx". Default is "exported_program"
164167
**kwargs: Any,
165168
Returns:
166169
torch.fx.GraphModule: Compiled FX Module, when run it will execute via TensorRT
@@ -238,11 +241,14 @@ def compile(
238241
"dla_global_dram_size": dla_global_dram_size,
239242
"dryrun": dryrun,
240243
"hardware_compatible": hardware_compatible,
244+
"output_format": output_format,
241245
}
242246

243247
settings = CompilationSettings(**compilation_options)
244248
logger.info("Compilation Settings: %s\n", settings)
245-
return compile_module(gm, inputs, settings)
249+
trt_gm = compile_module(gm, inputs, settings)
250+
trt_result = export(trt_gm, torch_inputs, output_format)
251+
return trt_result
246252

247253

248254
def compile_module(

Diff for: py/torch_tensorrt/dynamo/_defaults.py

+1
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,7 @@
2626
REQUIRE_FULL_COMPILATION = False
2727
DRYRUN = False
2828
HARDWARE_COMPATIBLE = False
29+
OUTPUT_FORMAT = "exported_program"
2930

3031

3132
def default_device() -> Device:

Diff for: py/torch_tensorrt/dynamo/_exporter.py

+88-50
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,3 @@
1-
import copy
21
import operator
32
from typing import Any, Dict, Sequence, Tuple, cast
43

@@ -19,50 +18,43 @@
1918
def export(
2019
gm: torch.fx.GraphModule,
2120
inputs: Sequence[torch.Tensor],
22-
*,
23-
ir: str = "torchscript",
21+
output_format: str,
2422
) -> ExportedProgram:
25-
"""Export a program (``torch.fx.GraphModule``) for serialization with the TensorRT engines embedded.
26-
27-
> Note: When ExportedProgram becomes stable, this function will get merged into ``torch_tensorrt.dynamo.compile``
23+
"""Export the result of TensorRT compilation into the desired output format.
2824
2925
Arguments:
30-
src_gm (torch.fx.GraphModule): Source module, generated by torch.export (The module provided to ``torch_tensorrt.dynamo.compile``)
3126
gm (torch.fx.GraphModule): Compiled Torch-TensorRT module, generated by ``torch_tensorrt.dynamo.compile``
32-
33-
Keyword Arguments:
34-
inputs (Any): **Required** List of specifications of input shape, dtype and memory layout for inputs to the module. This argument is required. Input Sizes can be specified as torch sizes, tuples or lists. dtypes can be specified using
35-
torch datatypes or torch_tensorrt datatypes and you can use either torch devices or the torch_tensorrt device type enum
36-
to select device type. ::
37-
38-
input=[
39-
torch_tensorrt.Input((1, 3, 224, 224)), # Static NCHW input shape for input #1
40-
torch_tensorrt.Input(
41-
min_shape=(1, 224, 224, 3),
42-
opt_shape=(1, 512, 512, 3),
43-
max_shape=(1, 1024, 1024, 3),
44-
dtype=torch.int32
45-
format=torch.channel_last
46-
), # Dynamic input shape for input #2
47-
torch.randn((1, 3, 224, 244)) # Use an example tensor and let torch_tensorrt infer settings
48-
ir (str): torchscript | exported_program. Based on the provided ir, the output type would be a torchscript or exported program.
27+
inputs (torch.Tensor): Torch input tensors
28+
output_format (str): Output format of the result of TRT compilation. Options include "exported_program" (or) "ep" | "torchscript" (or) "ts" | "graph_module" (or) "fx". Default is "exported_program"
4929
"""
50-
if ir == "torchscript":
30+
if output_format == "torchscript" or output_format == "ts":
5131
return torch.jit.trace(gm, inputs)
52-
elif ir == "exported_program":
32+
elif output_format == "exported_program" or output_format == "ep":
5333
patched_module = transform(gm, inputs)
5434
exp_program = create_trt_exp_program(patched_module)
55-
5635
return exp_program
36+
elif output_format == "graph_module" or output_format == "fx":
37+
return gm
5738
else:
5839
raise ValueError(
59-
f"Invalid ir : {ir} provided for serialization. Options include torchscript | exported_program"
40+
f"Invalid output format {output_format} specified. Supported options include exported_program (or) ep | torchscript (or) ts | graph_module (or) fx"
6041
)
6142

6243

6344
def transform(
6445
gm: torch.fx.GraphModule, inputs: Sequence[torch.Tensor]
6546
) -> torch.fx.GraphModule:
47+
"""
48+
Transforms the graphmodule by inlining Pytorch and TensorRT submodules.
49+
Inlining collapses submodules into nodes which is necessary for torch.export
50+
serialization.
51+
52+
Arguments:
53+
gm (torch.fx.GraphModule): Compiled Torch-TensorRT module, generated by ``torch_tensorrt.dynamo.compile``
54+
inputs (torch.Tensor): Torch input tensors
55+
56+
Returns an inlined torch.fx.GraphModule
57+
"""
6658
# Run shape analysis
6759
_, outputs_map = partitioning.run_shape_analysis(gm, inputs)
6860

@@ -72,10 +64,6 @@ def transform(
7264
# Inline pytorch submodules
7365
inline_torch_modules(gm)
7466

75-
# Lift constant buffers and parameters in the graph
76-
# torch.export serialization expects them to be lifted
77-
lift_constant_pass(gm)
78-
7967
# Clean the graph
8068
gm.delete_all_unused_submodules()
8169
gm.graph.eliminate_dead_code()
@@ -84,34 +72,80 @@ def transform(
8472
return gm
8573

8674

87-
def lift_constant_pass(trt_gm: torch.fx.GraphModule) -> torch.fx.GraphModule:
75+
def lift(gm: torch.fx.GraphModule, graph_signature: Any) -> torch.fx.GraphModule:
76+
"""
77+
Given an unlifted fx.GraphModule, lift all parameters, buffers into placeholders.
78+
Arguments:
79+
gm (torch.fx.GraphModule): Unlifted GraphModule which contains parameters and buffers as get_attr nodes.
80+
graph_signature (torch.export.ExportGraphSignature): Instance of ExportGraphSignature class created for the output ExportedProgram.
81+
After lifting, this graph_signature will be modified with the parameters and buffers added appropriately.
82+
Returns:
83+
A lifted fx.GraphModule, modified graph_signature and a new state_dict
84+
"""
85+
# Get the state_dict of graph_module. This is different from exported_program.state_dict
86+
# exp_program.state_dict contains parameters and buffers whereas a graph_module's state_dict
87+
# has all parameters registered as torch.tensors.
88+
state_dict = gm.state_dict()
89+
8890
fake_mode = detect_fake_mode(
89-
tuple(
90-
node.meta["val"] for node in trt_gm.graph.nodes if node.op == "placeholder"
91-
)
91+
tuple(node.meta["val"] for node in gm.graph.nodes if node.op == "placeholder")
9292
)
93+
assert fake_mode is not None
9394

95+
# Locate the user input to insert new placeholders before them
9496
first_user_input = None
95-
for node in trt_gm.graph.nodes:
96-
if node.op == "placeholder":
97+
for node in gm.graph.nodes:
98+
if node.op == "placeholder" and node.name in graph_signature.user_inputs:
9799
first_user_input = node
98100
break
99101

100-
for node in trt_gm.graph.nodes:
102+
# At first the user_inputs are only present in the graph_signature.input_specs and hence non_user_input_idx=0
103+
# The input_specs should be of the form [params, buffers, constant_tensors, user_inputs]
104+
non_user_input_idx = 0
105+
for node in gm.graph.nodes:
101106
if node.op == "get_attr":
102-
constant_tensor = getattr(trt_gm, node.target)
103-
with trt_gm.graph.inserting_before(first_user_input):
104-
const_placeholder_node = trt_gm.graph.placeholder(node.target)
105-
const_placeholder_node.meta = copy.deepcopy(node.meta)
107+
constant_tensor = getattr(gm, node.target)
108+
input_kind = InputKind.CONSTANT_TENSOR
109+
110+
# state_dict has these parameters/buffers as torch.Tensors. We override them as torch.nn.Parameter/torch.Tensors respectively.
111+
for name, _ in gm.named_parameters():
112+
if node.target == name:
113+
input_kind = InputKind.PARAMETER
114+
state_dict[name] = constant_tensor
115+
break
116+
for name, _ in gm.named_buffers():
117+
if node.target == name:
118+
input_kind = InputKind.BUFFER
119+
state_dict[name] = constant_tensor
120+
break
121+
122+
# Replace get_attr nodes with placeholder nodes and copy metadata.
123+
with gm.graph.inserting_before(first_user_input):
124+
const_placeholder_node = gm.graph.placeholder(node.target)
125+
for k, v in node.meta.items():
126+
const_placeholder_node.meta[k] = v
106127
const_placeholder_node.meta["val"] = fake_mode.from_tensor(
107128
constant_tensor
108129
)
109130
node.replace_all_uses_with(const_placeholder_node)
110-
trt_gm.graph.erase_node(node)
131+
gm.graph.erase_node(node)
132+
133+
# Add these parameters/buffers/constants to the existing graph signature
134+
# before user inputs. These specs are looked up in the state_dict during ExportedProgram creation.
135+
graph_signature.input_specs.insert(
136+
non_user_input_idx,
137+
InputSpec(
138+
kind=input_kind,
139+
arg=TensorArgument(name=const_placeholder_node.name),
140+
target=node.target,
141+
),
142+
)
143+
non_user_input_idx += 1
144+
145+
gm.graph.eliminate_dead_code()
146+
gm.graph.lint()
111147

112-
trt_gm.graph.eliminate_dead_code()
113-
trt_gm.graph.lint()
114-
return trt_gm
148+
return gm, graph_signature, state_dict
115149

116150

117151
def get_duplicate_nodes(
@@ -140,7 +174,7 @@ def get_duplicate_nodes(
140174
def inline_torch_modules(gm: torch.fx.GraphModule) -> torch.fx.GraphModule:
141175
"""
142176
Inline a submodule within the parent graph (gm). All `call_module` nodes
143-
should be replaced by their submodule nodes.
177+
should be replaced by their nodes in the submodule.
144178
"""
145179
# Clean the graph
146180
gm.graph.eliminate_dead_code()
@@ -165,7 +199,6 @@ def inline_torch_modules(gm: torch.fx.GraphModule) -> torch.fx.GraphModule:
165199

166200
# Copy all nodes in the submodule into gm and
167201
# store the output node of this submodule which is now present in gm
168-
169202
submodule_output = gm.graph.graph_copy(submodule.graph, val_map)
170203

171204
# Get their references (since we copied) in the parent graph (gm)
@@ -227,6 +260,7 @@ def create_trt_exp_program(
227260
"""Creates a new Exported Program. This function takes an torch.fx.GraphModule which has TRT engines
228261
and constructs an Exported Program object with the new IO node names and state_dict
229262
"""
263+
230264
input_nodes = [node for node in gm.graph.nodes if node.op == "placeholder"]
231265
output_nodes = [node for node in gm.graph.nodes if node.op == "output"]
232266
assert output_nodes
@@ -245,8 +279,12 @@ def create_trt_exp_program(
245279
input_specs=input_specs, output_specs=output_specs
246280
)
247281

282+
# Lift parameters/buffers/constants in the graph
283+
# torch.export serialization expects them to be lifted
284+
gm, trt_graph_signature, state_dict = lift(gm, trt_graph_signature)
285+
248286
trt_exp_program = ExportedProgram(
249-
gm, gm.graph, trt_graph_signature, gm.state_dict(), {}, [], [], []
287+
gm, gm.graph, trt_graph_signature, state_dict, {}, [], [], []
250288
)
251289

252290
return trt_exp_program

0 commit comments

Comments
 (0)