From 703130af09e1ae52cf652d8aaaf3bbec8fb224b7 Mon Sep 17 00:00:00 2001 From: Jack Atkinson Date: Fri, 11 Oct 2024 08:11:15 +0100 Subject: [PATCH] Update C++ XPU interface to handle multiple devices indices. --- src/ctorch.cpp | 14 +++++++++++--- 1 file changed, 11 insertions(+), 3 deletions(-) diff --git a/src/ctorch.cpp b/src/ctorch.cpp index 4d5f9f29..ab5e781b 100644 --- a/src/ctorch.cpp +++ b/src/ctorch.cpp @@ -63,11 +63,19 @@ const auto get_device(torch_device_t device_type, int device_index) } return torch::Device(torch::kMPS); case torch_kXPU: - if (device_index != -1) { - std::cerr << "[WARNING]: device index unused for XPU runs" + if (device_index == -1) { + std::cerr << "[WARNING]: device index unset, defaulting to 0" << std::endl; + device_index = 0; + } + if (device_index >= 0 && device_index < torch::xpu::device_count()) { + return torch::Device(torch::kXPU, device_index); + } else { + std::cerr << "[ERROR]: invalid device index " << device_index + << " for XPU device count " << torch::xpu::device_count() + << std::endl; + exit(EXIT_FAILURE); } - return torch::Device(torch::kXPU); default: std::cerr << "[WARNING]: unknown device type, setting to torch_kCPU" << std::endl;