diff --git a/src/details/ArborX_DetailsDistributedTreeImpl.hpp b/src/details/ArborX_DetailsDistributedTreeImpl.hpp index ecff3bfbe..20d2a540b 100644 --- a/src/details/ArborX_DetailsDistributedTreeImpl.hpp +++ b/src/details/ArborX_DetailsDistributedTreeImpl.hpp @@ -473,7 +473,7 @@ struct CallbackWithDistance // need to be fixed with a proper callback abstraction. int const leaf_node_index = _rev_permute(index); auto const &leaf_node_bounding_volume = - HappyTreeFriends::getBoundingVolume(_tree, leaf_node_index); + HappyTreeFriends::getLeafBoundingVolume(_tree, leaf_node_index); out({index, distance(getGeometry(query), leaf_node_bounding_volume)}); } }; diff --git a/src/details/ArborX_DetailsHalfTraversal.hpp b/src/details/ArborX_DetailsHalfTraversal.hpp index a4a2709f2..5ed7886c7 100644 --- a/src/details/ArborX_DetailsHalfTraversal.hpp +++ b/src/details/ArborX_DetailsHalfTraversal.hpp @@ -53,25 +53,30 @@ struct HalfTraversal KOKKOS_FUNCTION void operator()(int i) const { auto const predicate = - _get_predicate(HappyTreeFriends::getBoundingVolume(_bvh, i)); + _get_predicate(HappyTreeFriends::getLeafBoundingVolume(_bvh, i)); auto const leaf_permutation_i = HappyTreeFriends::getLeafPermutationIndex(_bvh, i); int node = HappyTreeFriends::getRope(_bvh, i); while (node != ROPE_SENTINEL) { - if (predicate(HappyTreeFriends::getBoundingVolume(_bvh, node))) + bool const is_leaf = HappyTreeFriends::isLeaf(_bvh, node); + + if (predicate( + (is_leaf + ? HappyTreeFriends::getLeafBoundingVolume(_bvh, node) + : HappyTreeFriends::getInternalBoundingVolume(_bvh, node)))) { - if (!HappyTreeFriends::isLeaf(_bvh, node)) - { - node = HappyTreeFriends::getLeftChild(_bvh, node); - } - else + if (is_leaf) { _callback(leaf_permutation_i, HappyTreeFriends::getLeafPermutationIndex(_bvh, node)); node = HappyTreeFriends::getRope(_bvh, node); } + else + { + node = HappyTreeFriends::getLeftChild(_bvh, node); + } } else { diff --git a/src/details/ArborX_DetailsHappyTreeFriends.hpp b/src/details/ArborX_DetailsHappyTreeFriends.hpp index 6cc45cef7..d861b8bdd 100644 --- a/src/details/ArborX_DetailsHappyTreeFriends.hpp +++ b/src/details/ArborX_DetailsHappyTreeFriends.hpp @@ -19,10 +19,9 @@ #include #include // declval -namespace ArborX -{ -namespace Details +namespace ArborX::Details { + struct HappyTreeFriends { template @@ -47,16 +46,32 @@ struct HappyTreeFriends } template + static KOKKOS_FUNCTION +// FIXME_HIP See https://github.com/arborx/ArborX/issues/553 +#ifdef __HIP_DEVICE_COMPILE__ + auto +#else + auto const & +#endif + getInternalBoundingVolume(BVH const &bvh, int i) + { + return bvh._internal_nodes(internalIndex(bvh, i)).bounding_volume; + } + + template + static KOKKOS_FUNCTION // FIXME_HIP See https://github.com/arborx/ArborX/issues/553 #ifdef __HIP_DEVICE_COMPILE__ - static KOKKOS_FUNCTION auto getBoundingVolume(BVH const &bvh, int i) + auto #else - static KOKKOS_FUNCTION auto const &getBoundingVolume(BVH const &bvh, int i) + auto const & #endif + getLeafBoundingVolume(BVH const &bvh, int i) { - auto const internal_i = internalIndex(bvh, i); - return (internal_i >= 0 ? bvh._internal_nodes(internal_i).bounding_volume - : bvh._leaf_nodes(i).bounding_volume); + static_assert( + std::is_same_v); + return bvh._leaf_nodes(i).bounding_volume; } template @@ -83,12 +98,10 @@ struct HappyTreeFriends template static KOKKOS_FUNCTION auto getRope(BVH const &bvh, int i) { - auto const internal_i = internalIndex(bvh, i); - return (internal_i >= 0 ? bvh._internal_nodes(internal_i).rope - : bvh._leaf_nodes(i).rope); + return (isLeaf(bvh, i) ? bvh._leaf_nodes(i).rope + : bvh._internal_nodes(internalIndex(bvh, i)).rope); } }; -} // namespace Details -} // namespace ArborX +} // namespace ArborX::Details #endif diff --git a/src/details/ArborX_DetailsTreeTraversal.hpp b/src/details/ArborX_DetailsTreeTraversal.hpp index 0fd7fef3a..ce0cc4377 100644 --- a/src/details/ArborX_DetailsTreeTraversal.hpp +++ b/src/details/ArborX_DetailsTreeTraversal.hpp @@ -1,5 +1,5 @@ /**************************************************************************** - * Copyright (c) 2017-2022 by the ArborX authors * + * Copyright (c) 2017-2023 by the ArborX authors * * All rights reserved. * * * * This file is part of the ArborX library. ArborX is * @@ -78,7 +78,7 @@ struct TreeTraversal auto const &predicate = Access::get(_predicates, queryIndex); auto const root = 0; auto const &root_bounding_volume = - HappyTreeFriends::getBoundingVolume(_bvh, root); + HappyTreeFriends::getLeafBoundingVolume(_bvh, root); if (predicate(root_bounding_volume)) { _callback(predicate, 0); @@ -92,13 +92,14 @@ struct TreeTraversal int node = HappyTreeFriends::getRoot(_bvh); // start with root do { - if (predicate(HappyTreeFriends::getBoundingVolume(_bvh, node))) + bool const is_leaf = HappyTreeFriends::isLeaf(_bvh, node); + + if (predicate( + (is_leaf + ? HappyTreeFriends::getLeafBoundingVolume(_bvh, node) + : HappyTreeFriends::getInternalBoundingVolume(_bvh, node)))) { - if (!HappyTreeFriends::isLeaf(_bvh, node)) - { - node = HappyTreeFriends::getLeftChild(_bvh, node); - } - else + if (is_leaf) { if (invoke_callback_and_check_early_exit( _callback, predicate, @@ -106,6 +107,10 @@ struct TreeTraversal return; node = HappyTreeFriends::getRope(_bvh, node); } + else + { + node = HappyTreeFriends::getLeftChild(_bvh, node); + } } else { @@ -255,6 +260,14 @@ struct TreeTraversal heap(UnmanagedStaticVector(buffer.data(), buffer.size())); + auto &bvh = _bvh; + auto const distance = [&predicate, &bvh](int j) { + return predicate.distance( + HappyTreeFriends::isLeaf(bvh, j) + ? HappyTreeFriends::getLeafBoundingVolume(bvh, j) + : HappyTreeFriends::getInternalBoundingVolume(bvh, j)); + }; + constexpr int SENTINEL = -1; int stack[64]; auto *stack_ptr = stack; @@ -285,10 +298,8 @@ struct TreeTraversal left_child = HappyTreeFriends::getLeftChild(_bvh, node); right_child = HappyTreeFriends::getRightChild(_bvh, node); - distance_left = predicate.distance( - HappyTreeFriends::getBoundingVolume(_bvh, left_child)); - distance_right = predicate.distance( - HappyTreeFriends::getBoundingVolume(_bvh, right_child)); + distance_left = distance(left_child); + distance_right = distance(right_child); if (distance_left < radius) { @@ -337,8 +348,7 @@ struct TreeTraversal // This is a theoretically unnecessary duplication of distance // calculation for stack nodes. However, for Cuda it's better than // putting the distances in stack. - distance_node = predicate.distance( - HappyTreeFriends::getBoundingVolume(_bvh, node)); + distance_node = distance(node); } #else distance_node = *--stack_distance_ptr; @@ -423,7 +433,7 @@ struct TreeTraversal; struct CompareDistance { @@ -460,6 +470,14 @@ struct TreeTraversal::value; + auto &bvh = _bvh; + auto const distance = [&predicate, &bvh](int j) { + return predicate.distance( + HappyTreeFriends::isLeaf(bvh, j) + ? HappyTreeFriends::getLeafBoundingVolume(bvh, j) + : HappyTreeFriends::getInternalBoundingVolume(bvh, j)); + }; + int node = HappyTreeFriends::getRoot(_bvh); int left_child; int right_child; @@ -484,12 +502,10 @@ struct TreeTraversal::value; auto const distance = [bounding_volume_i = - HappyTreeFriends::getBoundingVolume(_bvh, i), + HappyTreeFriends::getLeafBoundingVolume(_bvh, i), &bvh = _bvh](int j) { using Details::distance; - return distance(bounding_volume_i, - HappyTreeFriends::getBoundingVolume(bvh, j)); + auto &&bounding_volume_j = + (HappyTreeFriends::isLeaf(bvh, j) + ? HappyTreeFriends::getLeafBoundingVolume(bvh, j) + : HappyTreeFriends::getInternalBoundingVolume(bvh, j)); + return distance(bounding_volume_i, bounding_volume_j); }; auto const component = _labels(i); @@ -680,8 +683,8 @@ void resetSharedRadii(ExecutionSpace const &space, BVH const &bvh, auto const r = metric(HappyTreeFriends::getLeafPermutationIndex(bvh, i), HappyTreeFriends::getLeafPermutationIndex(bvh, j), - distance(HappyTreeFriends::getBoundingVolume(bvh, i), - HappyTreeFriends::getBoundingVolume(bvh, j))); + distance(HappyTreeFriends::getLeafBoundingVolume(bvh, i), + HappyTreeFriends::getLeafBoundingVolume(bvh, j))); Kokkos::atomic_min(&radii(label_i), r); Kokkos::atomic_min(&radii(label_j), r); }