Skip to content

Commit

Permalink
Cut down on the number of AccessTraits<Primitives, PredicatesTag>
Browse files Browse the repository at this point in the history
  • Loading branch information
aprokop committed Dec 6, 2023
1 parent f0acf28 commit aa590f6
Show file tree
Hide file tree
Showing 9 changed files with 134 additions and 155 deletions.
36 changes: 19 additions & 17 deletions src/ArborX_BruteForce.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -63,23 +63,23 @@ class BasicBruteForce
void query(ExecutionSpace const &space, Predicates const &predicates,
Callback const &callback, Ignore = Ignore()) const;

template <typename ExecutionSpace, typename Predicates,
template <typename ExecutionSpace, typename UserPredicates,
typename CallbackOrView, typename View, typename... Args>
std::enable_if_t<Kokkos::is_view_v<std::decay_t<View>>>
query(ExecutionSpace const &space, Predicates const &predicates,
query(ExecutionSpace const &space, UserPredicates const &user_predicates,
CallbackOrView &&callback_or_view, View &&view, Args &&...args) const
{
KokkosExt::ScopedProfileRegion guard("ArborX::BruteForce::query_crs");

Details::CrsGraphWrapperImpl::
check_valid_callback_if_first_argument_is_not_a_view<value_type>(
callback_or_view, predicates, view);
callback_or_view, user_predicates, view);

using Access = AccessTraits<Predicates, PredicatesTag>;
using Tag = typename Details::AccessTraitsHelper<Access>::tag;
using Predicates = Details::AccessValues<UserPredicates, PredicatesTag>;
using Tag = typename Predicates::value_type::Tag;

Details::CrsGraphWrapperImpl::queryDispatch(
Tag{}, *this, space, predicates,
Tag{}, *this, space, Predicates{user_predicates},
std::forward<CallbackOrView>(callback_or_view),
std::forward<View>(view), std::forward<Args>(args)...);
}
Expand Down Expand Up @@ -207,25 +207,27 @@ BasicBruteForce<MemorySpace, Value, IndexableGetter, BoundingVolume>::

template <typename MemorySpace, typename Value, typename IndexableGetter,
typename BoundingVolume>
template <typename ExecutionSpace, typename Predicates, typename Callback,
template <typename ExecutionSpace, typename UserPredicates, typename Callback,
typename Ignore>
void BasicBruteForce<MemorySpace, Value, IndexableGetter,
BoundingVolume>::query(ExecutionSpace const &space,
Predicates const &predicates,
Callback const &callback,
Ignore) const
void BasicBruteForce<MemorySpace, Value, IndexableGetter, BoundingVolume>::
query(ExecutionSpace const &space, UserPredicates const &user_predicates,
Callback const &callback, Ignore) const
{
static_assert(
KokkosExt::is_accessible_from<MemorySpace, ExecutionSpace>::value);
Details::check_valid_access_traits(PredicatesTag{}, predicates);
using Access = AccessTraits<Predicates, PredicatesTag>;
static_assert(KokkosExt::is_accessible_from<typename Access::memory_space,
Details::check_valid_access_traits(PredicatesTag{}, user_predicates);
Details::check_valid_callback<value_type>(callback, user_predicates);

using Predicates = Details::AccessValues<UserPredicates, PredicatesTag>;
static_assert(KokkosExt::is_accessible_from<typename Predicates::memory_space,
ExecutionSpace>::value,
"Predicates must be accessible from the execution space");
using Tag = typename Details::AccessTraitsHelper<Access>::tag;

Predicates predicates{user_predicates};

using Tag = typename Predicates::value_type::Tag;
static_assert(std::is_same<Tag, Details::SpatialPredicateTag>{},
"nearest query not implemented yet");
Details::check_valid_callback<Value>(callback, predicates);

Kokkos::Profiling::pushRegion("ArborX::BruteForce::query::spatial");

Expand Down
21 changes: 15 additions & 6 deletions src/ArborX_DistributedTree.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
#ifndef ARBORX_DISTRIBUTED_TREE_HPP
#define ARBORX_DISTRIBUTED_TREE_HPP

#include <ArborX_AccessTraits.hpp>
#include <ArborX_Box.hpp>
#include <ArborX_DetailsDistributedTreeImpl.hpp>
#include <ArborX_DetailsUtils.hpp> // accumulate
Expand Down Expand Up @@ -86,16 +87,24 @@ class DistributedTree
* - \c distances Computed distances (optional and only for nearest
* predicates).
*/
template <typename ExecutionSpace, typename Predicates, typename... Args>
void query(ExecutionSpace const &space, Predicates const &predicates,
template <typename ExecutionSpace, typename UserPredicates, typename... Args>
void query(ExecutionSpace const &space, UserPredicates const &user_predicates,
Args &&...args) const
{
static_assert(Kokkos::is_execution_space<ExecutionSpace>::value);
using Access = AccessTraits<Predicates, PredicatesTag>;
using Tag = typename Details::AccessTraitsHelper<Access>::tag;
static_assert(
KokkosExt::is_accessible_from<MemorySpace, ExecutionSpace>::value);

using Predicates = Details::AccessValues<UserPredicates, PredicatesTag>;
static_assert(
KokkosExt::is_accessible_from<typename Predicates::memory_space,
ExecutionSpace>::value,
"Predicates must be accessible from the execution space");

using Tag = typename Predicates::value_type::Tag;
using DeviceType = Kokkos::Device<ExecutionSpace, MemorySpace>;
Details::DistributedTreeImpl<DeviceType>::queryDispatch(
Tag{}, *this, space, predicates, std::forward<Args>(args)...);
Tag{}, *this, space, Predicates{user_predicates},
std::forward<Args>(args)...);
}

private:
Expand Down
28 changes: 15 additions & 13 deletions src/ArborX_LinearBVH.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -87,23 +87,23 @@ class BasicBoundingVolumeHierarchy
Experimental::TraversalPolicy const &policy =
Experimental::TraversalPolicy()) const;

template <typename ExecutionSpace, typename Predicates,
template <typename ExecutionSpace, typename UserPredicates,
typename CallbackOrView, typename View, typename... Args>
std::enable_if_t<Kokkos::is_view_v<std::decay_t<View>>>
query(ExecutionSpace const &space, Predicates const &predicates,
query(ExecutionSpace const &space, UserPredicates const &user_predicates,
CallbackOrView &&callback_or_view, View &&view, Args &&...args) const
{
KokkosExt::ScopedProfileRegion guard("ArborX::BVH::query_crs");

Details::CrsGraphWrapperImpl::
check_valid_callback_if_first_argument_is_not_a_view<value_type>(
callback_or_view, predicates, view);
callback_or_view, user_predicates, view);

using Access = AccessTraits<Predicates, PredicatesTag>;
using Tag = typename Details::AccessTraitsHelper<Access>::tag;
using Predicates = Details::AccessValues<UserPredicates, PredicatesTag>;
using Tag = typename Predicates::value_type::Tag;

Details::CrsGraphWrapperImpl::queryDispatch(
Tag{}, *this, space, predicates,
Tag{}, *this, space, Predicates{user_predicates},
std::forward<CallbackOrView>(callback_or_view),
std::forward<View>(view), std::forward<Args>(args)...);
}
Expand Down Expand Up @@ -326,24 +326,26 @@ BasicBoundingVolumeHierarchy<MemorySpace, Value, IndexableGetter,

template <typename MemorySpace, typename Value, typename IndexableGetter,
typename BoundingVolume>
template <typename ExecutionSpace, typename Predicates, typename Callback>
template <typename ExecutionSpace, typename UserPredicates, typename Callback>
void BasicBoundingVolumeHierarchy<
MemorySpace, Value, IndexableGetter,
BoundingVolume>::query(ExecutionSpace const &space,
Predicates const &predicates,
UserPredicates const &user_predicates,
Callback const &callback,
Experimental::TraversalPolicy const &policy) const
{
static_assert(
KokkosExt::is_accessible_from<MemorySpace, ExecutionSpace>::value);
Details::check_valid_access_traits(PredicatesTag{}, predicates);
using Access = AccessTraits<Predicates, PredicatesTag>;
static_assert(KokkosExt::is_accessible_from<typename Access::memory_space,
Details::check_valid_access_traits(PredicatesTag{}, user_predicates);
Details::check_valid_callback<value_type>(callback, user_predicates);

using Predicates = Details::AccessValues<UserPredicates, PredicatesTag>;
static_assert(KokkosExt::is_accessible_from<typename Predicates::memory_space,
ExecutionSpace>::value,
"Predicates must be accessible from the execution space");
Details::check_valid_callback<value_type>(callback, predicates);
Predicates predicates{user_predicates};

using Tag = typename Details::AccessTraitsHelper<Access>::tag;
using Tag = typename Predicates::value_type::Tag;
std::string profiling_prefix = "ArborX::BVH::query::";
if constexpr (std::is_same_v<Tag, Details::SpatialPredicateTag>)
{
Expand Down
26 changes: 9 additions & 17 deletions src/details/ArborX_DetailsBatchedQueries.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,6 @@
#ifndef ARBORX_DETAILS_BATCHED_QUERIES_HPP
#define ARBORX_DETAILS_BATCHED_QUERIES_HPP

#include <ArborX_AccessTraits.hpp>
#include <ArborX_Box.hpp>
#include <ArborX_DetailsAlgorithms.hpp> // returnCentroid, translateAndScale
#include <ArborX_DetailsKokkosExtViewHelpers.hpp>
Expand Down Expand Up @@ -53,11 +52,10 @@ struct BatchedQueries
Box const &scene_bounding_box,
Predicates const &predicates)
{
using Access = AccessTraits<Predicates, PredicatesTag>;
auto const n_queries = Access::size(predicates);
auto const n_queries = predicates.size();

using Point = std::decay_t<decltype(returnCentroid(
getGeometry(Access::get(predicates, 0))))>;
using Point =
std::decay_t<decltype(returnCentroid(getGeometry(predicates(0))))>;
using LinearOrderingValueType =
Kokkos::detected_t<SpaceFillingCurveProjectionArchetypeExpression,
SpaceFillingCurve, Box, Point>;
Expand All @@ -69,9 +67,8 @@ struct BatchedQueries
"ArborX::BatchedQueries::project_predicates_onto_space_filling_curve",
Kokkos::RangePolicy<ExecutionSpace>(space, 0, n_queries),
KOKKOS_LAMBDA(int i) {
linear_ordering_indices(i) =
curve(scene_bounding_box,
returnCentroid(getGeometry(Access::get(predicates, i))));
linear_ordering_indices(i) = curve(
scene_bounding_box, returnCentroid(getGeometry(predicates(i))));
});

return sortObjects(space, linear_ordering_indices);
Expand All @@ -85,24 +82,19 @@ struct BatchedQueries
applyPermutation(ExecutionSpace const &space,
Kokkos::View<unsigned int const *, DeviceType> permute,
Predicates const &v)
-> Kokkos::View<typename AccessTraitsHelper<
AccessTraits<Predicates, PredicatesTag>>::type *,
DeviceType>
-> Kokkos::View<typename Predicates::value_type *, DeviceType>
{
using Access = AccessTraits<Predicates, PredicatesTag>;
auto const n = Access::size(v);
auto const n = v.size();
ARBORX_ASSERT(permute.extent(0) == n);

using T = std::decay_t<decltype(Access::get(
std::declval<Predicates const &>(), std::declval<int>()))>;
Kokkos::View<T *, DeviceType> w(
Kokkos::View<typename Predicates::value_type *, DeviceType> w(
Kokkos::view_alloc(space, Kokkos::WithoutInitializing,
"ArborX::permuted_predicates"),
n);
Kokkos::parallel_for(
"ArborX::BatchedQueries::permute_entries",
Kokkos::RangePolicy<ExecutionSpace>(space, 0, n),
KOKKOS_LAMBDA(int i) { w(i) = Access::get(v, permute(i)); });
KOKKOS_LAMBDA(int i) { w(i) = v(permute(i)); });

return w;
}
Expand Down
9 changes: 3 additions & 6 deletions src/details/ArborX_DetailsBruteForceImpl.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,6 @@
#ifndef ARBORX_DETAILS_BRUTE_FORCE_IMPL_HPP
#define ARBORX_DETAILS_BRUTE_FORCE_IMPL_HPP

#include <ArborX_AccessTraits.hpp>
#include <ArborX_DetailsAlgorithms.hpp> // expand
#include <ArborX_Exception.hpp>

Expand Down Expand Up @@ -53,12 +52,11 @@ struct BruteForceImpl
Callback const &callback)
{
using TeamPolicy = Kokkos::TeamPolicy<ExecutionSpace>;
using AccessPredicates = AccessTraits<Predicates, PredicatesTag>;
using PredicateType = typename AccessTraitsHelper<AccessPredicates>::type;
using PredicateType = typename Predicates::value_type;
using IndexableType = std::decay_t<decltype(indexables(0))>;

int const n_indexables = values.size();
int const n_predicates = AccessPredicates::size(predicates);
int const n_predicates = predicates.size();
int max_scratch_size = TeamPolicy::scratch_size_max(0);
// half of the scratch memory used by predicates and half for indexables
int const predicates_per_team =
Expand Down Expand Up @@ -110,8 +108,7 @@ struct BruteForceImpl
Kokkos::parallel_for(
Kokkos::TeamVectorRange(teamMember, predicates_in_this_team),
[&](const int q) {
scratch_predicates(q) =
AccessPredicates::get(predicates, predicate_start + q);
scratch_predicates(q) = predicates(predicate_start + q);
});
Kokkos::parallel_for(
Kokkos::TeamVectorRange(teamMember, indexables_in_this_team),
Expand Down
40 changes: 15 additions & 25 deletions src/details/ArborX_DetailsCrsGraphWrapperImpl.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,6 @@
#ifndef ARBORX_DETAIL_CRS_GRAPH_WRAPPER_IMPL_HPP
#define ARBORX_DETAIL_CRS_GRAPH_WRAPPER_IMPL_HPP

#include <ArborX_AccessTraits.hpp>
#include <ArborX_Box.hpp>
#include <ArborX_Callbacks.hpp>
#include <ArborX_DetailsBatchedQueries.hpp>
Expand Down Expand Up @@ -50,8 +49,8 @@ struct FirstPassNoBufferOptimizationTag
struct SecondPassTag
{};

template <typename PassTag, typename Predicates, typename Callback,
typename OutputView, typename CountView, typename PermutedOffset>
template <typename PassTag, typename Callback, typename OutputView,
typename CountView, typename PermutedOffset>
struct InsertGenerator
{
Callback _callback;
Expand All @@ -60,11 +59,9 @@ struct InsertGenerator
PermutedOffset _permuted_offset;

using ValueType = typename OutputView::value_type;
using Access = AccessTraits<Predicates, PredicatesTag>;
using PredicateType = typename AccessTraitsHelper<Access>::type;

template <typename Value>
KOKKOS_FUNCTION auto operator()(PredicateType const &predicate,
template <typename Predicate, typename Value>
KOKKOS_FUNCTION auto operator()(Predicate const &predicate,
Value const &value) const
{
auto const predicate_index = getData(predicate);
Expand Down Expand Up @@ -126,8 +123,7 @@ void queryImpl(ExecutionSpace const &space, Tree const &tree,

static_assert(Kokkos::is_execution_space<ExecutionSpace>{});

using Access = AccessTraits<Predicates, PredicatesTag>;
auto const n_queries = Access::size(predicates);
auto const n_queries = predicates.size();

Kokkos::Profiling::pushRegion("ArborX::CrsGraphWrapper::two_pass");

Expand All @@ -150,9 +146,8 @@ void queryImpl(ExecutionSpace const &space, Tree const &tree,
{
tree.query(
space, permuted_predicates,
InsertGenerator<FirstPassTag, PermutedPredicates, Callback, OutputView,
CountView, PermutedOffset>{callback, out, counts,
permuted_offset},
InsertGenerator<FirstPassTag, Callback, OutputView, CountView,
PermutedOffset>{callback, out, counts, permuted_offset},
ArborX::Experimental::TraversalPolicy().setPredicateSorting(false));

// Detecting overflow is a local operation that needs to be done for every
Expand Down Expand Up @@ -185,9 +180,9 @@ void queryImpl(ExecutionSpace const &space, Tree const &tree,
{
tree.query(
space, permuted_predicates,
InsertGenerator<FirstPassNoBufferOptimizationTag, PermutedPredicates,
Callback, OutputView, CountView, PermutedOffset>{
callback, out, counts, permuted_offset},
InsertGenerator<FirstPassNoBufferOptimizationTag, Callback, OutputView,
CountView, PermutedOffset>{callback, out, counts,
permuted_offset},
ArborX::Experimental::TraversalPolicy().setPredicateSorting(false));
// This may not be true, but it does not matter. As long as we have
// (n_results == 0) check before second pass, this value is not used.
Expand Down Expand Up @@ -247,9 +242,8 @@ void queryImpl(ExecutionSpace const &space, Tree const &tree,

tree.query(
space, permuted_predicates,
InsertGenerator<SecondPassTag, PermutedPredicates, Callback, OutputView,
CountView, PermutedOffset>{callback, out, counts,
permuted_offset},
InsertGenerator<SecondPassTag, Callback, OutputView, CountView,
PermutedOffset>{callback, out, counts, permuted_offset},
ArborX::Experimental::TraversalPolicy().setPredicateSorting(false));

Kokkos::Profiling::popRegion();
Expand Down Expand Up @@ -298,9 +292,7 @@ allocateAndInitializeStorage(Tag, ExecutionSpace const &space,
Predicates const &predicates, OffsetView &offset,
OutView &out, int buffer_size)
{
using Access = AccessTraits<Predicates, PredicatesTag>;

auto const n_queries = Access::size(predicates);
auto const n_queries = predicates.size();
KokkosExt::reallocWithoutInitializing(space, offset, n_queries + 1);

buffer_size = std::abs(buffer_size);
Expand All @@ -324,16 +316,14 @@ allocateAndInitializeStorage(Tag, ExecutionSpace const &space,
Predicates const &predicates, OffsetView &offset,
OutView &out, int /*buffer_size*/)
{
using Access = AccessTraits<Predicates, PredicatesTag>;

auto const n_queries = Access::size(predicates);
auto const n_queries = predicates.size();
KokkosExt::reallocWithoutInitializing(space, offset, n_queries + 1);

Kokkos::parallel_for(
"ArborX::CrsGraphWrapper::query::nearest::"
"scan_queries_for_numbers_of_nearest_neighbors",
Kokkos::RangePolicy<ExecutionSpace>(space, 0, n_queries),
KOKKOS_LAMBDA(int i) { offset(i) = getK(Access::get(predicates, i)); });
KOKKOS_LAMBDA(int i) { offset(i) = getK(predicates(i)); });
exclusivePrefixSum(space, offset);

KokkosExt::reallocWithoutInitializing(space, out,
Expand Down
Loading

0 comments on commit aa590f6

Please # to comment.