Skip to content

Commit f1d0a43

Browse files
committed
fix(device_conf): Devices never actually got swithed in multi device
cases Signed-off-by: Naren Dasan <naren@narendasan.com> Signed-off-by: Naren Dasan <narens@nvidia.com>
1 parent 744b417 commit f1d0a43

File tree

8 files changed

+103
-48
lines changed

8 files changed

+103
-48
lines changed

core/runtime/CudaDevice.cpp

+9
Original file line numberDiff line numberDiff line change
@@ -66,6 +66,15 @@ CudaDevice::CudaDevice(std::string device_info) {
6666
LOG_DEBUG("Deserialized Device Info: " << *this);
6767
}
6868

69+
CudaDevice& CudaDevice::operator=(const CudaDevice& other) {
70+
id = other.id;
71+
major = other.major;
72+
minor = other.minor;
73+
device_type = other.device_type;
74+
device_name = other.device_name;
75+
return (*this);
76+
}
77+
6978
std::string CudaDevice::serialize() {
7079
std::vector<std::string> content;
7180
content.resize(DEVICE_NAME_IDX + 1);

core/runtime/TRTEngine.cpp

+3-1
Original file line numberDiff line numberDiff line change
@@ -38,7 +38,9 @@ TRTEngine::TRTEngine(std::vector<std::string> serialized_info) {
3838
}
3939

4040
TRTEngine::TRTEngine(std::string mod_name, std::string serialized_engine, CudaDevice cuda_device) {
41-
device_info = cuda_device;
41+
auto most_compatible_device = get_most_compatible_device(cuda_device);
42+
TRTORCH_CHECK(most_compatible_device, "No compatible device was found for instantiating TensorRT engine");
43+
device_info = most_compatible_device.value();
4244
set_cuda_device(device_info);
4345

4446
rt = std::shared_ptr<nvinfer1::IRuntime>(nvinfer1::createInferRuntime(util::logging::get_logger()));

core/runtime/register_trt_op.cpp

+18-43
Original file line numberDiff line numberDiff line change
@@ -11,75 +11,50 @@ namespace core {
1111
namespace runtime {
1212

1313
// Checks if the context switch requred for device ID
14-
bool is_switch_required(const CudaDevice& curr_device, const CudaDevice& conf_device) {
14+
bool is_switch_required(const CudaDevice& curr_device, const CudaDevice& engine_device) {
1515
// If SM capability is not the same as configured then switch
16-
if ((curr_device.major != conf_device.major) || (curr_device.minor != conf_device.minor)) {
16+
if ((curr_device.major != engine_device.major) || (curr_device.minor != engine_device.minor)) {
1717
LOG_WARNING(
18-
"Configured SM capability " << conf_device.getSMCapability()
18+
"Configured SM capability " << engine_device.getSMCapability()
1919
<< " does not match with current device SM capability "
2020
<< curr_device.getSMCapability() << " (" << curr_device
2121
<< "). Switching device context");
2222
return true;
2323
}
2424

2525
// GPU case
26-
if (conf_device.device_type == nvinfer1::DeviceType::kGPU) {
27-
if (curr_device.device_name != conf_device.device_name) {
26+
if (engine_device.device_type == nvinfer1::DeviceType::kGPU) {
27+
if (curr_device.device_name != engine_device.device_name) {
2828
LOG_WARNING(
29-
"Program compiled for " << conf_device.device_name << " but current CUDA device is " << curr_device
29+
"Program compiled for " << engine_device.device_name << " but current CUDA device is " << curr_device
3030
<< ". Attempting to switch device context for better compatibility");
3131
return true;
3232
}
3333
}
3434

35-
if (curr_device.id != conf_device.id) {
35+
if (curr_device.id != engine_device.id) {
3636
LOG_WARNING(
37-
"Configured Device ID: " << conf_device.id << " is different that current device ID: " << curr_device.id
38-
<< ". Moving input tensors to device: " << conf_device.id);
37+
"Configured Device ID: " << engine_device.id << " is different that current device ID: " << curr_device.id
38+
<< ". Moving input tensors to device: " << engine_device.id);
3939
return true;
4040
}
4141

4242
return false;
4343
}
4444

45-
CudaDevice select_cuda_device(const CudaDevice& conf_device) {
46-
int64_t device_id = -1;
47-
auto dla_supported = get_dla_supported_SMs();
48-
49-
auto device_list = get_available_device_list().get_devices();
50-
51-
CudaDevice new_target_device;
52-
53-
for (auto device : device_list) {
54-
auto compute_cap = device.second.getSMCapability();
55-
// In case of DLA select the DLA supported device ID
56-
if (conf_device.device_type == nvinfer1::DeviceType::kDLA) {
57-
if (dla_supported.find(compute_cap) != dla_supported.end() &&
58-
dla_supported[compute_cap] == device.second.device_name) {
59-
device_id = device.second.id;
60-
new_target_device = CudaDevice(device_id, nvinfer1::DeviceType::kDLA);
61-
break;
62-
}
63-
} else if (conf_device.device_type == nvinfer1::DeviceType::kGPU) {
64-
auto conf_sm = conf_device.getSMCapability();
65-
if (compute_cap == conf_sm && device.second.device_name == conf_device.device_name) {
66-
device_id = device.second.id;
67-
new_target_device = CudaDevice(device_id, nvinfer1::DeviceType::kGPU);
68-
break;
69-
}
70-
} else {
71-
TRTORCH_THROW_ERROR("Unknown target device type detected from the compiled program (runtime.select_cuda_device)");
72-
break;
73-
}
74-
}
45+
CudaDevice select_cuda_device(const CudaDevice& engine_device) {
46+
auto new_target_device_opt = get_most_compatible_device(engine_device);
7547

7648
// REVIEW: THIS DOES NOT LIST DLA PROBABLY, WHICH WE SHOULD
49+
// TODO: I think this logic could be way simpler at execution time since if the tensors arent on the right
50+
// device, its not going to run. We should just set device to engine device and maybe reset and memcpy tensors
51+
// back to orginal device if needed.
7752
TRTORCH_CHECK(
78-
device_id >= 0,
53+
new_target_device_opt,
7954
"No compatible device found on system to run program.\n Program targets "
80-
<< conf_device << "\n Available targets: \n"
55+
<< engine_device << "\n Available targets: \n"
8156
<< get_available_device_list().dump_list() << "\n(runtime.select_cuda_device)");
82-
return new_target_device;
57+
return new_target_device_opt.value();
8358
}
8459

8560
std::vector<at::Tensor> execute_engine(std::vector<at::Tensor> inputs, c10::intrusive_ptr<TRTEngine> compiled_engine) {
@@ -96,7 +71,7 @@ std::vector<at::Tensor> execute_engine(std::vector<at::Tensor> inputs, c10::intr
9671
std::string target_device = "cuda:" + std::to_string(device.id);
9772

9873
for (auto& in : inputs) {
99-
in = in.to(at::kCUDA);
74+
in = in.to(torch::Device(target_device));
10075
}
10176
}
10277

core/runtime/runtime.cpp

+63
Original file line numberDiff line numberDiff line change
@@ -7,9 +7,72 @@ namespace trtorch {
77
namespace core {
88
namespace runtime {
99

10+
c10::optional<CudaDevice> get_most_compatible_device(const CudaDevice& target_device) {
11+
LOG_DEBUG("Target Device: " << target_device);
12+
auto device_options = find_compatible_devices(target_device);
13+
if (device_options.size() == 0) {
14+
return {};
15+
} else if (device_options.size() == 1) {
16+
return {device_options[0]};
17+
}
18+
19+
CudaDevice best_match;
20+
std::stringstream dev_list;
21+
dev_list << "[" << std::endl;
22+
for (auto device : device_options) {
23+
dev_list << " " << device << ',' << std::endl;
24+
if (device.device_name == target_device.device_name && best_match.device_name != target_device.device_name) {
25+
best_match = device;
26+
} else if (device.device_name == target_device.device_name && best_match.device_name == target_device.device_name) {
27+
if (device.id == target_device.id && best_match.id != target_device.id) {
28+
best_match = device;
29+
}
30+
}
31+
}
32+
dev_list << ']';
33+
LOG_DEBUG("Compatible device options: " << dev_list.str());
34+
35+
if (best_match.id == -1) {
36+
LOG_DEBUG("No valid device options");
37+
return {};
38+
} else {
39+
LOG_DEBUG("Selected: " << best_match);
40+
return {best_match};
41+
}
42+
}
43+
44+
std::vector<CudaDevice> find_compatible_devices(const CudaDevice& target_device) {
45+
auto dla_supported = get_dla_supported_SMs();
46+
auto device_list = get_available_device_list().get_devices();
47+
48+
std::vector<CudaDevice> compatible_devices;
49+
50+
for (auto device : device_list) {
51+
auto poss_dev_cc = device.second.getSMCapability();
52+
if (target_device.device_type == nvinfer1::DeviceType::kDLA) {
53+
if (dla_supported.find(poss_dev_cc) != dla_supported.end() &&
54+
dla_supported[poss_dev_cc] == target_device.device_name) {
55+
compatible_devices.push_back(device.second);
56+
}
57+
} else if (target_device.device_type == nvinfer1::DeviceType::kGPU) {
58+
auto target_dev_cc = target_device.getSMCapability();
59+
// If the SM Capabilities match, should be good enough to run
60+
if (poss_dev_cc == target_dev_cc) {
61+
compatible_devices.push_back(device.second);
62+
}
63+
} else {
64+
TRTORCH_THROW_ERROR(
65+
"Unknown target device type detected from the compiled program (runtime.find_compatible_devices)");
66+
break;
67+
}
68+
}
69+
return compatible_devices;
70+
}
71+
1072
void set_cuda_device(CudaDevice& cuda_device) {
1173
TRTORCH_CHECK(
1274
(cudaSetDevice(cuda_device.id) == cudaSuccess), "Unable to set device: " << cuda_device << "as active device");
75+
LOG_DEBUG("Setting " << cuda_device << " as active device");
1376
}
1477

1578
CudaDevice get_current_device() {

core/runtime/runtime.h

+4
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,7 @@ struct CudaDevice {
2424
CudaDevice();
2525
CudaDevice(int64_t gpu_id, nvinfer1::DeviceType device_type);
2626
CudaDevice(std::string serialized_device_info);
27+
CudaDevice& operator=(const CudaDevice& other);
2728
std::string serialize();
2829
std::string getSMCapability() const;
2930
friend std::ostream& operator<<(std::ostream& os, const CudaDevice& device);
@@ -33,6 +34,9 @@ void set_cuda_device(CudaDevice& cuda_device);
3334
// Gets the current active GPU (DLA will not show up through this)
3435
CudaDevice get_current_device();
3536

37+
c10::optional<CudaDevice> get_most_compatible_device(const CudaDevice& target_device);
38+
std::vector<CudaDevice> find_compatible_devices(const CudaDevice& target_device);
39+
3640
std::string serialize_device(CudaDevice& cuda_device);
3741
CudaDevice deserialize_device(std::string device_info);
3842

tests/core/conversion/evaluators/test_aten_evaluators.cpp

+1
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33
#include "gtest/gtest.h"
44
#include "tests/util/util.h"
55
#include "torch/csrc/jit/ir/irparser.h"
6+
#include "torch/torch.h"
67

78
TEST(Evaluators, DivIntEvaluatesCorrectly) {
89
const auto graph = R"IR(

tests/cpp/BUILD

+2-2
Original file line numberDiff line numberDiff line change
@@ -94,10 +94,10 @@ cc_test(
9494
name = "test_multi_gpu_serdes",
9595
srcs = ["test_multi_gpu_serdes.cpp"],
9696
data = [
97-
":jit_models",
97+
"//tests/modules:jit_models",
9898
],
9999
deps = [
100-
":module_test",
100+
":cpp_api_test",
101101
],
102102
)
103103

tests/cpp/test_multi_gpu_serdes.cpp

+3-2
Original file line numberDiff line numberDiff line change
@@ -23,11 +23,12 @@ TEST_P(CppAPITests, CompiledModuleIsClose) {
2323
trt_results.push_back(trt_results_ivalues.toTensor());
2424

2525
for (size_t i = 0; i < trt_results.size(); i++) {
26-
ASSERT_TRUE(trtorch::tests::util::almostEqual(jit_results[i], trt_results[i].reshape_as(jit_results[i]), 2e-5));
26+
ASSERT_TRUE(trtorch::tests::util::almostEqual(
27+
jit_results[i], trt_results[i].reshape_as(jit_results[i]).to(torch::Device("cuda:0")), 2e-5));
2728
}
2829
}
2930

3031
INSTANTIATE_TEST_SUITE_P(
3132
CompiledModuleForwardIsCloseSuite,
3233
CppAPITests,
33-
testing::Values(PathAndInSize({"tests/modules/resnet18_traced.jit.pt", {{1, 3, 224, 224}}})));
34+
testing::Values(PathAndInSize({"tests/modules/resnet18_traced.jit.pt", {{1, 3, 224, 224}}, 2e-5})));

0 commit comments

Comments
 (0)