Skip to content

Commit 780e398

Browse files
authored
chore/fix: Update TRTInterpreter impl in Dynamo compile [1 / x] (#2002)
1 parent 07a8c22 commit 780e398

File tree

4 files changed

+20
-14
lines changed

4 files changed

+20
-14
lines changed

Diff for: py/torch_tensorrt/dynamo/backend/__init__.py

+8-6
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,7 @@
1414
from torch_tensorrt.dynamo.backend._defaults import (
1515
PRECISION,
1616
DEBUG,
17-
MAX_WORKSPACE_SIZE,
17+
WORKSPACE_SIZE,
1818
MIN_BLOCK_SIZE,
1919
PASS_THROUGH_BUILD_FAILURES,
2020
)
@@ -35,7 +35,7 @@ def compile(
3535
debug=DEBUG,
3636
capability=EngineCapability.default,
3737
num_avg_timing_iters=1,
38-
workspace_size=MAX_WORKSPACE_SIZE,
38+
workspace_size=WORKSPACE_SIZE,
3939
dla_sram_size=1048576,
4040
dla_local_dram_size=1073741824,
4141
dla_global_dram_size=536870912,
@@ -105,7 +105,7 @@ def compile(
105105
def create_backend(
106106
precision: LowerPrecision = PRECISION,
107107
debug: bool = DEBUG,
108-
workspace_size: int = MAX_WORKSPACE_SIZE,
108+
workspace_size: int = WORKSPACE_SIZE,
109109
min_block_size: int = MIN_BLOCK_SIZE,
110110
torch_executed_ops: Sequence[str] = set(),
111111
pass_through_build_failures: bool = PASS_THROUGH_BUILD_FAILURES,
@@ -114,10 +114,12 @@ def create_backend(
114114
"""Create torch.compile backend given specified arguments
115115
116116
Args:
117-
precision:
118-
debug: Whether to print out verbose debugging information
119-
workspace_size: Maximum workspace TRT is allowed to use for the module
120117
precision: Model Layer precision
118+
debug: Whether to print out verbose debugging information
119+
workspace_size: Workspace TRT is allowed to use for the module (0 is default)
120+
min_block_size: Minimum number of operators per TRT-Engine Block
121+
torch_executed_ops: Sequence of operations to run in Torch, regardless of converter coverage
122+
pass_through_build_failures: Whether to fail on TRT engine build errors (True) or not (False)
121123
Returns:
122124
Backend for torch.compile
123125
"""

Diff for: py/torch_tensorrt/dynamo/backend/_defaults.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,6 @@
33

44
PRECISION = LowerPrecision.FP32
55
DEBUG = False
6-
MAX_WORKSPACE_SIZE = 20 << 30
6+
WORKSPACE_SIZE = 0
77
MIN_BLOCK_SIZE = 5
88
PASS_THROUGH_BUILD_FAILURES = False

Diff for: py/torch_tensorrt/dynamo/backend/_settings.py

+2-2
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@
55
from torch_tensorrt.dynamo.backend._defaults import (
66
PRECISION,
77
DEBUG,
8-
MAX_WORKSPACE_SIZE,
8+
WORKSPACE_SIZE,
99
MIN_BLOCK_SIZE,
1010
PASS_THROUGH_BUILD_FAILURES,
1111
)
@@ -15,7 +15,7 @@
1515
class CompilationSettings:
1616
precision: LowerPrecision = PRECISION
1717
debug: bool = DEBUG
18-
workspace_size: int = MAX_WORKSPACE_SIZE
18+
workspace_size: int = WORKSPACE_SIZE
1919
min_block_size: int = MIN_BLOCK_SIZE
2020
torch_executed_ops: Sequence[str] = field(default_factory=set)
2121
pass_through_build_failures: bool = PASS_THROUGH_BUILD_FAILURES

Diff for: py/torch_tensorrt/dynamo/backend/conversion.py

+9-5
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33
from torch_tensorrt.fx.trt_module import TRTModule
44
from torch_tensorrt import TRTModuleNext
55
from torch_tensorrt.dynamo.backend._settings import CompilationSettings
6-
from torch_tensorrt.fx.fx2trt import (
6+
from torch_tensorrt.dynamo.fx_ts_compat.fx2trt import (
77
InputTensorSpec,
88
TRTInterpreter,
99
)
@@ -24,15 +24,15 @@ def convert_module(
2424
Returns:
2525
TRTModule or TRTModuleNext
2626
"""
27-
interp = TRTInterpreter(
27+
interpreter = TRTInterpreter(
2828
module,
2929
InputTensorSpec.from_tensors(inputs),
3030
explicit_batch_dimension=True,
3131
logger_level=(trt.Logger.VERBOSE if settings.debug else trt.Logger.WARNING),
3232
)
3333

34-
r = interp.run(
35-
max_workspace_size=settings.workspace_size,
34+
interpreter_result = interpreter.run(
35+
workspace_size=settings.workspace_size,
3636
lower_precision=settings.precision,
3737
profiling_verbosity=(
3838
trt.ProfilingVerbosity.VERBOSE
@@ -41,4 +41,8 @@ def convert_module(
4141
),
4242
)
4343

44-
return TRTModule(*r)
44+
return TRTModule(
45+
engine=interpreter_result.engine,
46+
input_names=interpreter_result.input_names,
47+
output_names=interpreter_result.output_names,
48+
)

0 commit comments

Comments
 (0)