From 90adbe7ca824b90b8caa8c0d33dd93f463707b75 Mon Sep 17 00:00:00 2001 From: Andrey Prokopenko Date: Mon, 8 May 2023 14:10:16 -0400 Subject: [PATCH] Separate HappyTreeFunctions for internal and leaf nodes --- .../ArborX_DetailsDistributedTreeImpl.hpp | 4 +- src/details/ArborX_DetailsHalfTraversal.hpp | 28 +++++--- .../ArborX_DetailsHappyTreeFriends.hpp | 47 ++++++++----- src/details/ArborX_DetailsTreeTraversal.hpp | 69 +++++++++++++------ .../ArborX_DetailsTreeVisualization.hpp | 23 +++++-- src/details/ArborX_MinimumSpanningTree.hpp | 31 +++++---- 6 files changed, 132 insertions(+), 70 deletions(-) diff --git a/src/details/ArborX_DetailsDistributedTreeImpl.hpp b/src/details/ArborX_DetailsDistributedTreeImpl.hpp index ecff3bfbe..c60a63b26 100644 --- a/src/details/ArborX_DetailsDistributedTreeImpl.hpp +++ b/src/details/ArborX_DetailsDistributedTreeImpl.hpp @@ -472,8 +472,8 @@ struct CallbackWithDistance // the details of the local tree. Right now, this is the only way. Will // need to be fixed with a proper callback abstraction. int const leaf_node_index = _rev_permute(index); - auto const &leaf_node_bounding_volume = - HappyTreeFriends::getBoundingVolume(_tree, leaf_node_index); + auto const &leaf_node_bounding_volume = HappyTreeFriends::getBoundingVolume( + LeafNodeTag{}, _tree, leaf_node_index); out({index, distance(getGeometry(query), leaf_node_bounding_volume)}); } }; diff --git a/src/details/ArborX_DetailsHalfTraversal.hpp b/src/details/ArborX_DetailsHalfTraversal.hpp index a4a2709f2..f0e84a4a8 100644 --- a/src/details/ArborX_DetailsHalfTraversal.hpp +++ b/src/details/ArborX_DetailsHalfTraversal.hpp @@ -52,30 +52,38 @@ struct HalfTraversal KOKKOS_FUNCTION void operator()(int i) const { - auto const predicate = - _get_predicate(HappyTreeFriends::getBoundingVolume(_bvh, i)); + auto const predicate = _get_predicate( + HappyTreeFriends::getBoundingVolume(LeafNodeTag{}, _bvh, i)); auto const leaf_permutation_i = HappyTreeFriends::getLeafPermutationIndex(_bvh, i); - int node = HappyTreeFriends::getRope(_bvh, i); + int node = HappyTreeFriends::getRope(LeafNodeTag{}, _bvh, i); while (node != ROPE_SENTINEL) { - if (predicate(HappyTreeFriends::getBoundingVolume(_bvh, node))) + auto const internal_node = HappyTreeFriends::internalIndex(_bvh, node); + bool const is_leaf = HappyTreeFriends::isLeaf(_bvh, node); + + if (predicate((is_leaf ? HappyTreeFriends::getBoundingVolume( + LeafNodeTag{}, _bvh, node) + : HappyTreeFriends::getBoundingVolume( + InternalNodeTag{}, _bvh, internal_node)))) { - if (!HappyTreeFriends::isLeaf(_bvh, node)) + if (is_leaf) { - node = HappyTreeFriends::getLeftChild(_bvh, node); + _callback(leaf_permutation_i, + HappyTreeFriends::getLeafPermutationIndex(_bvh, node)); + node = HappyTreeFriends::getRope(LeafNodeTag{}, _bvh, node); } else { - _callback(leaf_permutation_i, - HappyTreeFriends::getLeafPermutationIndex(_bvh, node)); - node = HappyTreeFriends::getRope(_bvh, node); + node = HappyTreeFriends::getLeftChild(_bvh, node); } } else { - node = HappyTreeFriends::getRope(_bvh, node); + node = (is_leaf ? HappyTreeFriends::getRope(LeafNodeTag{}, _bvh, node) + : HappyTreeFriends::getRope(InternalNodeTag{}, _bvh, + internal_node)); } } } diff --git a/src/details/ArborX_DetailsHappyTreeFriends.hpp b/src/details/ArborX_DetailsHappyTreeFriends.hpp index 6cc45cef7..925d7abd8 100644 --- a/src/details/ArborX_DetailsHappyTreeFriends.hpp +++ b/src/details/ArborX_DetailsHappyTreeFriends.hpp @@ -19,10 +19,13 @@ #include #include // declval -namespace ArborX -{ -namespace Details +namespace ArborX::Details { +struct LeafNodeTag +{}; +struct InternalNodeTag +{}; + struct HappyTreeFriends { template @@ -46,17 +49,23 @@ struct HappyTreeFriends return i - (int)bvh.size(); } - template + template + static KOKKOS_FUNCTION // FIXME_HIP See https://github.com/arborx/ArborX/issues/553 #ifdef __HIP_DEVICE_COMPILE__ - static KOKKOS_FUNCTION auto getBoundingVolume(BVH const &bvh, int i) + auto #else - static KOKKOS_FUNCTION auto const &getBoundingVolume(BVH const &bvh, int i) + auto const & #endif + getBoundingVolume(Tag, BVH const &bvh, int i) { - auto const internal_i = internalIndex(bvh, i); - return (internal_i >= 0 ? bvh._internal_nodes(internal_i).bounding_volume - : bvh._leaf_nodes(i).bounding_volume); + static_assert( + std::is_same_v); + if constexpr (std::is_same_v) + return bvh._internal_nodes(i).bounding_volume; + else + return bvh._leaf_nodes(i).bounding_volume; } template @@ -77,18 +86,22 @@ struct HappyTreeFriends static KOKKOS_FUNCTION auto getRightChild(BVH const &bvh, int i) { assert(!isLeaf(bvh, i)); - return getRope(bvh, getLeftChild(bvh, i)); + auto left_child = getLeftChild(bvh, i); + bool const is_leaf = isLeaf(bvh, left_child); + return (is_leaf ? getRope(LeafNodeTag{}, bvh, left_child) + : getRope(InternalNodeTag{}, bvh, + internalIndex(bvh, left_child))); } - template - static KOKKOS_FUNCTION auto getRope(BVH const &bvh, int i) + template + static KOKKOS_FUNCTION auto getRope(Tag, BVH const &bvh, int i) { - auto const internal_i = internalIndex(bvh, i); - return (internal_i >= 0 ? bvh._internal_nodes(internal_i).rope - : bvh._leaf_nodes(i).rope); + if constexpr (std::is_same_v) + return bvh._internal_nodes(i).rope; + else + return bvh._leaf_nodes(i).rope; } }; -} // namespace Details -} // namespace ArborX +} // namespace ArborX::Details #endif diff --git a/src/details/ArborX_DetailsTreeTraversal.hpp b/src/details/ArborX_DetailsTreeTraversal.hpp index 0fd7fef3a..123839a68 100644 --- a/src/details/ArborX_DetailsTreeTraversal.hpp +++ b/src/details/ArborX_DetailsTreeTraversal.hpp @@ -1,5 +1,5 @@ /**************************************************************************** - * Copyright (c) 2017-2022 by the ArborX authors * + * Copyright (c) 2017-2023 by the ArborX authors * * All rights reserved. * * * * This file is part of the ArborX library. ArborX is * @@ -78,7 +78,7 @@ struct TreeTraversal auto const &predicate = Access::get(_predicates, queryIndex); auto const root = 0; auto const &root_bounding_volume = - HappyTreeFriends::getBoundingVolume(_bvh, root); + HappyTreeFriends::getBoundingVolume(LeafNodeTag{}, _bvh, root); if (predicate(root_bounding_volume)) { _callback(predicate, 0); @@ -92,24 +92,32 @@ struct TreeTraversal int node = HappyTreeFriends::getRoot(_bvh); // start with root do { - if (predicate(HappyTreeFriends::getBoundingVolume(_bvh, node))) + auto const internal_node = HappyTreeFriends::internalIndex(_bvh, node); + bool const is_leaf = HappyTreeFriends::isLeaf(_bvh, node); + + if (predicate((is_leaf ? HappyTreeFriends::getBoundingVolume( + LeafNodeTag{}, _bvh, node) + : HappyTreeFriends::getBoundingVolume( + InternalNodeTag{}, _bvh, internal_node)))) { - if (!HappyTreeFriends::isLeaf(_bvh, node)) - { - node = HappyTreeFriends::getLeftChild(_bvh, node); - } - else + if (is_leaf) { if (invoke_callback_and_check_early_exit( _callback, predicate, HappyTreeFriends::getLeafPermutationIndex(_bvh, node))) return; - node = HappyTreeFriends::getRope(_bvh, node); + node = HappyTreeFriends::getRope(LeafNodeTag{}, _bvh, node); + } + else + { + node = HappyTreeFriends::getLeftChild(_bvh, node); } } else { - node = HappyTreeFriends::getRope(_bvh, node); + node = (is_leaf ? HappyTreeFriends::getRope(LeafNodeTag{}, _bvh, node) + : HappyTreeFriends::getRope(InternalNodeTag{}, _bvh, + internal_node)); } } while (node != ROPE_SENTINEL); } @@ -255,6 +263,17 @@ struct TreeTraversal heap(UnmanagedStaticVector(buffer.data(), buffer.size())); + auto &bvh = _bvh; + auto const distance = [&predicate, &bvh](int j) { + using Details::distance; + return predicate.distance( + HappyTreeFriends::isLeaf(bvh, j) + ? HappyTreeFriends::getBoundingVolume(LeafNodeTag{}, bvh, j) + : HappyTreeFriends::getBoundingVolume( + InternalNodeTag{}, bvh, + HappyTreeFriends::internalIndex(bvh, j))); + }; + constexpr int SENTINEL = -1; int stack[64]; auto *stack_ptr = stack; @@ -285,10 +304,8 @@ struct TreeTraversal left_child = HappyTreeFriends::getLeftChild(_bvh, node); right_child = HappyTreeFriends::getRightChild(_bvh, node); - distance_left = predicate.distance( - HappyTreeFriends::getBoundingVolume(_bvh, left_child)); - distance_right = predicate.distance( - HappyTreeFriends::getBoundingVolume(_bvh, right_child)); + distance_left = distance(left_child); + distance_right = distance(right_child); if (distance_left < radius) { @@ -337,8 +354,7 @@ struct TreeTraversal // This is a theoretically unnecessary duplication of distance // calculation for stack nodes. However, for Cuda it's better than // putting the distances in stack. - distance_node = predicate.distance( - HappyTreeFriends::getBoundingVolume(_bvh, node)); + distance_node = distance(node); } #else distance_node = *--stack_distance_ptr; @@ -423,7 +439,7 @@ struct TreeTraversal; struct CompareDistance { @@ -460,6 +476,17 @@ struct TreeTraversal::value; + auto &bvh = _bvh; + auto const distance = [&predicate, &bvh](int j) { + using Details::distance; + return predicate.distance( + HappyTreeFriends::isLeaf(bvh, j) + ? HappyTreeFriends::getBoundingVolume(LeafNodeTag{}, bvh, j) + : HappyTreeFriends::getBoundingVolume( + InternalNodeTag{}, bvh, + HappyTreeFriends::internalIndex(bvh, j))); + }; + int node = HappyTreeFriends::getRoot(_bvh); int left_child; int right_child; @@ -484,12 +511,10 @@ struct TreeTraversal::value; - auto const distance = [bounding_volume_i = - HappyTreeFriends::getBoundingVolume(_bvh, i), - &bvh = _bvh](int j) { - using Details::distance; - return distance(bounding_volume_i, - HappyTreeFriends::getBoundingVolume(bvh, j)); - }; + auto const distance = + [bounding_volume_i = + HappyTreeFriends::getBoundingVolume(LeafNodeTag{}, _bvh, i), + &bvh = _bvh](int j) { + using Details::distance; + return distance( + bounding_volume_i, + (HappyTreeFriends::isLeaf(bvh, j) + ? HappyTreeFriends::getBoundingVolume(LeafNodeTag{}, bvh, j) + : HappyTreeFriends::getBoundingVolume( + InternalNodeTag{}, bvh, + HappyTreeFriends::internalIndex(bvh, j)))); + }; auto const component = _labels(i); auto const predicate = [label_i = component, &labels = _labels](int j) { @@ -677,11 +683,12 @@ void resetSharedRadii(ExecutionSpace const &space, BVH const &bvh, auto const label_j = labels(j); if (label_i != label_j) { - auto const r = - metric(HappyTreeFriends::getLeafPermutationIndex(bvh, i), - HappyTreeFriends::getLeafPermutationIndex(bvh, j), - distance(HappyTreeFriends::getBoundingVolume(bvh, i), - HappyTreeFriends::getBoundingVolume(bvh, j))); + auto const r = metric( + HappyTreeFriends::getLeafPermutationIndex(bvh, i), + HappyTreeFriends::getLeafPermutationIndex(bvh, j), + distance( + HappyTreeFriends::getBoundingVolume(LeafNodeTag{}, bvh, i), + HappyTreeFriends::getBoundingVolume(LeafNodeTag{}, bvh, j))); Kokkos::atomic_min(&radii(label_i), r); Kokkos::atomic_min(&radii(label_j), r); }