Skip to content

chore: Set return type of compilation to ExportedProgram [release/2.2] #2607

New issue

Have a question about this project? # for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “#”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? # to your account

Merged
merged 16 commits into from
Jan 31, 2024
Merged
49 changes: 31 additions & 18 deletions docsrc/user_guide/saving_models.rst
Original file line number Diff line number Diff line change
Expand Up @@ -14,14 +14,18 @@ Saving models compiled with Torch-TensorRT varies slightly with the `ir` that ha
Dynamo IR
-------------

Starting with 2.1 release of Torch-TensorRT, we are switching the default compilation to be dynamo based.
The output of `ir=dynamo` compilation is a `torch.fx.GraphModule` object. There are two ways to save these objects
The output type of `ir=dynamo` compilation of Torch-TensorRT is `torch.export.ExportedProgram` object by default.
In addition, we provide a new parameter `output_format` in the `CompilationSetting` object provided before compilation.
The `output_format` can take the following options

a) Converting to Torchscript
* `exported_program` (or) `ep` : This is the default. Returns an ExportedProgram
* `torchscript` (or) `ts` : This returns a TorchScript module
* `graph_module` (or) `fx` : This returns a torch.fx.GraphModule which can be traced into Torchscript to save to disk.

a) Torchscript
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^

`torch.fx.GraphModule` objects cannot be serialized directly. Hence we use `torch.jit.trace` to convert this into a `ScriptModule` object which can be saved to disk.
The following code illustrates this approach.
If you set the `output_format="torchscript"`, this will return a `ScriptModule` which can be serialized via torch.jit.save

.. code-block:: python

Expand All @@ -30,9 +34,9 @@ The following code illustrates this approach.

model = MyModel().eval().cuda()
inputs = [torch.randn((1, 3, 224, 224)).cuda()]
trt_gm = torch_tensorrt.compile(model, ir="dynamo", inputs) # Output is a torch.fx.GraphModule
trt_traced_model = torch.jit.trace(trt_gm, inputs)
torch.jit.save(trt_traced_model, "trt_model.ts")
# trt_ts is a torch.jit.ScriptModule object
trt_ts = torch_tensorrt.compile(model, ir="dynamo", inputs, output_format="torchscript")
torch.jit.save(trt_ts, "trt_model.ts")

# Later, you can load it and run inference
model = torch.jit.load("trt_model.ts").cuda()
Expand All @@ -41,8 +45,7 @@ The following code illustrates this approach.
b) ExportedProgram
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^

`torch.export.ExportedProgram` is a new format introduced in Pytorch 2.1. After we compile a Pytorch module using Torch-TensorRT, the resultant
`torch.fx.GraphModule` along with additional metadata can be used to create `ExportedProgram` which can be saved and loaded from disk.
`torch.export.ExportedProgram`, a new format introduced in Pytorch 2.X is the default return type of Torch-TensorRT compilation.

.. code-block:: python

Expand All @@ -51,26 +54,36 @@ b) ExportedProgram

model = MyModel().eval().cuda()
inputs = [torch.randn((1, 3, 224, 224)).cuda()]
trt_gm = torch_tensorrt.compile(model, ir="dynamo", inputs) # Output is a torch.fx.GraphModule
# Transform and create an exported program
trt_exp_program = torch_tensorrt.dynamo.export(trt_gm, inputs)
torch.export.save(trt_exp_program, "trt_model.ep")
# trt_ep is a torch.export.ExportedProgram object
trt_ep = torch_tensorrt.compile(model, ir="dynamo", inputs)
torch.export.save(trt_ep, "trt_model.ep")

# Later, you can load it and run inference
model = torch.export.load("trt_model.ep")
model(*inputs)

`torch_tensorrt.dynamo.export` inlines the submodules within a GraphModule to their corresponding nodes and stiches all the nodes together.
This is needed as `torch._export` serialization cannot handle serializing and deserializing of submodules (`call_module` nodes).
c) GraphModule
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^

.. note:: This way of saving the models using `ExportedProgram` is experimental. Here is a known issue : https://github.com/pytorch/TensorRT/issues/2341
We can also return a `torch.fx.GraphModule` object as the output of Torch-TensorRT compilation by setting `output_format="graph_module"`.
Internally, partitioning, lowering, conversion phases operate using GraphModule objects. These can be either traced into a Torchscript modules or
exported into `ExportedProgram` objects

.. code-block:: python

import torch
import torch_tensorrt

model = MyModel().eval().cuda()
inputs = [torch.randn((1, 3, 224, 224)).cuda()]
# trt_gm is a torch.fx.GraphModule object
trt_gm = torch_tensorrt.compile(model, ir="dynamo", inputs, output_format="graph_module")

Torchscript IR
-------------

In Torch-TensorRT 1.X versions, the primary way to compile and run inference with Torch-TensorRT is using Torchscript IR.
This behavior stays the same in 2.X versions as well.
For `ir=ts`, this behavior stays the same in 2.X versions as well.

.. code-block:: python

Expand Down
4 changes: 3 additions & 1 deletion examples/int8/training/vgg16/vgg16.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,10 +3,12 @@
- [Very Deep Convolutional Networks for Large-Scale Image Recognition](
https://arxiv.org/abs/1409.1556) (ICLR 2015)
"""

from functools import reduce

import torch
import torch.nn as nn
import torch.nn.functional as F
from functools import reduce


class VGG(nn.Module):
Expand Down
10 changes: 6 additions & 4 deletions py/torch_tensorrt/_Device.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,12 +32,14 @@ class Device(object):
allow_gpu_fallback (bool): Whether falling back to GPU if DLA cannot support an op should be allowed
"""

device_type: Optional[
trt.DeviceType
] = None #: Target device type (GPU or DLA). Set implicitly based on if dla_core is specified.
device_type: Optional[trt.DeviceType] = (
None #: Target device type (GPU or DLA). Set implicitly based on if dla_core is specified.
)
gpu_id: int = -1 #: Device ID for target GPU
dla_core: int = -1 #: Core ID for target DLA core
allow_gpu_fallback: bool = False #: Whether falling back to GPU if DLA cannot support an op should be allowed
allow_gpu_fallback: bool = (
False #: Whether falling back to GPU if DLA cannot support an op should be allowed
)

def __init__(self, *args: Any, **kwargs: Any):
"""__init__ Method for torch_tensorrt.Device
Expand Down
12 changes: 6 additions & 6 deletions py/torch_tensorrt/_Input.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,12 +28,12 @@ class _ShapeMode(Enum):
STATIC = 0
DYNAMIC = 1

shape_mode: Optional[
_ShapeMode
] = None #: Is input statically or dynamically shaped
shape: Optional[
Tuple[int, ...] | Dict[str, Tuple[int, ...]]
] = None #: Either a single Tuple or a dict of tuples defining the input shape. Static shaped inputs will have a single tuple. Dynamic inputs will have a dict of the form ``{ "min_shape": Tuple, "opt_shape": Tuple, "max_shape": Tuple }``
shape_mode: Optional[_ShapeMode] = (
None #: Is input statically or dynamically shaped
)
shape: Optional[Tuple[int, ...] | Dict[str, Tuple[int, ...]]] = (
None #: Either a single Tuple or a dict of tuples defining the input shape. Static shaped inputs will have a single tuple. Dynamic inputs will have a dict of the form ``{ "min_shape": Tuple, "opt_shape": Tuple, "max_shape": Tuple }``
)
dtype: _enums.dtype = (
_enums.dtype.unknown
) #: The expected data type of the input tensor (default: torch_tensorrt.dtype.float32)
Expand Down
15 changes: 11 additions & 4 deletions py/torch_tensorrt/dynamo/_compiler.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@
MIN_BLOCK_SIZE,
NUM_AVG_TIMING_ITERS,
OPTIMIZATION_LEVEL,
OUTPUT_FORMAT,
PASS_THROUGH_BUILD_FAILURES,
PRECISION,
REFIT,
Expand All @@ -38,6 +39,7 @@
VERSION_COMPATIBLE,
WORKSPACE_SIZE,
)
from torch_tensorrt.dynamo._exporter import export
from torch_tensorrt.dynamo.conversion import (
CompilationSettings,
convert_module,
Expand Down Expand Up @@ -88,6 +90,7 @@ def compile(
use_python_runtime: bool = USE_PYTHON_RUNTIME,
use_fast_partitioner: bool = USE_FAST_PARTITIONER,
enable_experimental_decompositions: bool = ENABLE_EXPERIMENTAL_DECOMPOSITIONS,
output_format: str = OUTPUT_FORMAT,
**kwargs: Any,
) -> torch.fx.GraphModule:
"""Compile a TorchScript module for NVIDIA GPUs using TensorRT
Expand Down Expand Up @@ -144,6 +147,7 @@ def compile(
use_python_runtime: (bool): Return a graph using a pure Python runtime, reduces options for serialization
use_fast_partitioner: (bool): Use the adjacency based partitioning scheme instead of the global partitioner. Adjacency partitioning is faster but may not be optiminal. Use the global paritioner (``False``) if looking for best performance
enable_experimental_decompositions (bool): Use the full set of operator decompositions. These decompositions may not be tested but serve to make the grap easier to covert to TensorRT, potentially increasing the amount of graphs run in TensorRT.
output_format (str): Output format of the result of TRT compilation. Options include "exported_program" (or) "ep" | "torchscript" (or) "ts" | "graph_module" (or) "fx". Default is "exported_program"
**kwargs: Any,
Returns:
torch.fx.GraphModule: Compiled FX Module, when run it will execute via TensorRT
Expand Down Expand Up @@ -200,9 +204,9 @@ def compile(
"device": device,
"workspace_size": workspace_size,
"min_block_size": min_block_size,
"torch_executed_ops": torch_executed_ops
if torch_executed_ops is not None
else set(),
"torch_executed_ops": (
torch_executed_ops if torch_executed_ops is not None else set()
),
"pass_through_build_failures": pass_through_build_failures,
"max_aux_streams": max_aux_streams,
"version_compatible": version_compatible,
Expand All @@ -219,11 +223,14 @@ def compile(
"dla_sram_size": dla_sram_size,
"dla_local_dram_size": dla_local_dram_size,
"dla_global_dram_size": dla_global_dram_size,
"output_format": output_format,
}

settings = CompilationSettings(**compilation_options)
logger.info("Compilation Settings: %s\n", settings)
return compile_module(gm, inputs, settings)
trt_gm = compile_module(gm, inputs, settings)
trt_result = export(trt_gm, torch_inputs, output_format)
return trt_result


def compile_module(
Expand Down
1 change: 1 addition & 0 deletions py/torch_tensorrt/dynamo/_defaults.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@
ENABLE_EXPERIMENTAL_DECOMPOSITIONS = False
REFIT = False
REQUIRE_FULL_COMPILATION = False
OUTPUT_FORMAT = "exported_program"


def default_device() -> Device:
Expand Down
Loading