1
1
from __future__ import annotations
2
2
3
+ import logging
3
4
from typing import Any , Dict , List , Optional , Sequence , Tuple
4
5
6
+ # @manual=//deeplearning/trt/python:py_tensorrt
7
+ import tensorrt as trt
5
8
import torch
6
9
from torch .nn import Module
7
10
from torch_tensorrt .fx .utils import Frameworks , unified_dtype_converter
8
11
9
- # @manual=//deeplearning/trt/python:py_tensorrt
10
- import tensorrt as trt
12
+ logger = logging .getLogger (__name__ )
11
13
12
14
13
15
class PythonTorchTensorRTModule (Module ): # type: ignore[misc]
@@ -22,14 +24,12 @@ def __init__(
22
24
engine : trt .ICudaEngine ,
23
25
input_names : Optional [List [str ]] = None ,
24
26
output_names : Optional [List [str ]] = None ,
25
- cuda_graph_batch_size : int = - 1 ,
26
27
):
27
28
super (PythonTorchTensorRTModule , self ).__init__ ()
28
29
self ._register_state_dict_hook (PythonTorchTensorRTModule ._on_state_dict )
29
30
self .engine = engine
30
31
self .input_names = input_names if input_names is not None else []
31
32
self .output_names = output_names if output_names is not None else []
32
- self .cuda_graph_batch_size = cuda_graph_batch_size
33
33
self .initialized = False
34
34
self ._initialize ()
35
35
@@ -107,7 +107,6 @@ def _on_state_dict(self, state_dict: Dict[str, Any], prefix: str, _: Any) -> Non
107
107
state_dict [prefix + "engine" ] = bytearray (self .engine .serialize ())
108
108
state_dict [prefix + "input_names" ] = self .input_names
109
109
state_dict [prefix + "output_names" ] = self .output_names
110
- state_dict [prefix + "cuda_graph_batch_size" ] = self .cuda_graph_batch_size
111
110
112
111
def _load_from_state_dict (
113
112
self ,
@@ -156,8 +155,6 @@ def forward(self, *inputs: Any) -> torch.Tensor | Tuple[torch.Tensor, ...]:
156
155
self .input_names
157
156
), f"Wrong number of inputs, expect { len (self .input_names )} get { len (inputs )} ."
158
157
159
- # This is only used when the trt engine is using implicit batch dim.
160
- batch_size = inputs [0 ].shape [0 ]
161
158
contiguous_inputs : List [torch .Tensor ] = [i .contiguous () for i in inputs ]
162
159
bindings : List [Any ] = [None ] * (
163
160
len (self .input_names )
@@ -166,37 +163,34 @@ def forward(self, *inputs: Any) -> torch.Tensor | Tuple[torch.Tensor, ...]:
166
163
)
167
164
168
165
for i , input_name in enumerate (self .input_names ):
169
- assert inputs [
170
- i
171
- ].is_cuda , f"{ i } th input({ input_name } ) is not on cuda device."
166
+ if not contiguous_inputs [i ].is_cuda :
167
+ logger .warning (
168
+ f"Detected input { input_name } of engine { self .engine .name } is not on a cuda device. "
169
+ "This tensor is being moved by the runtime but for performance considerations, "
170
+ "ensure your inputs are all on GPU and open an issue here "
171
+ "(https://github.com/pytorch/TensorRT/issues) if this warning persists."
172
+ )
173
+ contiguous_inputs = (
174
+ contiguous_inputs [:i ]
175
+ + [contiguous_inputs [i ].cuda ()]
176
+ + contiguous_inputs [i + 1 :]
177
+ )
178
+
172
179
assert (
173
- inputs [i ].dtype == self .input_dtypes [i ]
174
- ), f"Dtype mismatch for { i } th input({ input_name } ). Expect { self .input_dtypes [i ]} , got { inputs [i ].dtype } ."
180
+ contiguous_inputs [i ].dtype == self .input_dtypes [i ]
181
+ ), f"Dtype mismatch for { i } th input({ input_name } ). Expect { self .input_dtypes [i ]} , got { contiguous_inputs [i ].dtype } ."
175
182
176
183
idx = self .input_binding_indices_in_order [i ]
177
184
bindings [idx ] = contiguous_inputs [i ].data_ptr ()
178
185
179
- if not self .engine .has_implicit_batch_dimension :
180
- self .context .set_binding_shape (
181
- idx , tuple (contiguous_inputs [i ].shape )
182
- )
183
- else :
184
- assert inputs [i ].size ()[1 :] == self .input_shapes [i ], (
185
- f"Shape mismatch for { i } th input({ input_name } ). "
186
- f"Expect { self .input_shapes [i ]} , got { inputs [i ].size ()[1 :]} ."
187
- )
188
-
189
186
with torch .autograd .profiler .record_function (
190
187
"PythonTorchTensorRTModule:ProcessOutputs"
191
188
):
192
189
# create output tensors
193
190
outputs : List [torch .Tensor ] = []
194
191
195
192
for i , idx in enumerate (self .output_binding_indices_in_order ):
196
- if self .engine .has_implicit_batch_dimension :
197
- shape = (batch_size ,) + self .output_shapes [i ]
198
- else :
199
- shape = tuple (self .context .get_binding_shape (idx ))
193
+ shape = tuple (self .context .get_binding_shape (idx ))
200
194
201
195
output = torch .empty (
202
196
size = shape ,
@@ -207,10 +201,7 @@ def forward(self, *inputs: Any) -> torch.Tensor | Tuple[torch.Tensor, ...]:
207
201
bindings [idx ] = output .data_ptr ()
208
202
209
203
for i , idx in enumerate (self .hidden_output_binding_indices_in_order ):
210
- if self .engine .has_implicit_batch_dimension :
211
- shape = (batch_size ,) + self .hidden_output_shapes [i ]
212
- else :
213
- shape = tuple (self .context .get_binding_shape (idx ))
204
+ shape = tuple (self .context .get_binding_shape (idx ))
214
205
215
206
output = torch .empty (
216
207
size = shape ,
@@ -222,14 +213,9 @@ def forward(self, *inputs: Any) -> torch.Tensor | Tuple[torch.Tensor, ...]:
222
213
with torch .autograd .profiler .record_function (
223
214
"PythonTorchTensorRTModule:TensorRTRuntime"
224
215
):
225
- if self .engine .has_implicit_batch_dimension :
226
- self .context .execute_async (
227
- batch_size , bindings , torch .cuda .current_stream ().cuda_stream
228
- )
229
- else :
230
- self .context .execute_async_v2 (
231
- bindings , torch .cuda .current_stream ().cuda_stream
232
- )
216
+ self .context .execute_async_v2 (
217
+ bindings , torch .cuda .current_stream ().cuda_stream
218
+ )
233
219
234
220
if len (outputs ) == 1 :
235
221
return outputs [0 ]
0 commit comments