Skip to content

Commit ba18185

Browse files
committed
feat: Add support for device compilation setting
- Add updated Device utilities and automatic context-aware device detection for torch compile - Add testing for new utilities
1 parent 56b8950 commit ba18185

File tree

8 files changed

+115
-29
lines changed

8 files changed

+115
-29
lines changed

py/torch_tensorrt/_Device.py

+5-3
Original file line numberDiff line numberDiff line change
@@ -8,11 +8,10 @@
88

99
import warnings
1010

11-
import torch
12-
from torch_tensorrt import logging
13-
1411
# from torch_tensorrt import _enums
1512
import tensorrt as trt
13+
import torch
14+
from torch_tensorrt import logging
1615

1716
try:
1817
from torch_tensorrt import _C
@@ -120,6 +119,9 @@ def __str__(self) -> str:
120119
)
121120
)
122121

122+
def __repr__(self) -> str:
123+
return self.__str__()
124+
123125
def _to_internal(self) -> _C.Device:
124126
internal_dev = _C.Device()
125127
if self.device_type == trt.DeviceType.GPU:

py/torch_tensorrt/_compile.py

+2-2
Original file line numberDiff line numberDiff line change
@@ -209,12 +209,12 @@ def compile(
209209
import collections.abc
210210

211211
from torch_tensorrt import Device
212-
from torch_tensorrt.dynamo.utils import prepare_device, prepare_inputs
212+
from torch_tensorrt.dynamo.utils import prepare_inputs, to_torch_device
213213

214214
if not isinstance(inputs, collections.abc.Sequence):
215215
inputs = [inputs]
216216
device = kwargs.get("device", Device._current_device())
217-
torchtrt_inputs, torch_inputs = prepare_inputs(inputs, prepare_device(device))
217+
torchtrt_inputs, torch_inputs = prepare_inputs(inputs, to_torch_device(device))
218218
module = torch_tensorrt.dynamo.trace(module, torch_inputs, **kwargs)
219219
compiled_aten_module: torch.fx.GraphModule = dynamo_compile(
220220
module,

py/torch_tensorrt/dynamo/_defaults.py

+6
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,9 @@
11
import torch
2+
from torch_tensorrt._Device import Device
23

34
PRECISION = torch.float32
45
DEBUG = False
6+
DEVICE = None
57
WORKSPACE_SIZE = 0
68
MIN_BLOCK_SIZE = 5
79
PASS_THROUGH_BUILD_FAILURES = False
@@ -12,3 +14,7 @@
1214
USE_PYTHON_RUNTIME = False
1315
USE_FAST_PARTITIONER = True
1416
ENABLE_EXPERIMENTAL_DECOMPOSITIONS = False
17+
18+
19+
def default_device() -> Device:
20+
return Device(gpu_id=torch.cuda.current_device())

py/torch_tensorrt/dynamo/_settings.py

+3
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22
from typing import Optional, Set
33

44
import torch
5+
from torch_tensorrt._Device import Device
56
from torch_tensorrt.dynamo._defaults import (
67
DEBUG,
78
ENABLE_EXPERIMENTAL_DECOMPOSITIONS,
@@ -15,6 +16,7 @@
1516
USE_PYTHON_RUNTIME,
1617
VERSION_COMPATIBLE,
1718
WORKSPACE_SIZE,
19+
default_device,
1820
)
1921

2022

@@ -54,3 +56,4 @@ class CompilationSettings:
5456
truncate_long_and_double: bool = TRUNCATE_LONG_AND_DOUBLE
5557
use_fast_partitioner: bool = USE_FAST_PARTITIONER
5658
enable_experimental_decompositions: bool = ENABLE_EXPERIMENTAL_DECOMPOSITIONS
59+
device: Device = field(default_factory=default_device)

py/torch_tensorrt/dynamo/compile.py

+12-4
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@
22

33
import collections.abc
44
import logging
5-
from typing import Any, List, Optional, Sequence, Set, Tuple
5+
from typing import Any, List, Optional, Sequence, Set, Tuple, Union
66

77
import torch
88
import torch_tensorrt
@@ -13,6 +13,7 @@
1313
from torch_tensorrt.dynamo import CompilationSettings, partitioning
1414
from torch_tensorrt.dynamo._defaults import (
1515
DEBUG,
16+
DEVICE,
1617
ENABLE_EXPERIMENTAL_DECOMPOSITIONS,
1718
MAX_AUX_STREAMS,
1819
MIN_BLOCK_SIZE,
@@ -29,7 +30,11 @@
2930
convert_module,
3031
repair_long_or_double_inputs,
3132
)
32-
from torch_tensorrt.dynamo.utils import prepare_device, prepare_inputs
33+
from torch_tensorrt.dynamo.utils import (
34+
prepare_inputs,
35+
to_torch_device,
36+
to_torch_tensorrt_device,
37+
)
3338

3439
logger = logging.getLogger(__name__)
3540

@@ -38,7 +43,7 @@ def compile(
3843
gm: Any,
3944
inputs: Any,
4045
*,
41-
device: Device = Device._current_device(),
46+
device: Optional[Union[Device, torch.device, str]] = DEVICE,
4247
disable_tf32: bool = False,
4348
sparse_weights: bool = False,
4449
enabled_precisions: Set[torch.dtype] | Tuple[torch.dtype] = (torch.float32,),
@@ -82,7 +87,9 @@ def compile(
8287
if not isinstance(inputs, collections.abc.Sequence):
8388
inputs = [inputs]
8489

85-
_, torch_inputs = prepare_inputs(inputs, prepare_device(device))
90+
device = to_torch_tensorrt_device(device)
91+
92+
_, torch_inputs = prepare_inputs(inputs, to_torch_device(device))
8693

8794
if (
8895
torch.float16 in enabled_precisions
@@ -105,6 +112,7 @@ def compile(
105112
compilation_options = {
106113
"precision": precision,
107114
"debug": debug,
115+
"device": device,
108116
"workspace_size": workspace_size,
109117
"min_block_size": min_block_size,
110118
"torch_executed_ops": torch_executed_ops

py/torch_tensorrt/dynamo/conversion/conversion.py

+2-2
Original file line numberDiff line numberDiff line change
@@ -3,14 +3,13 @@
33
import io
44
from typing import Sequence
55

6+
import tensorrt as trt
67
import torch
78
from torch_tensorrt._Input import Input
89
from torch_tensorrt.dynamo import CompilationSettings
910
from torch_tensorrt.dynamo.conversion import TRTInterpreter
1011
from torch_tensorrt.dynamo.runtime import PythonTorchTensorRTModule, TorchTensorRTModule
1112

12-
import tensorrt as trt
13-
1413

1514
def convert_module(
1615
module: torch.fx.GraphModule,
@@ -72,4 +71,5 @@ def convert_module(
7271
name=name,
7372
input_binding_names=list(interpreter_result.input_names),
7473
output_binding_names=list(interpreter_result.output_names),
74+
target_device=settings.device,
7575
)

py/torch_tensorrt/dynamo/utils.py

+42-10
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@
22

33
import logging
44
from dataclasses import fields, replace
5-
from typing import Any, Callable, Dict, Optional, Sequence
5+
from typing import Any, Callable, Dict, Optional, Sequence, Union
66

77
import torch
88
import torch_tensorrt
@@ -116,23 +116,45 @@ def prepare_inputs(
116116
)
117117

118118

119-
def prepare_device(device: Device | torch.device) -> torch.device:
120-
_device: torch.device
119+
def to_torch_device(device: Optional[Union[Device, torch.device, str]]) -> torch.device:
120+
"""Cast a device-type to torch.device
121+
122+
Returns the corresponding torch.device
123+
"""
121124
if isinstance(device, Device):
122125
if device.gpu_id != -1:
123-
_device = torch.device(device.gpu_id)
126+
return torch.device(device.gpu_id)
124127
else:
125128
raise ValueError("Invalid GPU ID provided for the CUDA device provided")
126129

127130
elif isinstance(device, torch.device):
128-
_device = device
131+
return device
132+
133+
elif device is None:
134+
return torch.device(torch.cuda.current_device())
129135

130136
else:
131-
raise ValueError(
132-
"Invalid device provided. Supported options: torch.device | torch_tensorrt.Device"
133-
)
137+
return torch.device(device)
134138

135-
return _device
139+
140+
def to_torch_tensorrt_device(
141+
device: Optional[Union[Device, torch.device, str]]
142+
) -> Device:
143+
"""Cast a device-type to torch_tensorrt.Device
144+
145+
Returns the corresponding torch_tensorrt.Device
146+
"""
147+
if isinstance(device, Device):
148+
return device
149+
150+
elif isinstance(device, torch.device):
151+
return Device(gpu_id=device.index)
152+
153+
elif device is None:
154+
return Device(gpu_id=torch.cuda.current_device())
155+
156+
else:
157+
return Device(device)
136158

137159

138160
def parse_dynamo_kwargs(kwargs: Any) -> CompilationSettings:
@@ -184,7 +206,17 @@ def parse_dynamo_kwargs(kwargs: Any) -> CompilationSettings:
184206
# Parse input runtime specification
185207
settings.use_python_runtime = use_python_runtime_parser(settings.use_python_runtime)
186208

187-
logger.info("Compilation Settings: %s\n", settings)
209+
# Ensure device is a torch_tensorrt Device
210+
settings.device = to_torch_tensorrt_device(settings.device)
211+
212+
# Check and update device settings
213+
if "device" not in kwargs:
214+
logger.info(
215+
f"Device not specified, using Torch default current device - cuda:{settings.device.gpu_id}. "
216+
"If this is incorrect, please specify an input device, via the device keyword."
217+
)
218+
219+
logger.info(f"Compiling with Settings:\n{settings}")
188220

189221
return settings
190222

tests/py/dynamo/backend/test_compiler_utils.py

+43-8
Original file line numberDiff line numberDiff line change
@@ -1,26 +1,61 @@
1-
from torch_tensorrt.dynamo.utils import prepare_device, prepare_inputs
2-
from utils import same_output_format
3-
import torch_tensorrt
41
import unittest
2+
53
import torch
4+
import torch_tensorrt
5+
from torch_tensorrt.dynamo.utils import (
6+
prepare_inputs,
7+
to_torch_device,
8+
to_torch_tensorrt_device,
9+
)
10+
from utils import same_output_format
611

712

8-
class TestPrepareDevice(unittest.TestCase):
9-
def test_prepare_cuda_device(self):
13+
class TestToTorchDevice(unittest.TestCase):
14+
def test_cast_cuda_device(self):
1015
gpu_id = 0
1116
device = torch.device(f"cuda:{gpu_id}")
12-
prepared_device = prepare_device(device)
17+
prepared_device = to_torch_device(device)
1318
self.assertTrue(isinstance(prepared_device, torch.device))
1419
self.assertTrue(prepared_device.index == gpu_id)
1520

16-
def test_prepare_trt_device(self):
21+
def test_cast_trt_device(self):
1722
gpu_id = 4
1823
device = torch_tensorrt.Device(gpu_id=gpu_id)
19-
prepared_device = prepare_device(device)
24+
prepared_device = to_torch_device(device)
25+
self.assertTrue(isinstance(prepared_device, torch.device))
26+
self.assertTrue(prepared_device.index == gpu_id)
27+
28+
def test_cast_str_device(self):
29+
gpu_id = 2
30+
device = f"cuda:{2}"
31+
prepared_device = to_torch_device(device)
2032
self.assertTrue(isinstance(prepared_device, torch.device))
2133
self.assertTrue(prepared_device.index == gpu_id)
2234

2335

36+
class TestToTorchTRTDevice(unittest.TestCase):
37+
def test_cast_cuda_device(self):
38+
gpu_id = 0
39+
device = torch.device(f"cuda:{gpu_id}")
40+
prepared_device = to_torch_tensorrt_device(device)
41+
self.assertTrue(isinstance(prepared_device, torch_tensorrt.Device))
42+
self.assertTrue(prepared_device.gpu_id == gpu_id)
43+
44+
def test_cast_trt_device(self):
45+
gpu_id = 4
46+
device = torch_tensorrt.Device(gpu_id=gpu_id)
47+
prepared_device = to_torch_tensorrt_device(device)
48+
self.assertTrue(isinstance(prepared_device, torch_tensorrt.Device))
49+
self.assertTrue(prepared_device.gpu_id == gpu_id)
50+
51+
def test_cast_str_device(self):
52+
gpu_id = 2
53+
device = f"cuda:{2}"
54+
prepared_device = to_torch_tensorrt_device(device)
55+
self.assertTrue(isinstance(prepared_device, torch_tensorrt.Device))
56+
self.assertTrue(prepared_device.gpu_id == gpu_id)
57+
58+
2459
class TestPrepareInputs(unittest.TestCase):
2560
def test_prepare_single_tensor_input(self):
2661
inputs = [torch.ones((4, 4))]

0 commit comments

Comments
 (0)