7
7
8
8
import torch
9
9
import torch_tensorrt
10
- from torch ._subclasses .fake_tensor import FakeTensorMode
11
10
from torch .fx .experimental .proxy_tensor import unset_fake_temporarily
12
11
from torch_tensorrt .dynamo import partitioning
13
- from torch_tensorrt .dynamo .conversion import DYNAMIC_DIM
14
- from torch_tensorrt .dynamo .utils import input_is_dynamic
15
12
from torch_tensorrt .runtime ._utils import _is_switch_required , _select_rt_device
16
13
17
14
logger = logging .getLogger (__name__ )
@@ -21,25 +18,18 @@ class WrapperTorchTensorRTModule(torch.nn.Module): # type: ignore[misc]
21
18
"""This Wrapper runtime module is to record/replay whole cuda graph in sub modules
22
19
23
20
Args:
24
- original_module: Unmodified FX GraphModule
25
21
compiled_module: Complied fx graphModule that will be wrapped
26
- output_shapes: Shapes of output Tensors of the graph
27
- output_dtypes: Output data types of the graph
28
22
Returns:
29
23
Output tensor or tensor list
30
24
"""
31
25
32
26
def __init__ (
33
27
self ,
34
28
compiled_module : torch .nn .Module ,
35
- output_shapes : List [torch .Size ],
36
- output_dtypes : List [torch .dtype ],
37
29
):
38
30
super (WrapperTorchTensorRTModule , self ).__init__ ()
39
31
self .compiled_module = compiled_module
40
32
self .inputs = partitioning .construct_submodule_inputs (compiled_module )
41
- self .output_shapes = output_shapes
42
- self .output_dtypes = output_dtypes
43
33
44
34
self ._input_buffers : List [torch .Tensor ] = []
45
35
self ._output_buffers : List [torch .Tensor ] = []
@@ -49,7 +39,6 @@ def __init__(
49
39
self .prev_cudagraphs_enabled = False
50
40
self ._caller_stream : Optional [torch .cuda .Stream ] = None
51
41
self ._engine_stream : Optional [torch .cuda .Stream ] = None
52
- self .input_is_dynamic = input_is_dynamic (self .inputs )
53
42
54
43
# Disable cudagrphs in submodules as it will be enabled in wrapper
55
44
for name , rt_mod in self .compiled_module .named_children ():
@@ -82,18 +71,9 @@ def validate_input_shapes(self, inputs: Sequence[torch.Tensor]) -> bool:
82
71
# x: (3, 4), y: (4, 5) --> Key: (3,4)(4,5)
83
72
new_shape_key = "" .join (str (tuple (t .shape )).replace (" " , "" ) for t in inputs )
84
73
85
- # If the new shape key differs from the existing one, infer new output shape
86
74
if new_shape_key != self .shape_key :
87
75
logger .debug (f"Input shape changed { self .shape_key } -> { new_shape_key } " )
88
76
self .shape_key = new_shape_key
89
-
90
- if self .input_is_dynamic :
91
- with FakeTensorMode (allow_non_fake_inputs = True ):
92
- tmp_outputs = self .compiled_module (* inputs )
93
- if not isinstance (tmp_outputs , (list , tuple )):
94
- tmp_outputs = [tmp_outputs ]
95
- self .output_shapes = [tuple (output .shape ) for output in tmp_outputs ]
96
- print ("self.output_shapes " , self .output_shapes )
97
77
return True
98
78
99
79
return False
@@ -128,7 +108,6 @@ def forward(self, *inputs: torch.Tensor) -> torch.Tensor | Tuple[torch.Tensor, .
128
108
self .cudagraph .reset ()
129
109
130
110
self ._input_buffers = [None ] * len (self .inputs )
131
- self ._output_buffers = [None ] * len (self .output_shapes )
132
111
133
112
if not cudagraphs_enabled and self .cudagraph :
134
113
self .cudagraph .reset ()
@@ -202,32 +181,6 @@ def forward(self, *inputs: torch.Tensor) -> torch.Tensor | Tuple[torch.Tensor, .
202
181
elif cudagraphs_enabled :
203
182
self ._input_buffers [i ].copy_ (contiguous_inputs [i ])
204
183
205
- with (
206
- torch .autograd .profiler .record_function (
207
- "WrapperTorchTensorRTModule:ProcessOutputs"
208
- )
209
- if self .profiling_enabled
210
- else nullcontext ()
211
- ):
212
- # create output tensors
213
- outputs : List [torch .Tensor ] = []
214
-
215
- for o , shape in enumerate (self .output_shapes ):
216
- if DYNAMIC_DIM in shape :
217
- raise ValueError (
218
- "Encountered dynamic output shapes during runtime. This could mean the network has data-dependent output shapes which is not currently supported."
219
- )
220
-
221
- output = torch .empty (
222
- size = shape ,
223
- dtype = self .output_dtypes [o ],
224
- device = torch .cuda .current_device (),
225
- )
226
-
227
- outputs .append (output )
228
-
229
- if need_cudagraphs_record :
230
- self ._output_buffers [o ] = outputs [o ].clone ()
231
184
with (
232
185
torch .autograd .profiler .record_function (
233
186
"WrapperTorchTensorRTModule:TensorRTRuntime"
@@ -277,13 +230,10 @@ def forward(self, *inputs: torch.Tensor) -> torch.Tensor | Tuple[torch.Tensor, .
277
230
output_buffers = self ._output_buffers
278
231
else :
279
232
output_buffers = [self ._output_buffers ]
280
- for idx , o in enumerate (outputs ):
281
- o .copy_ (output_buffers [idx ])
282
-
233
+ outputs = [output .clone () for output in output_buffers ]
283
234
if len (outputs ) == 1 :
284
235
return outputs [0 ]
285
236
286
237
return outputs
287
238
else :
288
-
289
239
return outputs
0 commit comments