File tree 2 files changed +19
-3
lines changed
2 files changed +19
-3
lines changed Original file line number Diff line number Diff line change @@ -366,7 +366,6 @@ def contains_metadata(gm: torch.fx.GraphModule) -> bool:
366
366
367
367
# Partition module into components that can be TRT-accelerated
368
368
fast_partitioner_failed = False
369
-
370
369
# If specified, try using the fast partitioner and fall back to the global one on failure
371
370
if settings .use_fast_partitioner :
372
371
try :
@@ -408,6 +407,9 @@ def contains_metadata(gm: torch.fx.GraphModule) -> bool:
408
407
# Generate the corresponding TRT Module for those
409
408
for name , _ in partitioned_module .named_children ():
410
409
submodule = getattr (partitioned_module , name )
410
+ # filter on the GraphModule
411
+ if not isinstance (submodule , torch .fx .graph_module .GraphModule ):
412
+ continue
411
413
# Criteria for a module to be convertible to TRT
412
414
if settings .use_fast_partitioner and "_run_on_acc" not in name :
413
415
dryrun_tracker .to_run_in_torch .extend (parse_non_trt_nodes (submodule ))
Original file line number Diff line number Diff line change @@ -228,8 +228,22 @@ def partition(
228
228
# Determine partitions based on user specifications and operator support
229
229
# Then, fuse partitions and display overview of supported/unsupported operators
230
230
partitions = partitioner .propose_partitions ()
231
- fused_graph = partitioner .fuse_partitions (partitions )
232
-
231
+ # TODO: confirm with Naren whether this change is required or not
232
+ # tested both with and without this change, it both works
233
+ # the only difference is the graph node name, an example is as below:
234
+ # graph():
235
+ # %x : [num_users=1] = placeholder[target=x]
236
+ # %_run_on_acc_0 : [num_users=1] = call_module[target=_run_on_acc_0](args = (%x,), kwargs = {})
237
+ # return (_run_on_acc_0,)
238
+
239
+ # or
240
+
241
+ # graph():
242
+ # %x : [num_users=1] = placeholder[target=x]
243
+ # %fused_0 : [num_users=1] = call_module[target=fused_0](args = (%x,), kwargs = {})
244
+ # return (fused_0,)
245
+
246
+ fused_graph = partitioner .fuse_partitions (partitions , prefix = "_run_on_acc_" )
233
247
if verbose :
234
248
supported_ops .print_support_overview (len (partitions ))
235
249
You can’t perform that action at this time.
0 commit comments