Skip to content

Commit 5d563be

Browse files
committed
fix the globalpartitioner bug
#3157
1 parent 4dfe909 commit 5d563be

File tree

2 files changed

+19
-3
lines changed

2 files changed

+19
-3
lines changed

py/torch_tensorrt/dynamo/_compiler.py

+3-1
Original file line numberDiff line numberDiff line change
@@ -366,7 +366,6 @@ def contains_metadata(gm: torch.fx.GraphModule) -> bool:
366366

367367
# Partition module into components that can be TRT-accelerated
368368
fast_partitioner_failed = False
369-
370369
# If specified, try using the fast partitioner and fall back to the global one on failure
371370
if settings.use_fast_partitioner:
372371
try:
@@ -408,6 +407,9 @@ def contains_metadata(gm: torch.fx.GraphModule) -> bool:
408407
# Generate the corresponding TRT Module for those
409408
for name, _ in partitioned_module.named_children():
410409
submodule = getattr(partitioned_module, name)
410+
# filter on the GraphModule
411+
if not isinstance(submodule, torch.fx.graph_module.GraphModule):
412+
continue
411413
# Criteria for a module to be convertible to TRT
412414
if settings.use_fast_partitioner and "_run_on_acc" not in name:
413415
dryrun_tracker.to_run_in_torch.extend(parse_non_trt_nodes(submodule))

py/torch_tensorrt/dynamo/partitioning/_global_partitioner.py

+16-2
Original file line numberDiff line numberDiff line change
@@ -228,8 +228,22 @@ def partition(
228228
# Determine partitions based on user specifications and operator support
229229
# Then, fuse partitions and display overview of supported/unsupported operators
230230
partitions = partitioner.propose_partitions()
231-
fused_graph = partitioner.fuse_partitions(partitions)
232-
231+
# TODO: confirm with Naren whether this change is required or not
232+
# tested both with and without this change, it both works
233+
# the only difference is the graph node name, an example is as below:
234+
# graph():
235+
# %x : [num_users=1] = placeholder[target=x]
236+
# %_run_on_acc_0 : [num_users=1] = call_module[target=_run_on_acc_0](args = (%x,), kwargs = {})
237+
# return (_run_on_acc_0,)
238+
239+
# or
240+
241+
# graph():
242+
# %x : [num_users=1] = placeholder[target=x]
243+
# %fused_0 : [num_users=1] = call_module[target=fused_0](args = (%x,), kwargs = {})
244+
# return (fused_0,)
245+
246+
fused_graph = partitioner.fuse_partitions(partitions, prefix="_run_on_acc_")
233247
if verbose:
234248
supported_ops.print_support_overview(len(partitions))
235249

0 commit comments

Comments
 (0)