Skip to content

Commit

Permalink
FIX: distributed knn double sqrt bug (#2733)
Browse files Browse the repository at this point in the history
* FIX: distributed knn double sqrt bug

* fixed

* Update cpp/oneapi/dal/algo/knn/backend/gpu/infer_kernel_impl_dpc.hpp

* Update cpp/oneapi/dal/algo/knn/backend/gpu/infer_kernel_impl_dpc.hpp
  • Loading branch information
ethanglaser authored Apr 30, 2024
1 parent 693898e commit 8e7fc7a
Show file tree
Hide file tree
Showing 2 changed files with 10 additions and 22 deletions.
16 changes: 8 additions & 8 deletions cpp/oneapi/dal/algo/knn/backend/gpu/infer_kernel_impl_dpc.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -286,12 +286,12 @@ class knn_callback {

const auto& [first, last] = bnds;
ONEDAL_ASSERT(last > first);
auto& queue = this->queue_;

bk::event_vector ndeps{ deps.cbegin(), deps.cend() };
auto sq_event = copy_with_sqrt(queue, inp_dts, inp_dts, deps);
if (this->compute_sqrt_)
ndeps.push_back(sq_event);
if (this->compute_sqrt_) {
auto sqrt_event = copy_with_sqrt(this->queue_, inp_dts, inp_dts, deps);
ndeps.push_back(sqrt_event);
}

auto out_rps = this->responses_.get_slice(first, last);
ONEDAL_ASSERT((last - first) == out_rps.get_count());
Expand All @@ -310,12 +310,12 @@ class knn_callback {

const auto& [first, last] = bnds;
ONEDAL_ASSERT(last > first);
auto& queue = this->queue_;

bk::event_vector ndeps{ deps.cbegin(), deps.cend() };
auto sq_event = copy_with_sqrt(queue, inp_dts, inp_dts, deps);
if (this->compute_sqrt_)
ndeps.push_back(sq_event);
if (this->compute_sqrt_) {
auto sqrt_event = copy_with_sqrt(this->queue_, inp_dts, inp_dts, deps);
ndeps.push_back(sqrt_event);
}

auto out_rps = this->responses_.get_slice(first, last);
ONEDAL_ASSERT((last - first) == out_rps.get_count());
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -347,16 +347,10 @@ class knn_callback_distr {

const auto& [first, last] = bnds;
ONEDAL_ASSERT(last > first);
auto& queue = this->queue_;

bk::event_vector ndeps{ deps.cbegin(), deps.cend() };
auto sq_event = copy_with_sqrt(queue, inp_dts, inp_dts, deps);
if (this->compute_sqrt_)
ndeps.push_back(sq_event);

auto out_rps = this->responses_.get_slice(first, last);
ONEDAL_ASSERT((last - first) == out_rps.get_count());
return (*(this->distance_voting_))(tmp_rps, inp_dts, out_rps, ndeps);
return (*(this->distance_voting_))(tmp_rps, inp_dts, out_rps, deps);
}

template <typename T = Task, typename = detail::enable_if_regression_t<T>>
Expand All @@ -371,16 +365,10 @@ class knn_callback_distr {

const auto& [first, last] = bnds;
ONEDAL_ASSERT(last > first);
auto& queue = this->queue_;

bk::event_vector ndeps{ deps.cbegin(), deps.cend() };
auto sq_event = copy_with_sqrt(queue, inp_dts, inp_dts, deps);
if (this->compute_sqrt_)
ndeps.push_back(sq_event);

auto out_rps = this->responses_.get_slice(first, last);
ONEDAL_ASSERT((last - first) == out_rps.get_count());
return (*(this->distance_regression_))(tmp_rps, inp_dts, out_rps, ndeps);
return (*(this->distance_regression_))(tmp_rps, inp_dts, out_rps, deps);
}

sycl::event output_responses(const std::pair<idx_t, idx_t>& bnds,
Expand Down

0 comments on commit 8e7fc7a

Please # to comment.