From 8e7fc7ac6a48f80b2fa31100102e1a4b81db6848 Mon Sep 17 00:00:00 2001 From: ethanglaser <42726565+ethanglaser@users.noreply.github.com> Date: Tue, 30 Apr 2024 06:23:17 -0700 Subject: [PATCH] FIX: distributed knn double sqrt bug (#2733) * FIX: distributed knn double sqrt bug * fixed * Update cpp/oneapi/dal/algo/knn/backend/gpu/infer_kernel_impl_dpc.hpp * Update cpp/oneapi/dal/algo/knn/backend/gpu/infer_kernel_impl_dpc.hpp --- .../knn/backend/gpu/infer_kernel_impl_dpc.hpp | 16 ++++++++-------- .../backend/gpu/infer_kernel_impl_dpc_distr.hpp | 16 ++-------------- 2 files changed, 10 insertions(+), 22 deletions(-) diff --git a/cpp/oneapi/dal/algo/knn/backend/gpu/infer_kernel_impl_dpc.hpp b/cpp/oneapi/dal/algo/knn/backend/gpu/infer_kernel_impl_dpc.hpp index 9a31ef369ae..1ba10bf8737 100644 --- a/cpp/oneapi/dal/algo/knn/backend/gpu/infer_kernel_impl_dpc.hpp +++ b/cpp/oneapi/dal/algo/knn/backend/gpu/infer_kernel_impl_dpc.hpp @@ -286,12 +286,12 @@ class knn_callback { const auto& [first, last] = bnds; ONEDAL_ASSERT(last > first); - auto& queue = this->queue_; bk::event_vector ndeps{ deps.cbegin(), deps.cend() }; - auto sq_event = copy_with_sqrt(queue, inp_dts, inp_dts, deps); - if (this->compute_sqrt_) - ndeps.push_back(sq_event); + if (this->compute_sqrt_) { + auto sqrt_event = copy_with_sqrt(this->queue_, inp_dts, inp_dts, deps); + ndeps.push_back(sqrt_event); + } auto out_rps = this->responses_.get_slice(first, last); ONEDAL_ASSERT((last - first) == out_rps.get_count()); @@ -310,12 +310,12 @@ class knn_callback { const auto& [first, last] = bnds; ONEDAL_ASSERT(last > first); - auto& queue = this->queue_; bk::event_vector ndeps{ deps.cbegin(), deps.cend() }; - auto sq_event = copy_with_sqrt(queue, inp_dts, inp_dts, deps); - if (this->compute_sqrt_) - ndeps.push_back(sq_event); + if (this->compute_sqrt_) { + auto sqrt_event = copy_with_sqrt(this->queue_, inp_dts, inp_dts, deps); + ndeps.push_back(sqrt_event); + } auto out_rps = this->responses_.get_slice(first, last); ONEDAL_ASSERT((last - first) == out_rps.get_count()); diff --git a/cpp/oneapi/dal/algo/knn/backend/gpu/infer_kernel_impl_dpc_distr.hpp b/cpp/oneapi/dal/algo/knn/backend/gpu/infer_kernel_impl_dpc_distr.hpp index e67d555616a..daf3caa9187 100644 --- a/cpp/oneapi/dal/algo/knn/backend/gpu/infer_kernel_impl_dpc_distr.hpp +++ b/cpp/oneapi/dal/algo/knn/backend/gpu/infer_kernel_impl_dpc_distr.hpp @@ -347,16 +347,10 @@ class knn_callback_distr { const auto& [first, last] = bnds; ONEDAL_ASSERT(last > first); - auto& queue = this->queue_; - - bk::event_vector ndeps{ deps.cbegin(), deps.cend() }; - auto sq_event = copy_with_sqrt(queue, inp_dts, inp_dts, deps); - if (this->compute_sqrt_) - ndeps.push_back(sq_event); auto out_rps = this->responses_.get_slice(first, last); ONEDAL_ASSERT((last - first) == out_rps.get_count()); - return (*(this->distance_voting_))(tmp_rps, inp_dts, out_rps, ndeps); + return (*(this->distance_voting_))(tmp_rps, inp_dts, out_rps, deps); } template > @@ -371,16 +365,10 @@ class knn_callback_distr { const auto& [first, last] = bnds; ONEDAL_ASSERT(last > first); - auto& queue = this->queue_; - - bk::event_vector ndeps{ deps.cbegin(), deps.cend() }; - auto sq_event = copy_with_sqrt(queue, inp_dts, inp_dts, deps); - if (this->compute_sqrt_) - ndeps.push_back(sq_event); auto out_rps = this->responses_.get_slice(first, last); ONEDAL_ASSERT((last - first) == out_rps.get_count()); - return (*(this->distance_regression_))(tmp_rps, inp_dts, out_rps, ndeps); + return (*(this->distance_regression_))(tmp_rps, inp_dts, out_rps, deps); } sycl::event output_responses(const std::pair& bnds,