Skip to content

Commit

Permalink
Separate HappyTreeFunctions for internal and leaf nodes
Browse files Browse the repository at this point in the history
  • Loading branch information
aprokop committed May 13, 2023
1 parent 82488e1 commit 90adbe7
Show file tree
Hide file tree
Showing 6 changed files with 132 additions and 70 deletions.
4 changes: 2 additions & 2 deletions src/details/ArborX_DetailsDistributedTreeImpl.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -472,8 +472,8 @@ struct CallbackWithDistance
// the details of the local tree. Right now, this is the only way. Will
// need to be fixed with a proper callback abstraction.
int const leaf_node_index = _rev_permute(index);
auto const &leaf_node_bounding_volume =
HappyTreeFriends::getBoundingVolume(_tree, leaf_node_index);
auto const &leaf_node_bounding_volume = HappyTreeFriends::getBoundingVolume(
LeafNodeTag{}, _tree, leaf_node_index);
out({index, distance(getGeometry(query), leaf_node_bounding_volume)});
}
};
Expand Down
28 changes: 18 additions & 10 deletions src/details/ArborX_DetailsHalfTraversal.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -52,30 +52,38 @@ struct HalfTraversal

KOKKOS_FUNCTION void operator()(int i) const
{
auto const predicate =
_get_predicate(HappyTreeFriends::getBoundingVolume(_bvh, i));
auto const predicate = _get_predicate(
HappyTreeFriends::getBoundingVolume(LeafNodeTag{}, _bvh, i));
auto const leaf_permutation_i =
HappyTreeFriends::getLeafPermutationIndex(_bvh, i);

int node = HappyTreeFriends::getRope(_bvh, i);
int node = HappyTreeFriends::getRope(LeafNodeTag{}, _bvh, i);
while (node != ROPE_SENTINEL)
{
if (predicate(HappyTreeFriends::getBoundingVolume(_bvh, node)))
auto const internal_node = HappyTreeFriends::internalIndex(_bvh, node);
bool const is_leaf = HappyTreeFriends::isLeaf(_bvh, node);

if (predicate((is_leaf ? HappyTreeFriends::getBoundingVolume(
LeafNodeTag{}, _bvh, node)
: HappyTreeFriends::getBoundingVolume(
InternalNodeTag{}, _bvh, internal_node))))
{
if (!HappyTreeFriends::isLeaf(_bvh, node))
if (is_leaf)
{
node = HappyTreeFriends::getLeftChild(_bvh, node);
_callback(leaf_permutation_i,
HappyTreeFriends::getLeafPermutationIndex(_bvh, node));
node = HappyTreeFriends::getRope(LeafNodeTag{}, _bvh, node);
}
else
{
_callback(leaf_permutation_i,
HappyTreeFriends::getLeafPermutationIndex(_bvh, node));
node = HappyTreeFriends::getRope(_bvh, node);
node = HappyTreeFriends::getLeftChild(_bvh, node);
}
}
else
{
node = HappyTreeFriends::getRope(_bvh, node);
node = (is_leaf ? HappyTreeFriends::getRope(LeafNodeTag{}, _bvh, node)
: HappyTreeFriends::getRope(InternalNodeTag{}, _bvh,
internal_node));
}
}
}
Expand Down
47 changes: 30 additions & 17 deletions src/details/ArborX_DetailsHappyTreeFriends.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -19,10 +19,13 @@
#include <type_traits>
#include <utility> // declval

namespace ArborX
{
namespace Details
namespace ArborX::Details
{
struct LeafNodeTag
{};
struct InternalNodeTag
{};

struct HappyTreeFriends
{
template <class BVH>
Expand All @@ -46,17 +49,23 @@ struct HappyTreeFriends
return i - (int)bvh.size();
}

template <class BVH>
template <class Tag, class BVH>
static KOKKOS_FUNCTION
// FIXME_HIP See https://github.com/arborx/ArborX/issues/553
#ifdef __HIP_DEVICE_COMPILE__
static KOKKOS_FUNCTION auto getBoundingVolume(BVH const &bvh, int i)
auto
#else
static KOKKOS_FUNCTION auto const &getBoundingVolume(BVH const &bvh, int i)
auto const &
#endif
getBoundingVolume(Tag, BVH const &bvh, int i)
{
auto const internal_i = internalIndex(bvh, i);
return (internal_i >= 0 ? bvh._internal_nodes(internal_i).bounding_volume
: bvh._leaf_nodes(i).bounding_volume);
static_assert(
std::is_same_v<decltype(bvh._internal_nodes(0).bounding_volume),
decltype(bvh._leaf_nodes(0).bounding_volume)>);
if constexpr (std::is_same_v<Tag, InternalNodeTag>)
return bvh._internal_nodes(i).bounding_volume;
else
return bvh._leaf_nodes(i).bounding_volume;
}

template <class BVH>
Expand All @@ -77,18 +86,22 @@ struct HappyTreeFriends
static KOKKOS_FUNCTION auto getRightChild(BVH const &bvh, int i)
{
assert(!isLeaf(bvh, i));
return getRope(bvh, getLeftChild(bvh, i));
auto left_child = getLeftChild(bvh, i);
bool const is_leaf = isLeaf(bvh, left_child);
return (is_leaf ? getRope(LeafNodeTag{}, bvh, left_child)
: getRope(InternalNodeTag{}, bvh,
internalIndex(bvh, left_child)));
}

template <class BVH>
static KOKKOS_FUNCTION auto getRope(BVH const &bvh, int i)
template <class Tag, class BVH>
static KOKKOS_FUNCTION auto getRope(Tag, BVH const &bvh, int i)
{
auto const internal_i = internalIndex(bvh, i);
return (internal_i >= 0 ? bvh._internal_nodes(internal_i).rope
: bvh._leaf_nodes(i).rope);
if constexpr (std::is_same_v<Tag, InternalNodeTag>)
return bvh._internal_nodes(i).rope;
else
return bvh._leaf_nodes(i).rope;
}
};
} // namespace Details
} // namespace ArborX
} // namespace ArborX::Details

#endif
69 changes: 47 additions & 22 deletions src/details/ArborX_DetailsTreeTraversal.hpp
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
/****************************************************************************
* Copyright (c) 2017-2022 by the ArborX authors *
* Copyright (c) 2017-2023 by the ArborX authors *
* All rights reserved. *
* *
* This file is part of the ArborX library. ArborX is *
Expand Down Expand Up @@ -78,7 +78,7 @@ struct TreeTraversal<BVH, Predicates, Callback, SpatialPredicateTag>
auto const &predicate = Access::get(_predicates, queryIndex);
auto const root = 0;
auto const &root_bounding_volume =
HappyTreeFriends::getBoundingVolume(_bvh, root);
HappyTreeFriends::getBoundingVolume(LeafNodeTag{}, _bvh, root);
if (predicate(root_bounding_volume))
{
_callback(predicate, 0);
Expand All @@ -92,24 +92,32 @@ struct TreeTraversal<BVH, Predicates, Callback, SpatialPredicateTag>
int node = HappyTreeFriends::getRoot(_bvh); // start with root
do
{
if (predicate(HappyTreeFriends::getBoundingVolume(_bvh, node)))
auto const internal_node = HappyTreeFriends::internalIndex(_bvh, node);
bool const is_leaf = HappyTreeFriends::isLeaf(_bvh, node);

if (predicate((is_leaf ? HappyTreeFriends::getBoundingVolume(
LeafNodeTag{}, _bvh, node)
: HappyTreeFriends::getBoundingVolume(
InternalNodeTag{}, _bvh, internal_node))))
{
if (!HappyTreeFriends::isLeaf(_bvh, node))
{
node = HappyTreeFriends::getLeftChild(_bvh, node);
}
else
if (is_leaf)
{
if (invoke_callback_and_check_early_exit(
_callback, predicate,
HappyTreeFriends::getLeafPermutationIndex(_bvh, node)))
return;
node = HappyTreeFriends::getRope(_bvh, node);
node = HappyTreeFriends::getRope(LeafNodeTag{}, _bvh, node);
}
else
{
node = HappyTreeFriends::getLeftChild(_bvh, node);
}
}
else
{
node = HappyTreeFriends::getRope(_bvh, node);
node = (is_leaf ? HappyTreeFriends::getRope(LeafNodeTag{}, _bvh, node)
: HappyTreeFriends::getRope(InternalNodeTag{}, _bvh,
internal_node));
}
} while (node != ROPE_SENTINEL);
}
Expand Down Expand Up @@ -255,6 +263,17 @@ struct TreeTraversal<BVH, Predicates, Callback, NearestPredicateTag>
heap(UnmanagedStaticVector<PairIndexDistance>(buffer.data(),
buffer.size()));

auto &bvh = _bvh;
auto const distance = [&predicate, &bvh](int j) {
using Details::distance;
return predicate.distance(
HappyTreeFriends::isLeaf(bvh, j)
? HappyTreeFriends::getBoundingVolume(LeafNodeTag{}, bvh, j)
: HappyTreeFriends::getBoundingVolume(
InternalNodeTag{}, bvh,
HappyTreeFriends::internalIndex(bvh, j)));
};

constexpr int SENTINEL = -1;
int stack[64];
auto *stack_ptr = stack;
Expand Down Expand Up @@ -285,10 +304,8 @@ struct TreeTraversal<BVH, Predicates, Callback, NearestPredicateTag>
left_child = HappyTreeFriends::getLeftChild(_bvh, node);
right_child = HappyTreeFriends::getRightChild(_bvh, node);

distance_left = predicate.distance(
HappyTreeFriends::getBoundingVolume(_bvh, left_child));
distance_right = predicate.distance(
HappyTreeFriends::getBoundingVolume(_bvh, right_child));
distance_left = distance(left_child);
distance_right = distance(right_child);

if (distance_left < radius)
{
Expand Down Expand Up @@ -337,8 +354,7 @@ struct TreeTraversal<BVH, Predicates, Callback, NearestPredicateTag>
// This is a theoretically unnecessary duplication of distance
// calculation for stack nodes. However, for Cuda it's better than
// putting the distances in stack.
distance_node = predicate.distance(
HappyTreeFriends::getBoundingVolume(_bvh, node));
distance_node = distance(node);
}
#else
distance_node = *--stack_distance_ptr;
Expand Down Expand Up @@ -423,7 +439,7 @@ struct TreeTraversal<BVH, Predicates, Callback,
auto const &predicate = Access::get(_predicates, queryIndex);
auto const root = 0;
auto const &root_bounding_volume =
HappyTreeFriends::getBoundingVolume(_bvh, root);
HappyTreeFriends::getBoundingVolume(LeafNodeTag{}, _bvh, root);
using distance_type =
decltype(distance(getGeometry(predicate), root_bounding_volume));
constexpr auto inf =
Expand All @@ -440,7 +456,7 @@ struct TreeTraversal<BVH, Predicates, Callback,
using ArborX::Details::HappyTreeFriends;

using distance_type = decltype(predicate.distance(
HappyTreeFriends::getBoundingVolume(_bvh, 0)));
HappyTreeFriends::getBoundingVolume(InternalNodeTag{}, _bvh, 0)));
using PairIndexDistance = Kokkos::pair<int, distance_type>;
struct CompareDistance
{
Expand All @@ -460,6 +476,17 @@ struct TreeTraversal<BVH, Predicates, Callback,
constexpr auto inf =
KokkosExt::ArithmeticTraits::infinity<distance_type>::value;

auto &bvh = _bvh;
auto const distance = [&predicate, &bvh](int j) {
using Details::distance;
return predicate.distance(
HappyTreeFriends::isLeaf(bvh, j)
? HappyTreeFriends::getBoundingVolume(LeafNodeTag{}, bvh, j)
: HappyTreeFriends::getBoundingVolume(
InternalNodeTag{}, bvh,
HappyTreeFriends::internalIndex(bvh, j)));
};

int node = HappyTreeFriends::getRoot(_bvh);
int left_child;
int right_child;
Expand All @@ -484,12 +511,10 @@ struct TreeTraversal<BVH, Predicates, Callback,
left_child = HappyTreeFriends::getLeftChild(_bvh, node);
right_child = HappyTreeFriends::getRightChild(_bvh, node);

auto const distance_left = predicate.distance(
HappyTreeFriends::getBoundingVolume(_bvh, left_child));
auto const distance_left = distance(left_child);
auto const left_pair = Kokkos::make_pair(left_child, distance_left);

auto const distance_right = predicate.distance(
HappyTreeFriends::getBoundingVolume(_bvh, right_child));
auto const distance_right = distance(right_child);
auto const right_pair = Kokkos::make_pair(right_child, distance_right);

auto const &closer_pair =
Expand Down
23 changes: 16 additions & 7 deletions src/details/ArborX_DetailsTreeVisualization.hpp
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
/****************************************************************************
* Copyright (c) 2017-2022 by the ArborX authors *
* Copyright (c) 2017-2023 by the ArborX authors *
* All rights reserved. *
* *
* This file is part of the ArborX library. ArborX is *
Expand Down Expand Up @@ -29,6 +29,13 @@ std::ostream &operator<<(std::ostream &os, Point const &p)
os << "(" << p[0] << "," << p[1] << ")";
return os;
}
std::ostream &operator<<(std::ostream &os, Box const &box)
{
auto const min_corner = box.minCorner();
auto const max_corner = box.maxCorner();
os << min_corner << " rectangle " << max_corner;
return os;
}

struct TreeVisualization
{
Expand Down Expand Up @@ -120,12 +127,14 @@ struct TreeVisualization
{
auto const node_label = getNodeLabel(tree, node);
auto const node_attributes = getNodeAttributes(tree, node);
auto const bounding_volume =
HappyTreeFriends::getBoundingVolume(tree, node);
auto const min_corner = bounding_volume.minCorner();
auto const max_corner = bounding_volume.maxCorner();
_os << R"(\draw)" << node_attributes << " " << min_corner << " rectangle "
<< max_corner << " node {" << node_label << "};\n";
_os << R"(\draw)" << node_attributes << " ";
if (HappyTreeFriends::isLeaf(tree, node))
_os << HappyTreeFriends::getBoundingVolume(
InternalNodeTag{}, tree,
HappyTreeFriends::internalIndex(tree, node));
else
_os << HappyTreeFriends::getBoundingVolume(LeafNodeTag{}, tree, node);
_os << " node {" << node_label << "};\n";
}
};

Expand Down
31 changes: 19 additions & 12 deletions src/details/ArborX_MinimumSpanningTree.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -175,13 +175,19 @@ struct FindComponentNearestNeighbors
{
constexpr auto inf = KokkosExt::ArithmeticTraits::infinity<float>::value;

auto const distance = [bounding_volume_i =
HappyTreeFriends::getBoundingVolume(_bvh, i),
&bvh = _bvh](int j) {
using Details::distance;
return distance(bounding_volume_i,
HappyTreeFriends::getBoundingVolume(bvh, j));
};
auto const distance =
[bounding_volume_i =
HappyTreeFriends::getBoundingVolume(LeafNodeTag{}, _bvh, i),
&bvh = _bvh](int j) {
using Details::distance;
return distance(
bounding_volume_i,
(HappyTreeFriends::isLeaf(bvh, j)
? HappyTreeFriends::getBoundingVolume(LeafNodeTag{}, bvh, j)
: HappyTreeFriends::getBoundingVolume(
InternalNodeTag{}, bvh,
HappyTreeFriends::internalIndex(bvh, j))));
};

auto const component = _labels(i);
auto const predicate = [label_i = component, &labels = _labels](int j) {
Expand Down Expand Up @@ -677,11 +683,12 @@ void resetSharedRadii(ExecutionSpace const &space, BVH const &bvh,
auto const label_j = labels(j);
if (label_i != label_j)
{
auto const r =
metric(HappyTreeFriends::getLeafPermutationIndex(bvh, i),
HappyTreeFriends::getLeafPermutationIndex(bvh, j),
distance(HappyTreeFriends::getBoundingVolume(bvh, i),
HappyTreeFriends::getBoundingVolume(bvh, j)));
auto const r = metric(
HappyTreeFriends::getLeafPermutationIndex(bvh, i),
HappyTreeFriends::getLeafPermutationIndex(bvh, j),
distance(
HappyTreeFriends::getBoundingVolume(LeafNodeTag{}, bvh, i),
HappyTreeFriends::getBoundingVolume(LeafNodeTag{}, bvh, j)));
Kokkos::atomic_min(&radii(label_i), r);
Kokkos::atomic_min(&radii(label_j), r);
}
Expand Down

0 comments on commit 90adbe7

Please # to comment.