1
1
from __future__ import annotations
2
2
3
+ import logging
3
4
from typing import Any , Dict , List , Optional , Sequence , Tuple
4
5
6
+ import tensorrt as trt
5
7
import torch
6
8
from torch .nn import Module
7
9
from torch_tensorrt .fx .utils import Frameworks , unified_dtype_converter
8
10
9
- # @manual=//deeplearning/trt/python:py_tensorrt
10
- import tensorrt as trt
11
+ logger = logging .getLogger (__name__ )
11
12
12
13
13
14
class PythonTorchTensorRTModule (Module ): # type: ignore[misc]
@@ -22,14 +23,12 @@ def __init__(
22
23
engine : trt .ICudaEngine ,
23
24
input_names : Optional [List [str ]] = None ,
24
25
output_names : Optional [List [str ]] = None ,
25
- cuda_graph_batch_size : int = - 1 ,
26
26
):
27
27
super (PythonTorchTensorRTModule , self ).__init__ ()
28
28
self ._register_state_dict_hook (PythonTorchTensorRTModule ._on_state_dict )
29
29
self .engine = engine
30
30
self .input_names = input_names if input_names is not None else []
31
31
self .output_names = output_names if output_names is not None else []
32
- self .cuda_graph_batch_size = cuda_graph_batch_size
33
32
self .initialized = False
34
33
self ._initialize ()
35
34
@@ -107,7 +106,6 @@ def _on_state_dict(self, state_dict: Dict[str, Any], prefix: str, _: Any) -> Non
107
106
state_dict [prefix + "engine" ] = bytearray (self .engine .serialize ())
108
107
state_dict [prefix + "input_names" ] = self .input_names
109
108
state_dict [prefix + "output_names" ] = self .output_names
110
- state_dict [prefix + "cuda_graph_batch_size" ] = self .cuda_graph_batch_size
111
109
112
110
def _load_from_state_dict (
113
111
self ,
@@ -156,8 +154,6 @@ def forward(self, *inputs: Any) -> torch.Tensor | Tuple[torch.Tensor, ...]:
156
154
self .input_names
157
155
), f"Wrong number of inputs, expect { len (self .input_names )} get { len (inputs )} ."
158
156
159
- # This is only used when the trt engine is using implicit batch dim.
160
- batch_size = inputs [0 ].shape [0 ]
161
157
contiguous_inputs : List [torch .Tensor ] = [i .contiguous () for i in inputs ]
162
158
bindings : List [Any ] = [None ] * (
163
159
len (self .input_names )
@@ -166,25 +162,29 @@ def forward(self, *inputs: Any) -> torch.Tensor | Tuple[torch.Tensor, ...]:
166
162
)
167
163
168
164
for i , input_name in enumerate (self .input_names ):
169
- assert inputs [
170
- i
171
- ].is_cuda , f"{ i } th input({ input_name } ) is not on cuda device."
165
+ if not contiguous_inputs [i ].is_cuda :
166
+ logger .warning (
167
+ f"Detected input { input_name } of engine { self .engine .name } is not on a cuda device. "
168
+ "This tensor is being moved by the runtime but for performance considerations, "
169
+ "ensure your inputs are all on GPU and open an issue here "
170
+ "(https://github.com/pytorch/TensorRT/issues) if this warning persists."
171
+ )
172
+ contiguous_inputs = (
173
+ contiguous_inputs [:i ]
174
+ + [contiguous_inputs [i ].cuda ()]
175
+ + contiguous_inputs [i + 1 :]
176
+ )
177
+
172
178
assert (
173
- inputs [i ].dtype == self .input_dtypes [i ]
174
- ), f"Dtype mismatch for { i } th input({ input_name } ). Expect { self .input_dtypes [i ]} , got { inputs [i ].dtype } ."
179
+ contiguous_inputs [i ].dtype == self .input_dtypes [i ]
180
+ ), f"Dtype mismatch for { i } th input({ input_name } ). Expect { self .input_dtypes [i ]} , got { contiguous_inputs [i ].dtype } ."
175
181
176
182
idx = self .input_binding_indices_in_order [i ]
177
183
bindings [idx ] = contiguous_inputs [i ].data_ptr ()
178
184
179
- if not self .engine .has_implicit_batch_dimension :
180
- self .context .set_binding_shape (
181
- idx , tuple (contiguous_inputs [i ].shape )
182
- )
183
- else :
184
- assert inputs [i ].size ()[1 :] == self .input_shapes [i ], (
185
- f"Shape mismatch for { i } th input({ input_name } ). "
186
- f"Expect { self .input_shapes [i ]} , got { inputs [i ].size ()[1 :]} ."
187
- )
185
+ self .context .set_binding_shape (
186
+ idx , tuple (contiguous_inputs [i ].shape )
187
+ )
188
188
189
189
with torch .autograd .profiler .record_function (
190
190
"PythonTorchTensorRTModule:ProcessOutputs"
@@ -193,10 +193,7 @@ def forward(self, *inputs: Any) -> torch.Tensor | Tuple[torch.Tensor, ...]:
193
193
outputs : List [torch .Tensor ] = []
194
194
195
195
for i , idx in enumerate (self .output_binding_indices_in_order ):
196
- if self .engine .has_implicit_batch_dimension :
197
- shape = (batch_size ,) + self .output_shapes [i ]
198
- else :
199
- shape = tuple (self .context .get_binding_shape (idx ))
196
+ shape = tuple (self .context .get_binding_shape (idx ))
200
197
201
198
output = torch .empty (
202
199
size = shape ,
@@ -207,10 +204,7 @@ def forward(self, *inputs: Any) -> torch.Tensor | Tuple[torch.Tensor, ...]:
207
204
bindings [idx ] = output .data_ptr ()
208
205
209
206
for i , idx in enumerate (self .hidden_output_binding_indices_in_order ):
210
- if self .engine .has_implicit_batch_dimension :
211
- shape = (batch_size ,) + self .hidden_output_shapes [i ]
212
- else :
213
- shape = tuple (self .context .get_binding_shape (idx ))
207
+ shape = tuple (self .context .get_binding_shape (idx ))
214
208
215
209
output = torch .empty (
216
210
size = shape ,
@@ -222,14 +216,9 @@ def forward(self, *inputs: Any) -> torch.Tensor | Tuple[torch.Tensor, ...]:
222
216
with torch .autograd .profiler .record_function (
223
217
"PythonTorchTensorRTModule:TensorRTRuntime"
224
218
):
225
- if self .engine .has_implicit_batch_dimension :
226
- self .context .execute_async (
227
- batch_size , bindings , torch .cuda .current_stream ().cuda_stream
228
- )
229
- else :
230
- self .context .execute_async_v2 (
231
- bindings , torch .cuda .current_stream ().cuda_stream
232
- )
219
+ self .context .execute_async_v2 (
220
+ bindings , torch .cuda .current_stream ().cuda_stream
221
+ )
233
222
234
223
if len (outputs ) == 1 :
235
224
return outputs [0 ]
0 commit comments