You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
Copy file name to clipboardExpand all lines: docsrc/getting_started/getting_started_with_python_api.rst
+27-3
Original file line number
Diff line number
Diff line change
@@ -14,8 +14,7 @@ If given a ``torch.nn.Module`` and the ``ir`` flag is set to either ``default``
14
14
15
15
To compile your input ``torch.nn.Module`` with Torch-TensorRT, all you need to do is provide the module and inputs
16
16
to Torch-TensorRT and you will be returned an optimized TorchScript module to run or add into another PyTorch module. Inputs
17
-
is a list of ``torch_tensorrt.Input`` classes which define input's shape, datatype and memory format. You can also specify settings such as
18
-
operating precision for the engine or target device. After compilation you can save the module just like any other module
17
+
is a list of ``torch_tensorrt.Input`` classes which define input Tensors' shape, datatype and memory format. Alternatively, if your input is a more complex data type, such as a tuple or list of Tensors, you can use the ``input_signature`` argument to specify a collection-based input, such as ``(List[Tensor], Tuple[Tensor, Tensor])``. See the second sample below for an example. You can also specify settings such as operating precision for the engine or target device. After compilation you can save the module just like any other module
19
18
to load in a deployment application. In order to load a TensorRT/TorchScript module, make sure you first import ``torch_tensorrt``.
20
19
21
20
.. code-block:: python
@@ -44,6 +43,32 @@ to load in a deployment application. In order to load a TensorRT/TorchScript mod
44
43
result = trt_ts_module(input_data)
45
44
torch.jit.save(trt_ts_module, "trt_ts_module.ts")
46
45
46
+
.. code-block:: python
47
+
48
+
# Sample using collection-based inputs via the input_signature argument
49
+
import torch_tensorrt
50
+
51
+
...
52
+
53
+
model = MyModel().eval()
54
+
55
+
# input_signature expects a tuple of individual input arguments to the module
56
+
# The module below, for example, would have a docstring of the form:
@@ -55,4 +80,3 @@ to load in a deployment application. In order to load a TensorRT/TorchScript mod
55
80
result = trt_ts_module(input_data)
56
81
57
82
Torch-TensorRT Python API also provides ``torch_tensorrt.ts.compile`` which accepts a TorchScript module as input and ``torch_tensorrt.fx.compile`` which accepts a FX GraphModule as input.
0 commit comments