Skip to content

Commit

Permalink
Fix all tests and examples to use new half traversal
Browse files Browse the repository at this point in the history
  • Loading branch information
aprokop committed Apr 9, 2024
1 parent 646504a commit 0579555
Show file tree
Hide file tree
Showing 4 changed files with 30 additions and 16 deletions.
8 changes: 5 additions & 3 deletions src/ArborX_DBSCAN.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -57,14 +57,16 @@ struct WithinRadiusGetter
{
float _r;

template <typename Point>
KOKKOS_FUNCTION auto operator()(Point const &point) const
template <typename Point, typename Index>
KOKKOS_FUNCTION auto
operator()(PairValueIndex<Point, Index> const &value) const
{
static_assert(GeometryTraits::is_point_v<Point>);

constexpr int dim = GeometryTraits::dimension_v<Point>;
auto const &hyper_point =
reinterpret_cast<ExperimentalHyperGeometry::Point<dim> const &>(point);
reinterpret_cast<ExperimentalHyperGeometry::Point<dim> const &>(
value.value);
using ArborX::intersects;
return intersects(ExperimentalHyperGeometry::Sphere<dim>{hyper_point, _r});
}
Expand Down
9 changes: 8 additions & 1 deletion src/details/ArborX_DetailsFDBSCAN.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@

#include <ArborX_Callbacks.hpp>
#include <ArborX_DetailsUnionFind.hpp>
#include <ArborX_PairValueIndex.hpp>
#include <ArborX_Predicates.hpp>

#include <Kokkos_Core.hpp>
Expand Down Expand Up @@ -49,8 +50,14 @@ struct FDBSCANCallback
UnionFind _union_find;
CorePointsType _is_core_point;

KOKKOS_FUNCTION auto operator()(int i, int j) const
template <typename Value, typename Index>
KOKKOS_FUNCTION auto
operator()(PairValueIndex<Value, Index> const &value1,
PairValueIndex<Value, Index> const &value2) const
{
int i = value1.index;
int j = value2.index;

bool const is_border_point = !_is_core_point(i);
bool const neighbor_is_core_point = _is_core_point(j);
if (is_border_point)
Expand Down
25 changes: 14 additions & 11 deletions src/details/ArborX_NeighborList.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -28,16 +28,17 @@ struct NeighborListPredicateGetter
{
float _radius;

template <typename Point>
KOKKOS_FUNCTION auto operator()(Point point) const
template <typename Point, typename Index>
KOKKOS_FUNCTION auto
operator()(PairValueIndex<Point, Index> const &pair) const
{
static_assert(GeometryTraits::is_point_v<Point>);

constexpr int dim = GeometryTraits::dimension_v<Point>;
using Coordinate = typename GeometryTraits::coordinate_type_t<Point>;

auto const &hyper_point = reinterpret_cast<
ExperimentalHyperGeometry::Point<dim, Coordinate> const &>(point);
ExperimentalHyperGeometry::Point<dim, Coordinate> const &>(pair.value);
return intersects(ExperimentalHyperGeometry::Sphere<dim, Coordinate>{
hyper_point, _radius});
}
Expand Down Expand Up @@ -76,7 +77,9 @@ void findHalfNeighborList(ExecutionSpace const &space,
Kokkos::deep_copy(space, offsets, 0);
HalfTraversal(
space, bvh,
KOKKOS_LAMBDA(int, int j) { Kokkos::atomic_increment(&offsets(j)); },
KOKKOS_LAMBDA(auto, auto const &value) {
Kokkos::atomic_increment(&offsets(value.index));
},
NeighborListPredicateGetter{radius});
KokkosExt::exclusive_scan(space, offsets, offsets, 0);
KokkosExt::reallocWithoutInitializing(space, indices,
Expand All @@ -90,8 +93,8 @@ void findHalfNeighborList(ExecutionSpace const &space,
"ArborX::Experimental::HalfNeighborList::counts");
HalfTraversal(
space, bvh,
KOKKOS_LAMBDA(int i, int j) {
indices(Kokkos::atomic_fetch_inc(&counts(j))) = i;
KOKKOS_LAMBDA(auto const &value1, auto const &value2) {
indices(Kokkos::atomic_fetch_inc(&counts(value2.index))) = value1.index;
},
NeighborListPredicateGetter{radius});

Expand Down Expand Up @@ -132,9 +135,9 @@ void findFullNeighborList(ExecutionSpace const &space,
Kokkos::deep_copy(space, offsets, 0);
HalfTraversal(
space, bvh,
KOKKOS_LAMBDA(int i, int j) {
Kokkos::atomic_increment(&offsets(i));
Kokkos::atomic_increment(&offsets(j));
KOKKOS_LAMBDA(auto const &value1, auto const &value2) {
Kokkos::atomic_increment(&offsets(value1.index));
Kokkos::atomic_increment(&offsets(value2.index));
},
NeighborListPredicateGetter{radius});
KokkosExt::exclusive_scan(space, offsets, offsets, 0);
Expand All @@ -149,8 +152,8 @@ void findFullNeighborList(ExecutionSpace const &space,
"ArborX::Experimental::FullNeighborList::counts");
HalfTraversal(
space, bvh,
KOKKOS_LAMBDA(int i, int j) {
indices(Kokkos::atomic_fetch_inc(&counts(j))) = i;
KOKKOS_LAMBDA(auto const &value1, auto const &value2) {
indices(Kokkos::atomic_fetch_inc(&counts(value2.index))) = value1.index;
},
NeighborListPredicateGetter{radius});

Expand Down
4 changes: 3 additions & 1 deletion test/tstDetailsHalfTraversal.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -83,7 +83,9 @@ BOOST_AUTO_TEST_CASE_TEMPLATE(half_traversal, DeviceType, ARBORX_DEVICE_TYPES)
using ArborX::Details::HalfTraversal;
HalfTraversal(
exec_space, bvh,
KOKKOS_LAMBDA(int i, int j) {
KOKKOS_LAMBDA(auto const &value1, auto const &value2) {
int i = value1.index;
int j = value2.index;
auto [min_ij, max_ij] = Kokkos::minmax(i, j);
Kokkos::atomic_increment(&count(max_ij * (max_ij + 1) / 2 + min_ij));
},
Expand Down

0 comments on commit 0579555

Please # to comment.