Skip to content

Commit 1033dff

Browse files
authored
fix: Allow low rank inputs in Python Runtime (#2282)
1 parent b2aa255 commit 1033dff

File tree

2 files changed

+107
-37
lines changed

2 files changed

+107
-37
lines changed

py/torch_tensorrt/dynamo/runtime/_PythonTorchTensorRTModule.py

+26-37
Original file line numberDiff line numberDiff line change
@@ -1,13 +1,14 @@
11
from __future__ import annotations
22

3+
import logging
34
from typing import Any, Dict, List, Optional, Sequence, Tuple
45

6+
import tensorrt as trt
57
import torch
68
from torch.nn import Module
79
from torch_tensorrt.fx.utils import Frameworks, unified_dtype_converter
810

9-
# @manual=//deeplearning/trt/python:py_tensorrt
10-
import tensorrt as trt
11+
logger = logging.getLogger(__name__)
1112

1213

1314
class PythonTorchTensorRTModule(Module): # type: ignore[misc]
@@ -22,14 +23,12 @@ def __init__(
2223
engine: trt.ICudaEngine,
2324
input_names: Optional[List[str]] = None,
2425
output_names: Optional[List[str]] = None,
25-
cuda_graph_batch_size: int = -1,
2626
):
2727
super(PythonTorchTensorRTModule, self).__init__()
2828
self._register_state_dict_hook(PythonTorchTensorRTModule._on_state_dict)
2929
self.engine = engine
3030
self.input_names = input_names if input_names is not None else []
3131
self.output_names = output_names if output_names is not None else []
32-
self.cuda_graph_batch_size = cuda_graph_batch_size
3332
self.initialized = False
3433
self._initialize()
3534

@@ -107,7 +106,6 @@ def _on_state_dict(self, state_dict: Dict[str, Any], prefix: str, _: Any) -> Non
107106
state_dict[prefix + "engine"] = bytearray(self.engine.serialize())
108107
state_dict[prefix + "input_names"] = self.input_names
109108
state_dict[prefix + "output_names"] = self.output_names
110-
state_dict[prefix + "cuda_graph_batch_size"] = self.cuda_graph_batch_size
111109

112110
def _load_from_state_dict(
113111
self,
@@ -156,8 +154,6 @@ def forward(self, *inputs: Any) -> torch.Tensor | Tuple[torch.Tensor, ...]:
156154
self.input_names
157155
), f"Wrong number of inputs, expect {len(self.input_names)} get {len(inputs)}."
158156

159-
# This is only used when the trt engine is using implicit batch dim.
160-
batch_size = inputs[0].shape[0]
161157
contiguous_inputs: List[torch.Tensor] = [i.contiguous() for i in inputs]
162158
bindings: List[Any] = [None] * (
163159
len(self.input_names)
@@ -166,25 +162,29 @@ def forward(self, *inputs: Any) -> torch.Tensor | Tuple[torch.Tensor, ...]:
166162
)
167163

168164
for i, input_name in enumerate(self.input_names):
169-
assert inputs[
170-
i
171-
].is_cuda, f"{i}th input({input_name}) is not on cuda device."
165+
if not contiguous_inputs[i].is_cuda:
166+
logger.warning(
167+
f"Detected input {input_name} of engine {self.engine.name} is not on a cuda device. "
168+
"This tensor is being moved by the runtime but for performance considerations, "
169+
"ensure your inputs are all on GPU and open an issue here "
170+
"(https://github.com/pytorch/TensorRT/issues) if this warning persists."
171+
)
172+
contiguous_inputs = (
173+
contiguous_inputs[:i]
174+
+ [contiguous_inputs[i].cuda()]
175+
+ contiguous_inputs[i + 1 :]
176+
)
177+
172178
assert (
173-
inputs[i].dtype == self.input_dtypes[i]
174-
), f"Dtype mismatch for {i}th input({input_name}). Expect {self.input_dtypes[i]}, got {inputs[i].dtype}."
179+
contiguous_inputs[i].dtype == self.input_dtypes[i]
180+
), f"Dtype mismatch for {i}th input({input_name}). Expect {self.input_dtypes[i]}, got {contiguous_inputs[i].dtype}."
175181

176182
idx = self.input_binding_indices_in_order[i]
177183
bindings[idx] = contiguous_inputs[i].data_ptr()
178184

179-
if not self.engine.has_implicit_batch_dimension:
180-
self.context.set_binding_shape(
181-
idx, tuple(contiguous_inputs[i].shape)
182-
)
183-
else:
184-
assert inputs[i].size()[1:] == self.input_shapes[i], (
185-
f"Shape mismatch for {i}th input({input_name}). "
186-
f"Expect {self.input_shapes[i]}, got {inputs[i].size()[1:]}."
187-
)
185+
self.context.set_binding_shape(
186+
idx, tuple(contiguous_inputs[i].shape)
187+
)
188188

189189
with torch.autograd.profiler.record_function(
190190
"PythonTorchTensorRTModule:ProcessOutputs"
@@ -193,10 +193,7 @@ def forward(self, *inputs: Any) -> torch.Tensor | Tuple[torch.Tensor, ...]:
193193
outputs: List[torch.Tensor] = []
194194

195195
for i, idx in enumerate(self.output_binding_indices_in_order):
196-
if self.engine.has_implicit_batch_dimension:
197-
shape = (batch_size,) + self.output_shapes[i]
198-
else:
199-
shape = tuple(self.context.get_binding_shape(idx))
196+
shape = tuple(self.context.get_binding_shape(idx))
200197

201198
output = torch.empty(
202199
size=shape,
@@ -207,10 +204,7 @@ def forward(self, *inputs: Any) -> torch.Tensor | Tuple[torch.Tensor, ...]:
207204
bindings[idx] = output.data_ptr()
208205

209206
for i, idx in enumerate(self.hidden_output_binding_indices_in_order):
210-
if self.engine.has_implicit_batch_dimension:
211-
shape = (batch_size,) + self.hidden_output_shapes[i]
212-
else:
213-
shape = tuple(self.context.get_binding_shape(idx))
207+
shape = tuple(self.context.get_binding_shape(idx))
214208

215209
output = torch.empty(
216210
size=shape,
@@ -222,14 +216,9 @@ def forward(self, *inputs: Any) -> torch.Tensor | Tuple[torch.Tensor, ...]:
222216
with torch.autograd.profiler.record_function(
223217
"PythonTorchTensorRTModule:TensorRTRuntime"
224218
):
225-
if self.engine.has_implicit_batch_dimension:
226-
self.context.execute_async(
227-
batch_size, bindings, torch.cuda.current_stream().cuda_stream
228-
)
229-
else:
230-
self.context.execute_async_v2(
231-
bindings, torch.cuda.current_stream().cuda_stream
232-
)
219+
self.context.execute_async_v2(
220+
bindings, torch.cuda.current_stream().cuda_stream
221+
)
233222

234223
if len(outputs) == 1:
235224
return outputs[0]
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,81 @@
1+
import torch
2+
import torch_tensorrt
3+
from torch.testing._internal.common_utils import TestCase, run_tests
4+
5+
6+
class TestLowRankInputs(TestCase):
7+
def test_0D_input(self):
8+
class Tensor0DInput(torch.nn.Module):
9+
def forward(self, x):
10+
return x * 7
11+
12+
inputs = [
13+
torch.tensor(
14+
3,
15+
)
16+
.cuda()
17+
.int(),
18+
]
19+
20+
fx_graph = torch.fx.symbolic_trace(Tensor0DInput())
21+
22+
# Validate that the results between Torch and Torch-TRT are similar
23+
optimized_model = torch_tensorrt.compile(
24+
fx_graph,
25+
"torch_compile",
26+
inputs,
27+
min_block_size=1,
28+
pass_through_build_failures=True,
29+
use_python_runtime=True,
30+
)
31+
optimized_model_results = optimized_model(*inputs).detach().cpu()
32+
torch_model_results = fx_graph(*inputs).detach().cpu()
33+
34+
max_diff = float(
35+
torch.max(torch.abs(optimized_model_results - torch_model_results))
36+
)
37+
self.assertAlmostEqual(
38+
max_diff,
39+
0,
40+
msg=f"0D-Tensor TRT outputs don't match with the original model.",
41+
)
42+
torch._dynamo.reset()
43+
44+
def test_1D_input(self):
45+
class Tensor1DInput(torch.nn.Module):
46+
def forward(self, x, y):
47+
return (x + 7.1) / (y * 2.1)
48+
49+
inputs = [torch.rand((3, 1)).cuda(), torch.rand((3, 1)).cuda()]
50+
51+
fx_graph = torch.fx.symbolic_trace(Tensor1DInput())
52+
53+
# Validate that the results between Torch and Torch-TRT are similar
54+
optimized_model = torch_tensorrt.compile(
55+
fx_graph,
56+
"torch_compile",
57+
inputs,
58+
min_block_size=1,
59+
pass_through_build_failures=True,
60+
use_python_runtime=True,
61+
)
62+
optimized_model_results = optimized_model(*inputs).detach().cpu()
63+
torch_model_results = fx_graph(*inputs).detach().cpu()
64+
65+
max_diff = float(
66+
torch.max(torch.abs(optimized_model_results - torch_model_results))
67+
)
68+
self.assertAlmostEqual(
69+
max_diff,
70+
0,
71+
msg=f"1D-Tensor TRT outputs don't match with the original model.",
72+
)
73+
74+
# Validate that the runtime moves cpu inputs to cuda
75+
optimized_model(torch.rand((3, 1)), torch.rand((3, 1)))
76+
77+
torch._dynamo.reset()
78+
79+
80+
if __name__ == "__main__":
81+
run_tests()

0 commit comments

Comments
 (0)