Skip to content

Commit 9675c1b

Browse files
committed
fix: Allow low rank inputs in Python Runtime
- Enable support for 0D and 1D inputs in Python runtime - Remove shape checks which are faulty in cases where inputs do not have an explicit batch - Add regression test cases for the Python runtime
1 parent b7f1e85 commit 9675c1b

File tree

2 files changed

+105
-38
lines changed

2 files changed

+105
-38
lines changed

py/torch_tensorrt/dynamo/runtime/_PythonTorchTensorRTModule.py

+24-38
Original file line numberDiff line numberDiff line change
@@ -1,13 +1,15 @@
11
from __future__ import annotations
22

3+
import logging
34
from typing import Any, Dict, List, Optional, Sequence, Tuple
45

6+
# @manual=//deeplearning/trt/python:py_tensorrt
7+
import tensorrt as trt
58
import torch
69
from torch.nn import Module
710
from torch_tensorrt.fx.utils import Frameworks, unified_dtype_converter
811

9-
# @manual=//deeplearning/trt/python:py_tensorrt
10-
import tensorrt as trt
12+
logger = logging.getLogger(__name__)
1113

1214

1315
class PythonTorchTensorRTModule(Module): # type: ignore[misc]
@@ -22,14 +24,12 @@ def __init__(
2224
engine: trt.ICudaEngine,
2325
input_names: Optional[List[str]] = None,
2426
output_names: Optional[List[str]] = None,
25-
cuda_graph_batch_size: int = -1,
2627
):
2728
super(PythonTorchTensorRTModule, self).__init__()
2829
self._register_state_dict_hook(PythonTorchTensorRTModule._on_state_dict)
2930
self.engine = engine
3031
self.input_names = input_names if input_names is not None else []
3132
self.output_names = output_names if output_names is not None else []
32-
self.cuda_graph_batch_size = cuda_graph_batch_size
3333
self.initialized = False
3434
self._initialize()
3535

@@ -107,7 +107,6 @@ def _on_state_dict(self, state_dict: Dict[str, Any], prefix: str, _: Any) -> Non
107107
state_dict[prefix + "engine"] = bytearray(self.engine.serialize())
108108
state_dict[prefix + "input_names"] = self.input_names
109109
state_dict[prefix + "output_names"] = self.output_names
110-
state_dict[prefix + "cuda_graph_batch_size"] = self.cuda_graph_batch_size
111110

112111
def _load_from_state_dict(
113112
self,
@@ -156,8 +155,6 @@ def forward(self, *inputs: Any) -> torch.Tensor | Tuple[torch.Tensor, ...]:
156155
self.input_names
157156
), f"Wrong number of inputs, expect {len(self.input_names)} get {len(inputs)}."
158157

159-
# This is only used when the trt engine is using implicit batch dim.
160-
batch_size = inputs[0].shape[0]
161158
contiguous_inputs: List[torch.Tensor] = [i.contiguous() for i in inputs]
162159
bindings: List[Any] = [None] * (
163160
len(self.input_names)
@@ -166,37 +163,34 @@ def forward(self, *inputs: Any) -> torch.Tensor | Tuple[torch.Tensor, ...]:
166163
)
167164

168165
for i, input_name in enumerate(self.input_names):
169-
assert inputs[
170-
i
171-
].is_cuda, f"{i}th input({input_name}) is not on cuda device."
166+
if not contiguous_inputs[i].is_cuda:
167+
logger.warning(
168+
f"Detected input {input_name} of engine {self.engine.name} is not on a cuda device. "
169+
"This tensor is being moved by the runtime but for performance considerations, "
170+
"ensure your inputs are all on GPU and open an issue here "
171+
"(https://github.com/pytorch/TensorRT/issues) if this warning persists."
172+
)
173+
contiguous_inputs = (
174+
contiguous_inputs[:i]
175+
+ [contiguous_inputs[i].cuda()]
176+
+ contiguous_inputs[i + 1 :]
177+
)
178+
172179
assert (
173-
inputs[i].dtype == self.input_dtypes[i]
174-
), f"Dtype mismatch for {i}th input({input_name}). Expect {self.input_dtypes[i]}, got {inputs[i].dtype}."
180+
contiguous_inputs[i].dtype == self.input_dtypes[i]
181+
), f"Dtype mismatch for {i}th input({input_name}). Expect {self.input_dtypes[i]}, got {contiguous_inputs[i].dtype}."
175182

176183
idx = self.input_binding_indices_in_order[i]
177184
bindings[idx] = contiguous_inputs[i].data_ptr()
178185

179-
if not self.engine.has_implicit_batch_dimension:
180-
self.context.set_binding_shape(
181-
idx, tuple(contiguous_inputs[i].shape)
182-
)
183-
else:
184-
assert inputs[i].size()[1:] == self.input_shapes[i], (
185-
f"Shape mismatch for {i}th input({input_name}). "
186-
f"Expect {self.input_shapes[i]}, got {inputs[i].size()[1:]}."
187-
)
188-
189186
with torch.autograd.profiler.record_function(
190187
"PythonTorchTensorRTModule:ProcessOutputs"
191188
):
192189
# create output tensors
193190
outputs: List[torch.Tensor] = []
194191

195192
for i, idx in enumerate(self.output_binding_indices_in_order):
196-
if self.engine.has_implicit_batch_dimension:
197-
shape = (batch_size,) + self.output_shapes[i]
198-
else:
199-
shape = tuple(self.context.get_binding_shape(idx))
193+
shape = tuple(self.context.get_binding_shape(idx))
200194

201195
output = torch.empty(
202196
size=shape,
@@ -207,10 +201,7 @@ def forward(self, *inputs: Any) -> torch.Tensor | Tuple[torch.Tensor, ...]:
207201
bindings[idx] = output.data_ptr()
208202

209203
for i, idx in enumerate(self.hidden_output_binding_indices_in_order):
210-
if self.engine.has_implicit_batch_dimension:
211-
shape = (batch_size,) + self.hidden_output_shapes[i]
212-
else:
213-
shape = tuple(self.context.get_binding_shape(idx))
204+
shape = tuple(self.context.get_binding_shape(idx))
214205

215206
output = torch.empty(
216207
size=shape,
@@ -222,14 +213,9 @@ def forward(self, *inputs: Any) -> torch.Tensor | Tuple[torch.Tensor, ...]:
222213
with torch.autograd.profiler.record_function(
223214
"PythonTorchTensorRTModule:TensorRTRuntime"
224215
):
225-
if self.engine.has_implicit_batch_dimension:
226-
self.context.execute_async(
227-
batch_size, bindings, torch.cuda.current_stream().cuda_stream
228-
)
229-
else:
230-
self.context.execute_async_v2(
231-
bindings, torch.cuda.current_stream().cuda_stream
232-
)
216+
self.context.execute_async_v2(
217+
bindings, torch.cuda.current_stream().cuda_stream
218+
)
233219

234220
if len(outputs) == 1:
235221
return outputs[0]
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,81 @@
1+
import torch
2+
import torch_tensorrt
3+
from torch.testing._internal.common_utils import TestCase, run_tests
4+
5+
6+
class TestLowRankInputs(TestCase):
7+
def test_0D_input(self):
8+
class Tensor0DInput(torch.nn.Module):
9+
def forward(self, x):
10+
return x * 7
11+
12+
inputs = [
13+
torch.tensor(
14+
3,
15+
)
16+
.cuda()
17+
.int(),
18+
]
19+
20+
fx_graph = torch.fx.symbolic_trace(Tensor0DInput())
21+
22+
# Validate that the results between Torch and Torch-TRT are similar
23+
optimized_model = torch_tensorrt.compile(
24+
fx_graph,
25+
"torch_compile",
26+
inputs,
27+
min_block_size=1,
28+
pass_through_build_failures=True,
29+
use_python_runtime=True,
30+
)
31+
optimized_model_results = optimized_model(*inputs).detach().cpu()
32+
torch_model_results = fx_graph(*inputs).detach().cpu()
33+
34+
max_diff = float(
35+
torch.max(torch.abs(optimized_model_results - torch_model_results))
36+
)
37+
self.assertAlmostEqual(
38+
max_diff,
39+
0,
40+
msg=f"0D-Tensor TRT outputs don't match with the original model.",
41+
)
42+
torch._dynamo.reset()
43+
44+
def test_1D_input(self):
45+
class Tensor1DInput(torch.nn.Module):
46+
def forward(self, x, y):
47+
return (x + 7.1) / (y * 2.1)
48+
49+
inputs = [torch.rand((3, 1)).cuda(), torch.rand((3, 1)).cuda()]
50+
51+
fx_graph = torch.fx.symbolic_trace(Tensor1DInput())
52+
53+
# Validate that the results between Torch and Torch-TRT are similar
54+
optimized_model = torch_tensorrt.compile(
55+
fx_graph,
56+
"torch_compile",
57+
inputs,
58+
min_block_size=1,
59+
pass_through_build_failures=True,
60+
use_python_runtime=True,
61+
)
62+
optimized_model_results = optimized_model(*inputs).detach().cpu()
63+
torch_model_results = fx_graph(*inputs).detach().cpu()
64+
65+
max_diff = float(
66+
torch.max(torch.abs(optimized_model_results - torch_model_results))
67+
)
68+
self.assertAlmostEqual(
69+
max_diff,
70+
0,
71+
msg=f"1D-Tensor TRT outputs don't match with the original model.",
72+
)
73+
74+
# Validate that the runtime moves cpu inputs to cuda
75+
optimized_model(torch.rand((3, 1)).cuda(), torch.rand((3, 1)))
76+
77+
torch._dynamo.reset()
78+
79+
80+
if __name__ == "__main__":
81+
run_tests()

0 commit comments

Comments
 (0)