@@ -11,75 +11,50 @@ namespace core {
11
11
namespace runtime {
12
12
13
13
// Checks if the context switch requred for device ID
14
- bool is_switch_required (const CudaDevice& curr_device, const CudaDevice& conf_device ) {
14
+ bool is_switch_required (const CudaDevice& curr_device, const CudaDevice& engine_device ) {
15
15
// If SM capability is not the same as configured then switch
16
- if ((curr_device.major != conf_device .major ) || (curr_device.minor != conf_device .minor )) {
16
+ if ((curr_device.major != engine_device .major ) || (curr_device.minor != engine_device .minor )) {
17
17
LOG_WARNING (
18
- " Configured SM capability " << conf_device .getSMCapability ()
18
+ " Configured SM capability " << engine_device .getSMCapability ()
19
19
<< " does not match with current device SM capability "
20
20
<< curr_device.getSMCapability () << " (" << curr_device
21
21
<< " ). Switching device context" );
22
22
return true ;
23
23
}
24
24
25
25
// GPU case
26
- if (conf_device .device_type == nvinfer1::DeviceType::kGPU ) {
27
- if (curr_device.device_name != conf_device .device_name ) {
26
+ if (engine_device .device_type == nvinfer1::DeviceType::kGPU ) {
27
+ if (curr_device.device_name != engine_device .device_name ) {
28
28
LOG_WARNING (
29
- " Program compiled for " << conf_device .device_name << " but current CUDA device is " << curr_device
29
+ " Program compiled for " << engine_device .device_name << " but current CUDA device is " << curr_device
30
30
<< " . Attempting to switch device context for better compatibility" );
31
31
return true ;
32
32
}
33
33
}
34
34
35
- if (curr_device.id != conf_device .id ) {
35
+ if (curr_device.id != engine_device .id ) {
36
36
LOG_WARNING (
37
- " Configured Device ID: " << conf_device .id << " is different that current device ID: " << curr_device.id
38
- << " . Moving input tensors to device: " << conf_device .id );
37
+ " Configured Device ID: " << engine_device .id << " is different that current device ID: " << curr_device.id
38
+ << " . Moving input tensors to device: " << engine_device .id );
39
39
return true ;
40
40
}
41
41
42
42
return false ;
43
43
}
44
44
45
- CudaDevice select_cuda_device (const CudaDevice& conf_device) {
46
- int64_t device_id = -1 ;
47
- auto dla_supported = get_dla_supported_SMs ();
48
-
49
- auto device_list = get_available_device_list ().get_devices ();
50
-
51
- CudaDevice new_target_device;
52
-
53
- for (auto device : device_list) {
54
- auto compute_cap = device.second .getSMCapability ();
55
- // In case of DLA select the DLA supported device ID
56
- if (conf_device.device_type == nvinfer1::DeviceType::kDLA ) {
57
- if (dla_supported.find (compute_cap) != dla_supported.end () &&
58
- dla_supported[compute_cap] == device.second .device_name ) {
59
- device_id = device.second .id ;
60
- new_target_device = CudaDevice (device_id, nvinfer1::DeviceType::kDLA );
61
- break ;
62
- }
63
- } else if (conf_device.device_type == nvinfer1::DeviceType::kGPU ) {
64
- auto conf_sm = conf_device.getSMCapability ();
65
- if (compute_cap == conf_sm && device.second .device_name == conf_device.device_name ) {
66
- device_id = device.second .id ;
67
- new_target_device = CudaDevice (device_id, nvinfer1::DeviceType::kGPU );
68
- break ;
69
- }
70
- } else {
71
- TRTORCH_THROW_ERROR (" Unknown target device type detected from the compiled program (runtime.select_cuda_device)" );
72
- break ;
73
- }
74
- }
45
+ CudaDevice select_cuda_device (const CudaDevice& engine_device) {
46
+ auto new_target_device_opt = get_most_compatible_device (engine_device);
75
47
76
48
// REVIEW: THIS DOES NOT LIST DLA PROBABLY, WHICH WE SHOULD
49
+ // TODO: I think this logic could be way simpler at execution time since if the tensors arent on the right
50
+ // device, its not going to run. We should just set device to engine device and maybe reset and memcpy tensors
51
+ // back to orginal device if needed.
77
52
TRTORCH_CHECK (
78
- device_id >= 0 ,
53
+ new_target_device_opt ,
79
54
" No compatible device found on system to run program.\n Program targets "
80
- << conf_device << " \n Available targets: \n "
55
+ << engine_device << " \n Available targets: \n "
81
56
<< get_available_device_list ().dump_list () << " \n (runtime.select_cuda_device)" );
82
- return new_target_device ;
57
+ return new_target_device_opt. value () ;
83
58
}
84
59
85
60
std::vector<at::Tensor> execute_engine (std::vector<at::Tensor> inputs, c10::intrusive_ptr<TRTEngine> compiled_engine) {
@@ -96,7 +71,7 @@ std::vector<at::Tensor> execute_engine(std::vector<at::Tensor> inputs, c10::intr
96
71
std::string target_device = " cuda:" + std::to_string (device.id );
97
72
98
73
for (auto & in : inputs) {
99
- in = in.to (at:: kCUDA );
74
+ in = in.to (torch::Device (target_device) );
100
75
}
101
76
}
102
77
0 commit comments