|
2 | 2 |
|
3 | 3 | import logging
|
4 | 4 | from dataclasses import fields, replace
|
5 |
| -from typing import Any, Callable, Dict, Optional, Sequence |
| 5 | +from typing import Any, Callable, Dict, Optional, Sequence, Union |
6 | 6 |
|
7 | 7 | import torch
|
8 | 8 | import torch_tensorrt
|
@@ -116,23 +116,45 @@ def prepare_inputs(
|
116 | 116 | )
|
117 | 117 |
|
118 | 118 |
|
119 |
| -def prepare_device(device: Device | torch.device) -> torch.device: |
120 |
| - _device: torch.device |
| 119 | +def to_torch_device(device: Optional[Union[Device, torch.device, str]]) -> torch.device: |
| 120 | + """Cast a device-type to torch.device |
| 121 | +
|
| 122 | + Returns the corresponding torch.device |
| 123 | + """ |
121 | 124 | if isinstance(device, Device):
|
122 | 125 | if device.gpu_id != -1:
|
123 |
| - _device = torch.device(device.gpu_id) |
| 126 | + return torch.device(device.gpu_id) |
124 | 127 | else:
|
125 | 128 | raise ValueError("Invalid GPU ID provided for the CUDA device provided")
|
126 | 129 |
|
127 | 130 | elif isinstance(device, torch.device):
|
128 |
| - _device = device |
| 131 | + return device |
| 132 | + |
| 133 | + elif device is None: |
| 134 | + return torch.device(torch.cuda.current_device()) |
129 | 135 |
|
130 | 136 | else:
|
131 |
| - raise ValueError( |
132 |
| - "Invalid device provided. Supported options: torch.device | torch_tensorrt.Device" |
133 |
| - ) |
| 137 | + return torch.device(device) |
134 | 138 |
|
135 |
| - return _device |
| 139 | + |
| 140 | +def to_torch_tensorrt_device( |
| 141 | + device: Optional[Union[Device, torch.device, str]] |
| 142 | +) -> Device: |
| 143 | + """Cast a device-type to torch_tensorrt.Device |
| 144 | +
|
| 145 | + Returns the corresponding torch_tensorrt.Device |
| 146 | + """ |
| 147 | + if isinstance(device, Device): |
| 148 | + return device |
| 149 | + |
| 150 | + elif isinstance(device, torch.device): |
| 151 | + return Device(gpu_id=device.index) |
| 152 | + |
| 153 | + elif device is None: |
| 154 | + return Device(gpu_id=torch.cuda.current_device()) |
| 155 | + |
| 156 | + else: |
| 157 | + return Device(device) |
136 | 158 |
|
137 | 159 |
|
138 | 160 | def parse_dynamo_kwargs(kwargs: Any) -> CompilationSettings:
|
@@ -184,7 +206,17 @@ def parse_dynamo_kwargs(kwargs: Any) -> CompilationSettings:
|
184 | 206 | # Parse input runtime specification
|
185 | 207 | settings.use_python_runtime = use_python_runtime_parser(settings.use_python_runtime)
|
186 | 208 |
|
187 |
| - logger.info("Compilation Settings: %s\n", settings) |
| 209 | + # Ensure device is a torch_tensorrt Device |
| 210 | + settings.device = to_torch_tensorrt_device(settings.device) |
| 211 | + |
| 212 | + # Check and update device settings |
| 213 | + if "device" not in kwargs: |
| 214 | + logger.info( |
| 215 | + f"Device not specified, using Torch default current device - cuda:{settings.device.gpu_id}. " |
| 216 | + "If this is incorrect, please specify an input device, via the device keyword." |
| 217 | + ) |
| 218 | + |
| 219 | + logger.info(f"Compiling with Settings:\n{settings}") |
188 | 220 |
|
189 | 221 | return settings
|
190 | 222 |
|
|
0 commit comments