Skip to content
New issue

Have a question about this project? # for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “#”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? # to your account

Separate HappyTreeFunctions for internal and leaf nodes #864

Merged
merged 5 commits into from
May 18, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion src/details/ArborX_DetailsDistributedTreeImpl.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -473,7 +473,7 @@ struct CallbackWithDistance
// need to be fixed with a proper callback abstraction.
int const leaf_node_index = _rev_permute(index);
auto const &leaf_node_bounding_volume =
HappyTreeFriends::getBoundingVolume(_tree, leaf_node_index);
HappyTreeFriends::getLeafBoundingVolume(_tree, leaf_node_index);
out({index, distance(getGeometry(query), leaf_node_bounding_volume)});
}
};
Expand Down
19 changes: 12 additions & 7 deletions src/details/ArborX_DetailsHalfTraversal.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -53,25 +53,30 @@ struct HalfTraversal
KOKKOS_FUNCTION void operator()(int i) const
{
auto const predicate =
_get_predicate(HappyTreeFriends::getBoundingVolume(_bvh, i));
_get_predicate(HappyTreeFriends::getLeafBoundingVolume(_bvh, i));
auto const leaf_permutation_i =
HappyTreeFriends::getLeafPermutationIndex(_bvh, i);

int node = HappyTreeFriends::getRope(_bvh, i);
while (node != ROPE_SENTINEL)
{
if (predicate(HappyTreeFriends::getBoundingVolume(_bvh, node)))
bool const is_leaf = HappyTreeFriends::isLeaf(_bvh, node);

if (predicate(
(is_leaf
? HappyTreeFriends::getLeafBoundingVolume(_bvh, node)
: HappyTreeFriends::getInternalBoundingVolume(_bvh, node))))
{
if (!HappyTreeFriends::isLeaf(_bvh, node))
{
node = HappyTreeFriends::getLeftChild(_bvh, node);
}
else
if (is_leaf)
{
_callback(leaf_permutation_i,
HappyTreeFriends::getLeafPermutationIndex(_bvh, node));
node = HappyTreeFriends::getRope(_bvh, node);
}
else
{
node = HappyTreeFriends::getLeftChild(_bvh, node);
}
}
else
{
Expand Down
39 changes: 26 additions & 13 deletions src/details/ArborX_DetailsHappyTreeFriends.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -19,10 +19,9 @@
#include <type_traits>
#include <utility> // declval

namespace ArborX
{
namespace Details
namespace ArborX::Details
{

struct HappyTreeFriends
{
template <class BVH>
Expand All @@ -47,16 +46,32 @@ struct HappyTreeFriends
}

template <class BVH>
static KOKKOS_FUNCTION
// FIXME_HIP See https://github.com/arborx/ArborX/issues/553
#ifdef __HIP_DEVICE_COMPILE__
auto
#else
auto const &
#endif
getInternalBoundingVolume(BVH const &bvh, int i)
{
return bvh._internal_nodes(internalIndex(bvh, i)).bounding_volume;
}

template <class BVH>
static KOKKOS_FUNCTION
// FIXME_HIP See https://github.com/arborx/ArborX/issues/553
#ifdef __HIP_DEVICE_COMPILE__
static KOKKOS_FUNCTION auto getBoundingVolume(BVH const &bvh, int i)
auto
#else
static KOKKOS_FUNCTION auto const &getBoundingVolume(BVH const &bvh, int i)
auto const &
#endif
getLeafBoundingVolume(BVH const &bvh, int i)
{
auto const internal_i = internalIndex(bvh, i);
return (internal_i >= 0 ? bvh._internal_nodes(internal_i).bounding_volume
: bvh._leaf_nodes(i).bounding_volume);
static_assert(
std::is_same_v<decltype(bvh._internal_nodes(0).bounding_volume),
decltype(bvh._leaf_nodes(0).bounding_volume)>);
return bvh._leaf_nodes(i).bounding_volume;
}

template <class BVH>
Expand All @@ -83,12 +98,10 @@ struct HappyTreeFriends
template <class BVH>
static KOKKOS_FUNCTION auto getRope(BVH const &bvh, int i)
{
auto const internal_i = internalIndex(bvh, i);
return (internal_i >= 0 ? bvh._internal_nodes(internal_i).rope
: bvh._leaf_nodes(i).rope);
return (isLeaf(bvh, i) ? bvh._leaf_nodes(i).rope
: bvh._internal_nodes(internalIndex(bvh, i)).rope);
}
};
} // namespace Details
} // namespace ArborX
} // namespace ArborX::Details

#endif
56 changes: 36 additions & 20 deletions src/details/ArborX_DetailsTreeTraversal.hpp
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
/****************************************************************************
* Copyright (c) 2017-2022 by the ArborX authors *
* Copyright (c) 2017-2023 by the ArborX authors *
* All rights reserved. *
* *
* This file is part of the ArborX library. ArborX is *
Expand Down Expand Up @@ -78,7 +78,7 @@ struct TreeTraversal<BVH, Predicates, Callback, SpatialPredicateTag>
auto const &predicate = Access::get(_predicates, queryIndex);
auto const root = 0;
auto const &root_bounding_volume =
HappyTreeFriends::getBoundingVolume(_bvh, root);
HappyTreeFriends::getLeafBoundingVolume(_bvh, root);
if (predicate(root_bounding_volume))
{
_callback(predicate, 0);
Expand All @@ -92,20 +92,25 @@ struct TreeTraversal<BVH, Predicates, Callback, SpatialPredicateTag>
int node = HappyTreeFriends::getRoot(_bvh); // start with root
do
{
if (predicate(HappyTreeFriends::getBoundingVolume(_bvh, node)))
bool const is_leaf = HappyTreeFriends::isLeaf(_bvh, node);

if (predicate(
(is_leaf
? HappyTreeFriends::getLeafBoundingVolume(_bvh, node)
: HappyTreeFriends::getInternalBoundingVolume(_bvh, node))))
{
if (!HappyTreeFriends::isLeaf(_bvh, node))
{
node = HappyTreeFriends::getLeftChild(_bvh, node);
}
else
if (is_leaf)
{
if (invoke_callback_and_check_early_exit(
_callback, predicate,
HappyTreeFriends::getLeafPermutationIndex(_bvh, node)))
return;
node = HappyTreeFriends::getRope(_bvh, node);
}
else
{
node = HappyTreeFriends::getLeftChild(_bvh, node);
}
}
else
{
Expand Down Expand Up @@ -255,6 +260,14 @@ struct TreeTraversal<BVH, Predicates, Callback, NearestPredicateTag>
heap(UnmanagedStaticVector<PairIndexDistance>(buffer.data(),
buffer.size()));

auto &bvh = _bvh;
auto const distance = [&predicate, &bvh](int j) {
return predicate.distance(
HappyTreeFriends::isLeaf(bvh, j)
? HappyTreeFriends::getLeafBoundingVolume(bvh, j)
: HappyTreeFriends::getInternalBoundingVolume(bvh, j));
};

constexpr int SENTINEL = -1;
int stack[64];
auto *stack_ptr = stack;
Expand Down Expand Up @@ -285,10 +298,8 @@ struct TreeTraversal<BVH, Predicates, Callback, NearestPredicateTag>
left_child = HappyTreeFriends::getLeftChild(_bvh, node);
right_child = HappyTreeFriends::getRightChild(_bvh, node);

distance_left = predicate.distance(
HappyTreeFriends::getBoundingVolume(_bvh, left_child));
distance_right = predicate.distance(
HappyTreeFriends::getBoundingVolume(_bvh, right_child));
distance_left = distance(left_child);
distance_right = distance(right_child);

if (distance_left < radius)
{
Expand Down Expand Up @@ -337,8 +348,7 @@ struct TreeTraversal<BVH, Predicates, Callback, NearestPredicateTag>
// This is a theoretically unnecessary duplication of distance
// calculation for stack nodes. However, for Cuda it's better than
// putting the distances in stack.
distance_node = predicate.distance(
HappyTreeFriends::getBoundingVolume(_bvh, node));
distance_node = distance(node);
}
#else
distance_node = *--stack_distance_ptr;
Expand Down Expand Up @@ -423,7 +433,7 @@ struct TreeTraversal<BVH, Predicates, Callback,
auto const &predicate = Access::get(_predicates, queryIndex);
auto const root = 0;
auto const &root_bounding_volume =
HappyTreeFriends::getBoundingVolume(_bvh, root);
HappyTreeFriends::getLeafBoundingVolume(_bvh, root);
using distance_type =
decltype(distance(getGeometry(predicate), root_bounding_volume));
constexpr auto inf =
Expand All @@ -440,7 +450,7 @@ struct TreeTraversal<BVH, Predicates, Callback,
using ArborX::Details::HappyTreeFriends;

using distance_type = decltype(predicate.distance(
HappyTreeFriends::getBoundingVolume(_bvh, 0)));
HappyTreeFriends::getInternalBoundingVolume(_bvh, 0)));
using PairIndexDistance = Kokkos::pair<int, distance_type>;
struct CompareDistance
{
Expand All @@ -460,6 +470,14 @@ struct TreeTraversal<BVH, Predicates, Callback,
constexpr auto inf =
KokkosExt::ArithmeticTraits::infinity<distance_type>::value;

auto &bvh = _bvh;
auto const distance = [&predicate, &bvh](int j) {
return predicate.distance(
HappyTreeFriends::isLeaf(bvh, j)
? HappyTreeFriends::getLeafBoundingVolume(bvh, j)
: HappyTreeFriends::getInternalBoundingVolume(bvh, j));
};

int node = HappyTreeFriends::getRoot(_bvh);
int left_child;
int right_child;
Expand All @@ -484,12 +502,10 @@ struct TreeTraversal<BVH, Predicates, Callback,
left_child = HappyTreeFriends::getLeftChild(_bvh, node);
right_child = HappyTreeFriends::getRightChild(_bvh, node);

auto const distance_left = predicate.distance(
HappyTreeFriends::getBoundingVolume(_bvh, left_child));
auto const distance_left = distance(left_child);
auto const left_pair = Kokkos::make_pair(left_child, distance_left);

auto const distance_right = predicate.distance(
HappyTreeFriends::getBoundingVolume(_bvh, right_child));
auto const distance_right = distance(right_child);
auto const right_pair = Kokkos::make_pair(right_child, distance_right);

auto const &closer_pair =
Expand Down
6 changes: 4 additions & 2 deletions src/details/ArborX_DetailsTreeVisualization.hpp
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
/****************************************************************************
* Copyright (c) 2017-2022 by the ArborX authors *
* Copyright (c) 2017-2023 by the ArborX authors *
* All rights reserved. *
* *
* This file is part of the ArborX library. ArborX is *
Expand Down Expand Up @@ -121,7 +121,9 @@ struct TreeVisualization
auto const node_label = getNodeLabel(tree, node);
auto const node_attributes = getNodeAttributes(tree, node);
auto const bounding_volume =
HappyTreeFriends::getBoundingVolume(tree, node);
HappyTreeFriends::isLeaf(tree, node)
? HappyTreeFriends::getLeafBoundingVolume(tree, node)
: HappyTreeFriends::getInternalBoundingVolume(tree, node);
auto const min_corner = bounding_volume.minCorner();
auto const max_corner = bounding_volume.maxCorner();
_os << R"(\draw)" << node_attributes << " " << min_corner << " rectangle "
Expand Down
13 changes: 8 additions & 5 deletions src/details/ArborX_MinimumSpanningTree.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -176,11 +176,14 @@ struct FindComponentNearestNeighbors
constexpr auto inf = KokkosExt::ArithmeticTraits::infinity<float>::value;

auto const distance = [bounding_volume_i =
HappyTreeFriends::getBoundingVolume(_bvh, i),
HappyTreeFriends::getLeafBoundingVolume(_bvh, i),
&bvh = _bvh](int j) {
using Details::distance;
return distance(bounding_volume_i,
HappyTreeFriends::getBoundingVolume(bvh, j));
auto &&bounding_volume_j =
(HappyTreeFriends::isLeaf(bvh, j)
? HappyTreeFriends::getLeafBoundingVolume(bvh, j)
: HappyTreeFriends::getInternalBoundingVolume(bvh, j));
return distance(bounding_volume_i, bounding_volume_j);
};

auto const component = _labels(i);
Expand Down Expand Up @@ -680,8 +683,8 @@ void resetSharedRadii(ExecutionSpace const &space, BVH const &bvh,
auto const r =
metric(HappyTreeFriends::getLeafPermutationIndex(bvh, i),
HappyTreeFriends::getLeafPermutationIndex(bvh, j),
distance(HappyTreeFriends::getBoundingVolume(bvh, i),
HappyTreeFriends::getBoundingVolume(bvh, j)));
distance(HappyTreeFriends::getLeafBoundingVolume(bvh, i),
HappyTreeFriends::getLeafBoundingVolume(bvh, j)));
Kokkos::atomic_min(&radii(label_i), r);
Kokkos::atomic_min(&radii(label_j), r);
}
Expand Down