Skip to content

Commit 985f6a2

Browse files
committed
chore: Add samples of input_signature usage to docs
- Add documentation to `README` for usage of input signature - Add documentation to "Getting Started" page for usage of input signature
1 parent 619b9a0 commit 985f6a2

File tree

2 files changed

+34
-3
lines changed

2 files changed

+34
-3
lines changed

README.md

+7
Original file line numberDiff line numberDiff line change
@@ -73,6 +73,7 @@ import torch_tensorrt
7373
...
7474
7575
trt_ts_module = torch_tensorrt.compile(torch_script_module,
76+
# If the inputs to the module are plain Tensors, specify them via the `inputs` argument:
7677
inputs = [example_tensor, # Provide example tensor for input shape or...
7778
torch_tensorrt.Input( # Specify input object with shape and dtype
7879
min_shape=[1, 3, 224, 224],
@@ -81,6 +82,12 @@ trt_ts_module = torch_tensorrt.compile(torch_script_module,
8182
# For static size shape=[1, 3, 224, 224]
8283
dtype=torch.half) # Datatype of input tensor. Allowed options torch.(float|half|int8|int32|bool)
8384
],
85+
86+
# For inputs containing tuples or lists of tensors, use the `input_signature` argument:
87+
# Below, we have an input consisting of a Tuple of two Tensors (Tuple[Tensor, Tensor])
88+
# input_signature = ( (torch_tensorrt.Input(shape=[1, 3, 224, 224], dtype=torch.half),
89+
# torch_tensorrt.Input(shape=[1, 3, 224, 224], dtype=torch.half)), ),
90+
8491
enabled_precisions = {torch.half}, # Run with FP16
8592
)
8693

docsrc/getting_started/getting_started_with_python_api.rst

+27-3
Original file line numberDiff line numberDiff line change
@@ -14,8 +14,7 @@ If given a ``torch.nn.Module`` and the ``ir`` flag is set to either ``default``
1414

1515
To compile your input ``torch.nn.Module`` with Torch-TensorRT, all you need to do is provide the module and inputs
1616
to Torch-TensorRT and you will be returned an optimized TorchScript module to run or add into another PyTorch module. Inputs
17-
is a list of ``torch_tensorrt.Input`` classes which define input's shape, datatype and memory format. You can also specify settings such as
18-
operating precision for the engine or target device. After compilation you can save the module just like any other module
17+
is a list of ``torch_tensorrt.Input`` classes which define input Tensors' shape, datatype and memory format. Alternatively, if your input is a more complex data type, such as a tuple or list of Tensors, you can use the ``input_signature`` argument to specify a collection-based input, such as ``(List[Tensor], Tuple[Tensor, Tensor])``. See the second sample below for an example. You can also specify settings such as operating precision for the engine or target device. After compilation you can save the module just like any other module
1918
to load in a deployment application. In order to load a TensorRT/TorchScript module, make sure you first import ``torch_tensorrt``.
2019

2120
.. code-block:: python
@@ -44,6 +43,32 @@ to load in a deployment application. In order to load a TensorRT/TorchScript mod
4443
result = trt_ts_module(input_data)
4544
torch.jit.save(trt_ts_module, "trt_ts_module.ts")
4645
46+
.. code-block:: python
47+
48+
# Sample using collection-based inputs via the input_signature argument
49+
import torch_tensorrt
50+
51+
...
52+
53+
model = MyModel().eval()
54+
55+
# input_signature expects a tuple of individual input arguments to the module
56+
# The module below, for example, would have a docstring of the form:
57+
# def forward(self, input0: List[torch.Tensor], input1: Tuple[torch.Tensor, torch.Tensor])
58+
input_signature = (
59+
[torch_tensorrt.Input(shape=[64, 64], dtype=torch.half), torch_tensorrt.Input(shape=[64, 64], dtype=torch.half)],
60+
(torch_tensorrt.Input(shape=[64, 64], dtype=torch.half), torch_tensorrt.Input(shape=[64, 64], dtype=torch.half)),
61+
)
62+
enabled_precisions = {torch.float, torch.half}
63+
64+
trt_ts_module = torch_tensorrt.compile(
65+
model, input_signature=input_signature, enabled_precisions=enabled_precisions
66+
)
67+
68+
input_data = input_data.to("cuda").half()
69+
result = trt_ts_module(input_data)
70+
torch.jit.save(trt_ts_module, "trt_ts_module.ts")
71+
4772
.. code-block:: python
4873
4974
# Deployment application
@@ -55,4 +80,3 @@ to load in a deployment application. In order to load a TensorRT/TorchScript mod
5580
result = trt_ts_module(input_data)
5681
5782
Torch-TensorRT Python API also provides ``torch_tensorrt.ts.compile`` which accepts a TorchScript module as input and ``torch_tensorrt.fx.compile`` which accepts a FX GraphModule as input.
58-

0 commit comments

Comments
 (0)