You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
Copy file name to clipboardExpand all lines: docsrc/user_guide/saving_models.rst
+31-18
Original file line number
Diff line number
Diff line change
@@ -14,14 +14,18 @@ Saving models compiled with Torch-TensorRT varies slightly with the `ir` that ha
14
14
Dynamo IR
15
15
-------------
16
16
17
-
Starting with 2.1 release of Torch-TensorRT, we are switching the default compilation to be dynamo based.
18
-
The output of `ir=dynamo` compilation is a `torch.fx.GraphModule` object. There are two ways to save these objects
17
+
The output type of `ir=dynamo` compilation of Torch-TensorRT is `torch.export.ExportedProgram` object by default.
18
+
In addition, we provide a new parameter `output_format` in the `CompilationSetting` object provided before compilation.
19
+
The `output_format` can take the following options
19
20
20
-
a) Converting to Torchscript
21
+
* `exported_program` (or) `ep` : This is the default. Returns an ExportedProgram
22
+
* `torchscript` (or) `ts` : This returns a TorchScript module
23
+
* `graph_module` (or) `fx` : This returns a torch.fx.GraphModule which can be traced into Torchscript to save to disk.
24
+
25
+
a) Torchscript
21
26
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
22
27
23
-
`torch.fx.GraphModule` objects cannot be serialized directly. Hence we use `torch.jit.trace` to convert this into a `ScriptModule` object which can be saved to disk.
24
-
The following code illustrates this approach.
28
+
If you set the `output_format="torchscript"`, this will return a `ScriptModule` which can be serialized via torch.jit.save
25
29
26
30
.. code-block:: python
27
31
@@ -30,9 +34,9 @@ The following code illustrates this approach.
30
34
31
35
model = MyModel().eval().cuda()
32
36
inputs = [torch.randn((1, 3, 224, 224)).cuda()]
33
-
trt_gm = torch_tensorrt.compile(model, ir="dynamo", inputs) # Output is a torch.fx.GraphModule
`torch_tensorrt.dynamo.export` inlines the submodules within a GraphModule to their corresponding nodes and stiches all the nodes together.
64
-
This is needed as `torch._export` serialization cannot handle serializing and deserializing of submodules (`call_module` nodes).
65
+
c) GraphModule
66
+
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
65
67
66
-
.. note:: This way of saving the models using `ExportedProgram` is experimental. Here is a known issue : https://github.com/pytorch/TensorRT/issues/2341
68
+
We can also return a `torch.fx.GraphModule` object as the output of Torch-TensorRT compilation by setting `output_format="graph_module"`.
69
+
Internally, partitioning, lowering, conversion phases operate using GraphModule objects. These can be either traced into a Torchscript modules or
"""Compile a TorchScript module for NVIDIA GPUs using TensorRT
@@ -161,6 +163,7 @@ def compile(
161
163
enable_experimental_decompositions (bool): Use the full set of operator decompositions. These decompositions may not be tested but serve to make the grap easier to covert to TensorRT, potentially increasing the amount of graphs run in TensorRT.
162
164
dryrun (bool): Toggle for "Dryrun" mode, running everything except conversion to TRT and logging outputs
163
165
hardware_compatible (bool): Build the TensorRT engines compatible with GPU architectures other than that of the GPU on which the engine was built (currently works for NVIDIA Ampere and newer)
166
+
output_format (str): Output format of the result of TRT compilation. Options include "exported_program" (or) "ep" | "torchscript" (or) "ts" | "graph_module" (or) "fx". Default is "exported_program"
164
167
**kwargs: Any,
165
168
Returns:
166
169
torch.fx.GraphModule: Compiled FX Module, when run it will execute via TensorRT
Copy file name to clipboardExpand all lines: py/torch_tensorrt/dynamo/_exporter.py
+88-50
Original file line number
Diff line number
Diff line change
@@ -1,4 +1,3 @@
1
-
importcopy
2
1
importoperator
3
2
fromtypingimportAny, Dict, Sequence, Tuple, cast
4
3
@@ -19,50 +18,43 @@
19
18
defexport(
20
19
gm: torch.fx.GraphModule,
21
20
inputs: Sequence[torch.Tensor],
22
-
*,
23
-
ir: str="torchscript",
21
+
output_format: str,
24
22
) ->ExportedProgram:
25
-
"""Export a program (``torch.fx.GraphModule``) for serialization with the TensorRT engines embedded.
26
-
27
-
> Note: When ExportedProgram becomes stable, this function will get merged into ``torch_tensorrt.dynamo.compile``
23
+
"""Export the result of TensorRT compilation into the desired output format.
28
24
29
25
Arguments:
30
-
src_gm (torch.fx.GraphModule): Source module, generated by torch.export (The module provided to ``torch_tensorrt.dynamo.compile``)
31
26
gm (torch.fx.GraphModule): Compiled Torch-TensorRT module, generated by ``torch_tensorrt.dynamo.compile``
32
-
33
-
Keyword Arguments:
34
-
inputs (Any): **Required** List of specifications of input shape, dtype and memory layout for inputs to the module. This argument is required. Input Sizes can be specified as torch sizes, tuples or lists. dtypes can be specified using
35
-
torch datatypes or torch_tensorrt datatypes and you can use either torch devices or the torch_tensorrt device type enum
torch.randn((1, 3, 224, 244)) # Use an example tensor and let torch_tensorrt infer settings
48
-
ir (str): torchscript | exported_program. Based on the provided ir, the output type would be a torchscript or exported program.
27
+
inputs (torch.Tensor): Torch input tensors
28
+
output_format (str): Output format of the result of TRT compilation. Options include "exported_program" (or) "ep" | "torchscript" (or) "ts" | "graph_module" (or) "fx". Default is "exported_program"
0 commit comments