diff --git a/examples/fx/fx2trt_example_next.py b/examples/fx/fx2trt_example_next.py index f7b5ef1404..9fd1386df7 100644 --- a/examples/fx/fx2trt_example_next.py +++ b/examples/fx/fx2trt_example_next.py @@ -8,7 +8,10 @@ import torch_tensorrt.fx.tracer.acc_tracer.acc_tracer as acc_tracer from torch_tensorrt.fx import InputTensorSpec, TRTInterpreter from torch_tensorrt.fx.tools.trt_splitter import TRTSplitter, TRTSplitterSetting -from torch_tensorrt import TRTModuleNext as TRTModule, Device +from torch_tensorrt.dynamo._TorchTensorRTModule import ( + TorchTensorRTModule as TRTModule, + Device, +) # The purpose of this example is to demonstrate the overall flow of lowering a PyTorch # model to TensorRT via FX with existing FX based tooling. The general lowering flow diff --git a/py/torch_tensorrt/__init__.py b/py/torch_tensorrt/__init__.py index f7bc963343..952563f5ca 100644 --- a/py/torch_tensorrt/__init__.py +++ b/py/torch_tensorrt/__init__.py @@ -91,7 +91,6 @@ def _find_lib(name, paths): from torch_tensorrt import logging from torch_tensorrt._Input import Input from torch_tensorrt._Device import Device -from torch_tensorrt._TRTModuleNext import TRTModuleNext from torch_tensorrt import fx diff --git a/py/torch_tensorrt/_TRTModuleNext.py b/py/torch_tensorrt/dynamo/_TorchTensorRTModule.py similarity index 89% rename from py/torch_tensorrt/_TRTModuleNext.py rename to py/torch_tensorrt/dynamo/_TorchTensorRTModule.py index ca77c5bd06..8359bc62fb 100644 --- a/py/torch_tensorrt/_TRTModuleNext.py +++ b/py/torch_tensorrt/dynamo/_TorchTensorRTModule.py @@ -1,6 +1,5 @@ import logging -from operator import truediv -from typing import Any, List, Sequence, Tuple +from typing import Any, List, Tuple import torch from torch_tensorrt import _C @@ -9,8 +8,8 @@ logger = logging.getLogger(__name__) -class TRTModuleNext(torch.nn.Module): - """TRTModuleNext is a PyTorch module which encompasses an arbitrary TensorRT Engine. +class TorchTensorRTModule(torch.nn.Module): + """TorchTensorRTModule is a PyTorch module which encompasses an arbitrary TensorRT Engine. This module is backed by the Torch-TensorRT runtime and is fully compatibile with both FX / Python deployments (just ``import torch_tensorrt`` as part of the application) as @@ -20,7 +19,7 @@ class TRTModuleNext(torch.nn.Module): The forward function is simpily forward(*args: torch.Tensor) -> Tuple[torch.Tensor] where the internal implementation is ``return Tuple(torch.ops.tensorrt.execute_engine(list(inputs), self.engine))`` - > Note: TRTModuleNext only supports engines built with explict batch + > Note: TorchTensorRTModule only supports engines built with explict batch Attributes: name (str): Name of module (for easier debugging) @@ -37,7 +36,7 @@ def __init__( output_binding_names: List[str] = [], target_device: Device = Device._current_device(), ): - """__init__ method for torch_tensorrt.TRTModuleNext + """__init__ method for torch_tensorrt.TorchTensorRTModule Takes a name, target device, serialized TensorRT engine, and binding names / order and constructs a PyTorch ``torch.nn.Module`` around it. @@ -70,10 +69,7 @@ def __init__( ) """ - logger.warning( - "TRTModuleNext should be considered experimental stability, APIs are subject to change. Note: TRTModuleNext only supports engines built with explict batch" - ) - super(TRTModuleNext, self).__init__() + super(TorchTensorRTModule, self).__init__() if not isinstance(serialized_engine, bytearray): ValueError("Expected serialized engine as bytearray") @@ -89,8 +85,8 @@ def __init__( self.name + "_engine" if self.name != "" else "tensorrt_engine", target_device._to_serialized_rt_device(), serialized_engine, - TRTModuleNext._pack_binding_names(self.input_binding_names), - TRTModuleNext._pack_binding_names(self.output_binding_names), + TorchTensorRTModule._pack_binding_names(self.input_binding_names), + TorchTensorRTModule._pack_binding_names(self.output_binding_names), ] ) else: @@ -154,7 +150,7 @@ def is_non_tensor(i: Tuple[Any, bool]) -> bool: non_tensors = [i[0] for i in filter(zip(inputs, types), is_non_tensor)] raise RuntimeError( - f"TRTModuleNext expects a flattened list of tensors as input, found non tensors: {non_tensors}" + f"TorchTensorRTModule expects a flattened list of tensors as input, found non tensors: {non_tensors}" ) outputs = torch.ops.tensorrt.execute_engine(list(inputs), self.engine) diff --git a/py/torch_tensorrt/dynamo/backend/__init__.py b/py/torch_tensorrt/dynamo/backend/__init__.py index 5afe888473..38e60fce41 100644 --- a/py/torch_tensorrt/dynamo/backend/__init__.py +++ b/py/torch_tensorrt/dynamo/backend/__init__.py @@ -4,7 +4,7 @@ import torch_tensorrt from functools import partial -from typing import Any, Sequence +from typing import Any, Optional, Sequence from torch_tensorrt import EngineCapability, Device from torch_tensorrt.fx.utils import LowerPrecision @@ -16,6 +16,10 @@ WORKSPACE_SIZE, MIN_BLOCK_SIZE, PASS_THROUGH_BUILD_FAILURES, + MAX_AUX_STREAMS, + VERSION_COMPATIBLE, + OPTIMIZATION_LEVEL, + USE_PYTHON_RUNTIME, ) @@ -45,6 +49,10 @@ def compile( torch_executed_ops=[], torch_executed_modules=[], pass_through_build_failures=PASS_THROUGH_BUILD_FAILURES, + max_aux_streams=MAX_AUX_STREAMS, + version_compatible=VERSION_COMPATIBLE, + optimization_level=OPTIMIZATION_LEVEL, + use_python_runtime=USE_PYTHON_RUNTIME, **kwargs, ): if debug: @@ -91,6 +99,10 @@ def compile( min_block_size=min_block_size, torch_executed_ops=torch_executed_ops, pass_through_build_failures=pass_through_build_failures, + max_aux_streams=max_aux_streams, + version_compatible=version_compatible, + optimization_level=optimization_level, + use_python_runtime=use_python_runtime, **kwargs, ) @@ -114,6 +126,10 @@ def create_backend( min_block_size: int = MIN_BLOCK_SIZE, torch_executed_ops: Sequence[str] = set(), pass_through_build_failures: bool = PASS_THROUGH_BUILD_FAILURES, + max_aux_streams: Optional[int] = MAX_AUX_STREAMS, + version_compatible: bool = VERSION_COMPATIBLE, + optimization_level: Optional[int] = OPTIMIZATION_LEVEL, + use_python_runtime: Optional[bool] = USE_PYTHON_RUNTIME, **kwargs, ): """Create torch.compile backend given specified arguments @@ -125,6 +141,13 @@ def create_backend( min_block_size: Minimum number of operators per TRT-Engine Block torch_executed_ops: Sequence of operations to run in Torch, regardless of converter coverage pass_through_build_failures: Whether to fail on TRT engine build errors (True) or not (False) + max_aux_streams: Maximum number of allowed auxiliary TRT streams for each engine + version_compatible: Provide version forward-compatibility for engine plan files + optimization_level: Builder optimization 0-5, higher levels imply longer build time, + searching for more optimization options. TRT defaults to 3 + use_python_runtime: Whether to strictly use Python runtime or C++ runtime. To auto-select a runtime + based on C++ dependency presence (preferentially choosing C++ runtime if available), leave the + argument as None Returns: Backend for torch.compile """ @@ -136,4 +159,9 @@ def create_backend( min_block_size=min_block_size, torch_executed_ops=torch_executed_ops, pass_through_build_failures=pass_through_build_failures, + max_aux_streams=max_aux_streams, + version_compatible=version_compatible, + optimization_level=optimization_level, + use_python_runtime=use_python_runtime, + **kwargs, ) diff --git a/py/torch_tensorrt/dynamo/backend/_defaults.py b/py/torch_tensorrt/dynamo/backend/_defaults.py index bb34f2dcac..0afbc60f8c 100644 --- a/py/torch_tensorrt/dynamo/backend/_defaults.py +++ b/py/torch_tensorrt/dynamo/backend/_defaults.py @@ -6,3 +6,7 @@ WORKSPACE_SIZE = 0 MIN_BLOCK_SIZE = 5 PASS_THROUGH_BUILD_FAILURES = False +MAX_AUX_STREAMS = None +VERSION_COMPATIBLE = False +OPTIMIZATION_LEVEL = None +USE_PYTHON_RUNTIME = None diff --git a/py/torch_tensorrt/dynamo/backend/_settings.py b/py/torch_tensorrt/dynamo/backend/_settings.py index 73bc08a419..d074a6b079 100644 --- a/py/torch_tensorrt/dynamo/backend/_settings.py +++ b/py/torch_tensorrt/dynamo/backend/_settings.py @@ -1,5 +1,5 @@ from dataclasses import dataclass, field -from typing import Sequence +from typing import Optional, Sequence from torch_tensorrt.fx.utils import LowerPrecision from torch_tensorrt.dynamo.backend._defaults import ( @@ -8,10 +8,14 @@ WORKSPACE_SIZE, MIN_BLOCK_SIZE, PASS_THROUGH_BUILD_FAILURES, + MAX_AUX_STREAMS, + VERSION_COMPATIBLE, + OPTIMIZATION_LEVEL, + USE_PYTHON_RUNTIME, ) -@dataclass(frozen=True) +@dataclass class CompilationSettings: precision: LowerPrecision = PRECISION debug: bool = DEBUG @@ -19,3 +23,7 @@ class CompilationSettings: min_block_size: int = MIN_BLOCK_SIZE torch_executed_ops: Sequence[str] = field(default_factory=set) pass_through_build_failures: bool = PASS_THROUGH_BUILD_FAILURES + max_aux_streams: Optional[int] = MAX_AUX_STREAMS + version_compatible: bool = VERSION_COMPATIBLE + optimization_level: Optional[int] = OPTIMIZATION_LEVEL + use_python_runtime: Optional[bool] = USE_PYTHON_RUNTIME diff --git a/py/torch_tensorrt/dynamo/backend/backends.py b/py/torch_tensorrt/dynamo/backend/backends.py index 1d770b86a3..b97079948e 100644 --- a/py/torch_tensorrt/dynamo/backend/backends.py +++ b/py/torch_tensorrt/dynamo/backend/backends.py @@ -139,6 +139,7 @@ def _compile_module( submodule, submodule_inputs, settings=settings, + name=name, ) trt_modules[name] = trt_mod diff --git a/py/torch_tensorrt/dynamo/backend/conversion.py b/py/torch_tensorrt/dynamo/backend/conversion.py index 310b6f86ce..425fb0941e 100644 --- a/py/torch_tensorrt/dynamo/backend/conversion.py +++ b/py/torch_tensorrt/dynamo/backend/conversion.py @@ -1,7 +1,7 @@ from typing import Sequence, Union import torch +import io from torch_tensorrt.fx.trt_module import TRTModule -from torch_tensorrt import TRTModuleNext from torch_tensorrt.dynamo.backend._settings import CompilationSettings from torch_tensorrt.dynamo.fx_ts_compat.fx2trt import ( InputTensorSpec, @@ -15,12 +15,14 @@ def convert_module( module: torch.fx.GraphModule, inputs: Sequence[torch.Tensor], settings: CompilationSettings = CompilationSettings(), -) -> Union[TRTModuleNext, TRTModule]: + name: str = "", +): """Convert an FX module to a TRT module Args: module: FX GraphModule to convert inputs: Sequence of Tensors representing inputs to the module settings: Compilation settings + name: TRT engine name Returns: TRTModule or TRTModuleNext """ @@ -48,10 +50,27 @@ def convert_module( if settings.debug else trt.ProfilingVerbosity.LAYER_NAMES_ONLY ), + max_aux_streams=settings.max_aux_streams, + version_compatible=settings.version_compatible, + optimization_level=settings.optimization_level, ) - return TRTModule( - engine=interpreter_result.engine, - input_names=interpreter_result.input_names, - output_names=interpreter_result.output_names, - ) + if settings.use_python_runtime: + return TRTModule( + engine=interpreter_result.engine, + input_names=interpreter_result.input_names, + output_names=interpreter_result.output_names, + ) + + else: + from torch_tensorrt.dynamo._TorchTensorRTModule import TorchTensorRTModule + + with io.BytesIO() as engine_bytes: + engine_bytes.write(interpreter_result.engine.serialize()) + engine_str = engine_bytes.getvalue() + return TorchTensorRTModule( + serialized_engine=engine_str, + name=name, + input_binding_names=interpreter_result.input_names, + output_binding_names=interpreter_result.output_names, + ) diff --git a/py/torch_tensorrt/dynamo/backend/test/test_backend_compiler.py b/py/torch_tensorrt/dynamo/backend/test/test_backend_compiler.py new file mode 100644 index 0000000000..2af251adbc --- /dev/null +++ b/py/torch_tensorrt/dynamo/backend/test/test_backend_compiler.py @@ -0,0 +1,173 @@ +from torch_tensorrt.dynamo.backend.lowering import partition +from torch.testing._internal.common_utils import run_tests, TestCase +import torch +from copy import deepcopy +from torch_tensorrt.dynamo import compile +from utils import lower_graph_testing +from torch_tensorrt.dynamo.common_utils.test_utils import DECIMALS_OF_AGREEMENT + + +class TestTRTModuleNextCompilation(TestCase): + def test_trt_module_next_full_support(self): + class FullySupportedMultiOp(torch.nn.Module): + def forward(self, x, y): + out = x - y + out = out + x + out = 2 * out + out = out + y + return torch.mean(out, dim=1) + + fx_graph = torch.fx.symbolic_trace(FullySupportedMultiOp()) + partitioned_graph = partition(deepcopy(fx_graph), min_block_size=3) + + self.assertEquals( + len(list(partitioned_graph.named_children())), + 1, + "All operators are supported, there should be one segment", + ) + + inputs = [ + torch.randint(-5, 5, (16, 7), dtype=torch.float).cuda(), + torch.randint(-5, 5, (16, 7), dtype=torch.float).cuda(), + ] + + torch._dynamo.reset() + + # Validate that the results between Torch and Torch-TRT are similar + optimized_model = compile( + fx_graph, + inputs, + min_block_size=1, + pass_through_build_failures=True, + torch_executed_ops={"torch.ops.aten.add.Tensor"}, + use_python_runtime=False, + debug=True, + ) + optimized_model_results = optimized_model(*inputs).detach().cpu() + torch_model_results = fx_graph(*inputs).detach().cpu() + + max_diff = float( + torch.max(torch.abs(optimized_model_results - torch_model_results)) + ) + self.assertAlmostEqual( + max_diff, + 0, + DECIMALS_OF_AGREEMENT, + f"TRT outputs don't match with the original model.", + ) + + def test_trt_module_next_partial_support(self): + class PartiallySupportedMultiOp(torch.nn.Module): + def forward(self, x, y): + out = x - y + out = out - 3 * x + out = out + y + out = out.to(torch.float) + out = 2 * out + return torch.mean(out, dim=-1) + + fx_graph = torch.fx.symbolic_trace(PartiallySupportedMultiOp()) + unexpected_ops = {torch.ops.aten.add.Tensor} + + inputs = [ + torch.randint(-40, 40, (16, 7, 5), dtype=torch.int).cuda(), + torch.randint(1, 40, (16, 7, 5), dtype=torch.int).cuda(), + ] + + (unexpected_ops_seen, _, partitioned_graphs,) = lower_graph_testing( + fx_graph, + inputs, + unexpected_ops=unexpected_ops, + min_block_size=1, + torch_executed_ops={"torch.ops.aten.add.Tensor"}, + testing_partitioning=True, + ) + + self.assertEquals( + len(unexpected_ops_seen), + 0, + f"The following unexpected ops were encountered: {unexpected_ops_seen}", + ) + self.assertEquals( + len(partitioned_graphs), + 1, + "Without control flow breaks, there should only be a single graph", + ) + self.assertEquals( + len(list(partitioned_graphs[0].named_children())), + 2, + "Certain operators are set to run in Torch, expected 2 segments", + ) + + torch._dynamo.reset() + + # Validate that the results between Torch and Torch-TRT are similar + optimized_model = compile( + fx_graph, + inputs, + min_block_size=1, + pass_through_build_failures=True, + torch_executed_ops={"torch.ops.aten.add.Tensor"}, + use_python_runtime=False, + debug=True, + ) + optimized_model_results = optimized_model(*inputs).detach().cpu() + torch_model_results = fx_graph(*inputs).detach().cpu() + + max_diff = float( + torch.max(torch.abs(optimized_model_results - torch_model_results)) + ) + self.assertAlmostEqual( + max_diff, + 0, + DECIMALS_OF_AGREEMENT, + f"TRT outputs don't match with the original model.", + ) + + +class TestCompilationOptions(TestCase): + def test_trt_specific_options(self): + class SupportedMultiOp(torch.nn.Module): + def forward(self, x, y): + out = x - y + out = out - 3 * x + out = out + y + out = out - y / 5 + out = 2 * out + return torch.mean(out, dim=-1) + + fx_graph = torch.fx.symbolic_trace(SupportedMultiOp()) + + inputs = [ + torch.randint(-40, 40, (16, 7, 5), dtype=torch.float).cuda(), + torch.randint(1, 40, (16, 7, 5), dtype=torch.float).cuda(), + ] + + # Validate that the results between Torch and Torch-TRT are similar + optimized_model = compile( + fx_graph, + inputs, + min_block_size=1, + pass_through_build_failures=True, + use_python_runtime=False, + optimization_level=4, + version_compatible=True, + max_aux_streams=5, + debug=True, + ) + optimized_model_results = optimized_model(*inputs).detach().cpu() + torch_model_results = fx_graph(*inputs).detach().cpu() + + max_diff = float( + torch.max(torch.abs(optimized_model_results - torch_model_results)) + ) + self.assertAlmostEqual( + max_diff, + 0, + DECIMALS_OF_AGREEMENT, + f"TRT outputs don't match with the original model.", + ) + + +if __name__ == "__main__": + run_tests() diff --git a/py/torch_tensorrt/dynamo/backend/utils.py b/py/torch_tensorrt/dynamo/backend/utils.py index 9396373790..23a1cd4795 100644 --- a/py/torch_tensorrt/dynamo/backend/utils.py +++ b/py/torch_tensorrt/dynamo/backend/utils.py @@ -5,6 +5,7 @@ from torch_tensorrt.dynamo.backend._settings import CompilationSettings from typing import Any, Union, Sequence, Dict from torch_tensorrt import _Input, Device +from ..common_utils import use_python_runtime_parser logger = logging.getLogger(__name__) @@ -102,6 +103,9 @@ def parse_dynamo_kwargs(kwargs: Dict) -> CompilationSettings: if settings.debug: logger.setLevel(logging.DEBUG) + # Parse input runtime specification + settings.use_python_runtime = use_python_runtime_parser(settings.use_python_runtime) + logger.debug(f"Compiling with Settings:\n{settings}") return settings diff --git a/py/torch_tensorrt/dynamo/common_utils/__init__.py b/py/torch_tensorrt/dynamo/common_utils/__init__.py index e69de29bb2..de0ce0a48a 100644 --- a/py/torch_tensorrt/dynamo/common_utils/__init__.py +++ b/py/torch_tensorrt/dynamo/common_utils/__init__.py @@ -0,0 +1,36 @@ +import logging +from typing import Optional + + +logger = logging.getLogger(__name__) + + +def use_python_runtime_parser(use_python_runtime: Optional[bool] = None) -> bool: + """Parses a user-provided input argument regarding Python runtime + + Automatically handles cases where the user has not specified a runtime (None) + + Returns True if the Python runtime should be used, False if the C++ runtime should be used + """ + using_python_runtime = use_python_runtime + reason = "" + + # Runtime was manually specified by the user + if using_python_runtime is not None: + reason = "as requested by user" + # Runtime was not manually specified by the user, automatically detect runtime + else: + try: + from torch_tensorrt.dynamo._TorchTensorRTModule import TorchTensorRTModule + + using_python_runtime = False + reason = "since C++ dependency was detected as present" + except ImportError: + using_python_runtime = True + reason = "since import failed, C++ dependency not installed" + + logger.info( + f"Using {'Python' if using_python_runtime else 'C++'} {reason} TRT Runtime" + ) + + return using_python_runtime diff --git a/py/torch_tensorrt/dynamo/common_utils/test_utils.py b/py/torch_tensorrt/dynamo/common_utils/test_utils.py index b258d122a3..873aed4c6b 100644 --- a/py/torch_tensorrt/dynamo/common_utils/test_utils.py +++ b/py/torch_tensorrt/dynamo/common_utils/test_utils.py @@ -1,7 +1,7 @@ import torch COSINE_THRESHOLD = 0.99 -DECIMALS_OF_AGREEMENT = 5 +DECIMALS_OF_AGREEMENT = 4 def cosine_similarity(gt_tensor, pred_tensor): diff --git a/py/torch_tensorrt/dynamo/fx_ts_compat/lower.py b/py/torch_tensorrt/dynamo/fx_ts_compat/lower.py index 63477d894f..c0f1ae7870 100644 --- a/py/torch_tensorrt/dynamo/fx_ts_compat/lower.py +++ b/py/torch_tensorrt/dynamo/fx_ts_compat/lower.py @@ -14,6 +14,7 @@ from .lower_setting import LowerSetting from .passes.lower_pass_manager_builder import LowerPassManagerBuilder from .passes.pass_utils import PassFunc, validate_inference +from ..common_utils import use_python_runtime_parser from torch_tensorrt.fx.tools.timing_cache_utils import TimingCacheManager from torch_tensorrt.fx.tools.trt_splitter import TRTSplitter, TRTSplitterSetting @@ -48,7 +49,7 @@ def compile( save_timing_cache=False, cuda_graph_batch_size=-1, is_aten=False, - use_experimental_fx_rt=False, + use_python_runtime=None, max_aux_streams=None, version_compatible=False, optimization_level=None, @@ -70,7 +71,9 @@ def compile( timing_cache_prefix: Timing cache file name for timing cache used by fx2trt. save_timing_cache: Update timing cache with current timing cache data if set to True. cuda_graph_batch_size: Cuda graph batch size, default to be -1. - use_experimental_fx_rt: Uses the next generation TRTModule which supports both Python and TorchScript based execution (including in C++). + use_python_runtime: Whether to strictly use Python runtime or C++ runtime. To auto-select a runtime + based on C++ dependency presence (preferentially choosing C++ runtime if available), leave the + argument as None max_aux_streams: max number of aux stream to use version_compatible: enable version compatible feature optimization_level: builder optimization level @@ -111,6 +114,9 @@ def compile( "Invalid device provided. Supported options: torch.device | torch_tensorrt.Device" ) + # Parse user-specification of which runtime to use + use_python_runtime = use_python_runtime_parser(use_python_runtime) + lower_setting = LowerSetting( device=device, min_block_size=min_block_size, @@ -123,7 +129,7 @@ def compile( save_timing_cache=save_timing_cache, cuda_graph_batch_size=cuda_graph_batch_size, is_aten=is_aten, - use_experimental_rt=use_experimental_fx_rt, + use_python_runtime=use_python_runtime, max_aux_streams=max_aux_streams, version_compatible=version_compatible, optimization_level=optimization_level, @@ -202,7 +208,7 @@ def default_split_function( splitter_setting = TRTSplitterSetting() splitter_setting.use_implicit_batch_dim = False splitter_setting.min_block_size = lower_setting.min_block_size - splitter_setting.use_experimental_rt = lower_setting.use_experimental_rt + splitter_setting.use_experimental_rt = not lower_setting.use_python_runtime splitter = TRTSplitter(model, inputs, settings=splitter_setting) splitter.node_support_preview() return splitter.generate_split_results() @@ -224,32 +230,30 @@ def lower_pass( """ interpreter = create_trt_interpreter(lower_setting) interp_res: TRTInterpreterResult = interpreter(mod, input, module_name) - if lower_setting.use_experimental_rt: - import io + if lower_setting.use_python_runtime: + trt_module = TRTModule( + engine=interp_res.engine, + input_names=interp_res.input_names, + output_names=interp_res.output_names, + cuda_graph_batch_size=lower_setting.cuda_graph_batch_size, + ) + return trt_module + else: + import io from torch_tensorrt._Device import Device - from torch_tensorrt._TRTModuleNext import TRTModuleNext + from torch_tensorrt.dynamo._TorchTensorRTModule import TorchTensorRTModule with io.BytesIO() as engine_bytes: engine_bytes.write(interp_res.engine.serialize()) engine_str = engine_bytes.getvalue() - trt_module = TRTModuleNext( + trt_module = TorchTensorRTModule( engine_str, name=module_name, input_binding_names=interp_res.input_names, output_binding_names=interp_res.output_names, target_device=Device(f"cuda:{torch.cuda.current_device()}"), - # cuda_graph_batch_size=lower_setting.cuda_graph_batch_size, # NOTE: Not sure what this is supposed to do - ) - return trt_module - - else: - trt_module = TRTModule( - engine=interp_res.engine, - input_names=interp_res.input_names, - output_names=interp_res.output_names, - cuda_graph_batch_size=lower_setting.cuda_graph_batch_size, ) return trt_module diff --git a/py/torch_tensorrt/dynamo/fx_ts_compat/lower_setting.py b/py/torch_tensorrt/dynamo/fx_ts_compat/lower_setting.py index 64a67d1cc2..9301a2cd90 100644 --- a/py/torch_tensorrt/dynamo/fx_ts_compat/lower_setting.py +++ b/py/torch_tensorrt/dynamo/fx_ts_compat/lower_setting.py @@ -68,7 +68,8 @@ class LowerSetting(LowerSettingBasic): meaning all possible tactic sources. correctness_atol: absolute tolerance for correctness check correctness_rtol: relative tolerance for correctness check - use_experimental_rt: Uses the next generation TRTModule which supports both Python and TorchScript based execution (including in C++). + use_python_runtime: Whether to use Python runtime or C++ runtime. None implies the user has not + selected a runtime, and the frontend will automatically do so on their behalf max_aux_streams: max number of aux stream to use version_compatible: enable version compatible feature optimization_level: builder optimization level @@ -95,7 +96,7 @@ class LowerSetting(LowerSettingBasic): tactic_sources: Optional[int] = None correctness_atol: float = 0.1 correctness_rtol: float = 0.1 - use_experimental_rt: bool = False + use_python_runtime: Optional[bool] = None max_aux_streams: Optional[int] = None version_compatible: bool = False optimization_level: Optional[int] = None diff --git a/py/torch_tensorrt/dynamo/fx_ts_compat/tools/trt_minimizer.py b/py/torch_tensorrt/dynamo/fx_ts_compat/tools/trt_minimizer.py index f5c15b049b..bfb1964de9 100644 --- a/py/torch_tensorrt/dynamo/fx_ts_compat/tools/trt_minimizer.py +++ b/py/torch_tensorrt/dynamo/fx_ts_compat/tools/trt_minimizer.py @@ -15,23 +15,30 @@ def lower_mod_default( mod: torch.fx.GraphModule, inputs: Tensors, - use_experimental_rt: bool = False, + use_python_runtime: bool = False, ) -> TRTModule: interp = TRTInterpreter( mod, InputTensorSpec.from_tensors(inputs), explicit_batch_dimension=True ) interpreter_result = interp.run() - if use_experimental_rt: + if use_python_runtime: + res_mod = TRTModule( + interpreter_result.engine, + interpreter_result.input_names, + interpreter_result.output_names, + ) + + else: import io from torch_tensorrt._Device import Device - from torch_tensorrt._TRTModuleNext import TRTModuleNext + from torch_tensorrt.dynamo._TorchTensorRTModule import TorchTensorRTModule with io.BytesIO() as engine_bytes: engine_bytes.write(interpreter_result.engine.serialize()) engine_str = engine_bytes.getvalue() - res_mod = TRTModuleNext( + res_mod = TorchTensorRTModule( engine_str, name=str(type(mod)), input_binding_names=interpreter_result.input_names, @@ -39,12 +46,7 @@ def lower_mod_default( target_device=Device(f"cuda:{torch.cuda.current_device()}"), # cuda_graph_batch_size=lower_setting.cuda_graph_batch_size, # NOTE: Not sure what this is supposed to do ) - else: - res_mod = TRTModule( - interpreter_result.engine, - interpreter_result.input_names, - interpreter_result.output_names, - ) + return res_mod diff --git a/py/torch_tensorrt/dynamo/test/test_dynamo_backend.py b/py/torch_tensorrt/dynamo/test/test_dynamo_backend.py index 462fe04e70..f34aad6caf 100644 --- a/py/torch_tensorrt/dynamo/test/test_dynamo_backend.py +++ b/py/torch_tensorrt/dynamo/test/test_dynamo_backend.py @@ -31,6 +31,8 @@ def test_resnet18(ir): "enabled_precisions": {torch.float}, "ir": ir, "pass_through_build_failures": True, + "optimization_level": 1, + "min_block_size": 8, } trt_mod = torchtrt.compile(model, **compile_spec) @@ -62,6 +64,8 @@ def test_mobilenet_v2(ir): "enabled_precisions": {torch.float}, "ir": ir, "pass_through_build_failures": True, + "optimization_level": 1, + "min_block_size": 8, } trt_mod = torchtrt.compile(model, **compile_spec) @@ -93,6 +97,8 @@ def test_efficientnet_b0(ir): "enabled_precisions": {torch.float}, "ir": ir, "pass_through_build_failures": True, + "optimization_level": 1, + "min_block_size": 8, } trt_mod = torchtrt.compile(model, **compile_spec) @@ -133,6 +139,8 @@ def test_bert_base_uncased(ir): "truncate_long_and_double": True, "ir": ir, "pass_through_build_failures": True, + "optimization_level": 1, + "min_block_size": 8, } trt_mod = torchtrt.compile(model, **compile_spec) @@ -168,6 +176,8 @@ def test_resnet18_half(ir): "enabled_precisions": {torch.half}, "ir": ir, "pass_through_build_failures": True, + "optimization_level": 1, + "min_block_size": 8, } trt_mod = torchtrt.compile(model, **compile_spec) diff --git a/py/torch_tensorrt/fx/lower.py b/py/torch_tensorrt/fx/lower.py index 6572fe9588..5f66519e05 100644 --- a/py/torch_tensorrt/fx/lower.py +++ b/py/torch_tensorrt/fx/lower.py @@ -184,13 +184,13 @@ def lower_pass( import io from torch_tensorrt._Device import Device - from torch_tensorrt._TRTModuleNext import TRTModuleNext + from torch_tensorrt.dynamo._TorchTensorRTModule import TorchTensorRTModule with io.BytesIO() as engine_bytes: engine_bytes.write(interp_res.engine.serialize()) engine_str = engine_bytes.getvalue() - trt_module = TRTModuleNext( + trt_module = TorchTensorRTModule( engine_str, name=module_name, input_binding_names=interp_res.input_names, diff --git a/py/torch_tensorrt/fx/test/core/test_trt_module.py b/py/torch_tensorrt/fx/test/core/test_trt_module.py index df4de754ba..2cb52fd130 100644 --- a/py/torch_tensorrt/fx/test/core/test_trt_module.py +++ b/py/torch_tensorrt/fx/test/core/test_trt_module.py @@ -10,7 +10,7 @@ from torch.testing._internal.common_utils import run_tests, TestCase from torch_tensorrt.fx import InputTensorSpec, TRTInterpreter, TRTModule -# from torch_tensorrt import TRTModuleNext +# from torch_tensorrt.dynamo._TorchTensorRTModule import TorchTensorRTModule # from torch_tensorrt import Device from torch_tensorrt.fx.utils import LowerPrecision @@ -59,7 +59,7 @@ def forward(self, x): # TODO add unittest.skip later -# class TestTRTModuleNext(TestCase): +# class TestTorchTensorRTModule(TestCase): # def test_save_and_load_trt_module(self): # class TestModule(torch.nn.Module): # def forward(self, x): @@ -82,7 +82,7 @@ def forward(self, x): # engine_bytes.write(interp_res.engine.serialize()) # engine_str = engine_bytes.getvalue() -# trt_mod = TRTModuleNext( +# trt_mod = TorchTensorRTModule( # name="TestModule", # serialized_engine=engine_str, # input_binding_names=interp_res.input_names, @@ -122,7 +122,7 @@ def forward(self, x): # engine_bytes.write(interp_res.engine.serialize()) # engine_str = engine_bytes.getvalue() -# trt_mod = TRTModuleNext( +# trt_mod = TorchTensorRTModule( # name="TestModule", # serialized_engine=engine_str, # input_binding_names=interp_res.input_names, @@ -132,7 +132,7 @@ def forward(self, x): # st = trt_mod.state_dict() -# new_trt_mod = TRTModuleNext() +# new_trt_mod = TorchTensorRTModule() # new_trt_mod.load_state_dict(st) # torch.testing.assert_allclose( diff --git a/py/torch_tensorrt/fx/tools/trt_minimizer.py b/py/torch_tensorrt/fx/tools/trt_minimizer.py index f44a5e1d25..1c14b289cf 100644 --- a/py/torch_tensorrt/fx/tools/trt_minimizer.py +++ b/py/torch_tensorrt/fx/tools/trt_minimizer.py @@ -24,13 +24,13 @@ def lower_mod_default( import io from torch_tensorrt._Device import Device - from torch_tensorrt._TRTModuleNext import TRTModuleNext + from torch_tensorrt.dynamo._TorchTensorRTModule import TorchTensorRTModule with io.BytesIO() as engine_bytes: engine_bytes.write(interpreter_result.engine.serialize()) engine_str = engine_bytes.getvalue() - res_mod = TRTModuleNext( + res_mod = TorchTensorRTModule( engine_str, name=str(type(mod)), input_binding_names=interpreter_result.input_names, diff --git a/py/torch_tensorrt/fx/tools/trt_splitter.py b/py/torch_tensorrt/fx/tools/trt_splitter.py index aa3d930bfb..6fcb40c0d8 100644 --- a/py/torch_tensorrt/fx/tools/trt_splitter.py +++ b/py/torch_tensorrt/fx/tools/trt_splitter.py @@ -95,13 +95,13 @@ def _lower_model_to_backend( import io from torch_tensorrt._Device import Device - from torch_tensorrt._TRTModuleNext import TRTModuleNext + from torch_tensorrt.dynamo._TorchTensorRTModule import TorchTensorRTModule with io.BytesIO() as engine_bytes: engine_bytes.write(interpreter_result.engine.serialize()) engine_str = engine_bytes.getvalue() - return TRTModuleNext( + return TorchTensorRTModule( engine_str, name=str(type(mod)), input_binding_names=interpreter_result.input_names, diff --git a/tests/py/api/test_classes.py b/tests/py/api/test_classes.py index b9729b9d4d..3d0cb5c5f9 100644 --- a/tests/py/api/test_classes.py +++ b/tests/py/api/test_classes.py @@ -1,5 +1,6 @@ import unittest import torch_tensorrt as torchtrt +from torch_tensorrt.dynamo._TorchTensorRTModule import TorchTensorRTModule import torch import torchvision.models as models import copy @@ -238,7 +239,7 @@ def test_dynamic_shape(self): self.assertTrue(self._verify_correctness(ts_i, target)) -class TestTRTModuleNext(unittest.TestCase): +class TestTorchTensorRTModule(unittest.TestCase): @staticmethod def _get_trt_mod(): class Test(torch.nn.Module): @@ -255,7 +256,7 @@ def forward(self, x): test_mod_engine_str = torchtrt.ts.convert_method_to_trt_engine( mod, "forward", inputs=[torchtrt.Input((2, 10))] ) - return torchtrt.TRTModuleNext( + return TorchTensorRTModule( name="test", serialized_engine=test_mod_engine_str, input_binding_names=["input_0"], @@ -278,7 +279,7 @@ def forward(self, x): mod, "forward", inputs=[torchtrt.Input((2, 10))] ) with self.assertRaises(RuntimeError): - torchtrt.TRTModuleNext( + TorchTensorRTModule( name="test", serialized_engine=test_mod_engine_str, input_binding_names=["x.1"], @@ -301,7 +302,7 @@ def forward(self, x): mod, "forward", inputs=[torchtrt.Input((2, 10))] ) with self.assertRaises(RuntimeError): - torchtrt.TRTModuleNext( + TorchTensorRTModule( name="test", serialized_engine=test_mod_engine_str, input_binding_names=["input_0"], @@ -309,7 +310,7 @@ def forward(self, x): ) def test_set_get_profile_path_prefix(self): - trt_mod = TestTRTModuleNext._get_trt_mod() + trt_mod = TestTorchTensorRTModule._get_trt_mod() trt_mod.engine.profile_path_prefix = "/tmp/" self.assertTrue(trt_mod.engine.profile_path_prefix == "/tmp/") @@ -331,7 +332,7 @@ def test_get_layer_info(self): import json - trt_mod = TestTRTModuleNext._get_trt_mod() + trt_mod = TestTorchTensorRTModule._get_trt_mod() trt_json = json.loads(trt_mod.get_layer_info()) [self.assertTrue(k in trt_json.keys()) for k in ["Layers", "Bindings"]] self.assertTrue(len(trt_json["Layers"]) == 4)