Skip to content

Commit

Permalink
Merge pull request #349 from dalg24/allow_non_device_type_template_param
Browse files Browse the repository at this point in the history
  • Loading branch information
aprokop authored Jul 31, 2020
2 parents 60a9232 + 0308e58 commit ffed0fe
Showing 1 changed file with 116 additions and 122 deletions.
238 changes: 116 additions & 122 deletions src/details/ArborX_DetailsDistributedSearchTreeImpl.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -27,15 +27,11 @@
namespace ArborX
{

template <typename DeviceType, typename Enable>
class DistributedSearchTree;

namespace Details
{

struct CallbackDefaultSpatialPredicateWithRank
{
using tag = InlineCallbackTag;
int _rank;
template <typename Query, typename Insert>
KOKKOS_FUNCTION void operator()(Query const &, int index,
Expand All @@ -49,15 +45,14 @@ template <typename DeviceType>
struct DistributedSearchTreeImpl
{
// spatial queries
template <typename ExecutionSpace, typename Predicates>
static void
queryDispatch(SpatialPredicateTag,
DistributedSearchTree<typename DeviceType::memory_space,
void> const &tree,
template <typename DistributedTree, typename ExecutionSpace,
typename Predicates, typename Indices, typename Offset,
typename Ranks>
static std::enable_if_t<Kokkos::is_view<Indices>{} &&
Kokkos::is_view<Offset>{} && Kokkos::is_view<Ranks>{}>
queryDispatch(SpatialPredicateTag, DistributedTree const &tree,
ExecutionSpace const &space, Predicates const &queries,
Kokkos::View<int *, DeviceType> &indices,
Kokkos::View<int *, DeviceType> &offset,
Kokkos::View<int *, DeviceType> &ranks)
Indices &indices, Offset &offset, Ranks &ranks)
{
Kokkos::View<Kokkos::pair<int, int> *, DeviceType> out("pairs_index_rank",
0);
Expand All @@ -77,94 +72,93 @@ struct DistributedSearchTreeImpl
});
}

template <typename ExecutionSpace, typename Predicates, typename OutputView,
template <typename DistributedTree, typename ExecutionSpace,
typename Predicates, typename OutputView, typename OffsetView,
typename Callback>
static void
queryDispatch(SpatialPredicateTag,
DistributedSearchTree<typename DeviceType::memory_space,
void> const &tree,
static std::enable_if_t<Kokkos::is_view<OutputView>{} &&
Kokkos::is_view<OffsetView>{}>
queryDispatch(SpatialPredicateTag, DistributedTree const &tree,
ExecutionSpace const &space, Predicates const &queries,
Callback const &callback, OutputView &out,
Kokkos::View<int *, DeviceType> &offset);
Callback const &callback, OutputView &out, OffsetView &offset);

// nearest neighbors queries
template <typename ExecutionSpace, typename Predicates>
static void
queryDispatch(NearestPredicateTag,
DistributedSearchTree<typename DeviceType::memory_space,
void> const &tree,
template <typename DistributedTree, typename ExecutionSpace,
typename Predicates, typename Indices, typename Offset,
typename Ranks,
typename Distances = Kokkos::View<float *, DeviceType>>
static std::enable_if_t<
Kokkos::is_view<Indices>{} && Kokkos::is_view<Offset>{} &&
Kokkos::is_view<Ranks>{} && Kokkos::is_view<Distances>{}>
queryDispatch(NearestPredicateTag, DistributedTree const &tree,
ExecutionSpace const &space, Predicates const &queries,
Kokkos::View<int *, DeviceType> &indices,
Kokkos::View<int *, DeviceType> &offset,
Kokkos::View<int *, DeviceType> &ranks,
Kokkos::View<float *, DeviceType> *distances_ptr = nullptr);

template <typename ExecutionSpace, typename Predicates>
static void
queryDispatch(NearestPredicateTag tag,
DistributedSearchTree<typename DeviceType::memory_space,
void> const &tree,
Indices &indices, Offset &offset, Ranks &ranks,
Distances *distances_ptr = nullptr);

template <typename DistributedTree, typename ExecutionSpace,
typename Predicates, typename Indices, typename Offset,
typename Ranks, typename Distances>
static std::enable_if_t<
Kokkos::is_view<Indices>{} && Kokkos::is_view<Offset>{} &&
Kokkos::is_view<Ranks>{} && Kokkos::is_view<Distances>{}>
queryDispatch(NearestPredicateTag tag, DistributedTree const &tree,
ExecutionSpace const &space, Predicates const &queries,
Kokkos::View<int *, DeviceType> &indices,
Kokkos::View<int *, DeviceType> &offset,
Kokkos::View<int *, DeviceType> &ranks,
Kokkos::View<float *, DeviceType> &distances)
Indices &indices, Offset &offset, Ranks &ranks,
Distances &distances)
{
queryDispatch(tag, tree, space, queries, indices, offset, ranks,
&distances);
}

template <typename ExecutionSpace, typename Predicates>
static void
deviseStrategy(ExecutionSpace const &space, Predicates const &queries,
DistributedSearchTree<typename DeviceType::memory_space,
void> const &tree,
Kokkos::View<int *, DeviceType> &indices,
Kokkos::View<int *, DeviceType> &offset,
Kokkos::View<float *, DeviceType> &);

template <typename ExecutionSpace, typename Predicates>
static void
reassessStrategy(ExecutionSpace const &space, Predicates const &queries,
DistributedSearchTree<typename DeviceType::memory_space,
void> const &tree,
Kokkos::View<int *, DeviceType> &indices,
Kokkos::View<int *, DeviceType> &offset,
Kokkos::View<float *, DeviceType> &distances);

template <typename ExecutionSpace, typename Predicates, typename Query>
template <typename DistributedTree, typename ExecutionSpace,
typename Predicates, typename Indices, typename Offset,
typename Distances>
static void deviseStrategy(ExecutionSpace const &space,
Predicates const &queries,
DistributedTree const &tree, Indices &indices,
Offset &offset, Distances &);

template <typename DistributedTree, typename ExecutionSpace,
typename Predicates, typename Indices, typename Offset,
typename Distances>
static void reassessStrategy(ExecutionSpace const &space,
Predicates const &queries,
DistributedTree const &tree, Indices &indices,
Offset &offset, Distances &distances);

template <typename ExecutionSpace, typename Predicates, typename Ranks,
typename Query>
static void forwardQueries(MPI_Comm comm, ExecutionSpace const &space,
Predicates const &queries,
Kokkos::View<int *, DeviceType> indices,
Kokkos::View<int *, DeviceType> offset,
Kokkos::View<Query *, DeviceType> &fwd_queries,
Kokkos::View<int *, DeviceType> &fwd_ids,
Kokkos::View<int *, DeviceType> &fwd_ranks);

template <typename ExecutionSpace, typename OutputView>
static void communicateResultsBack(
MPI_Comm comm, ExecutionSpace const &space, OutputView &view,
Kokkos::View<int *, DeviceType> offset,
Kokkos::View<int *, DeviceType> &ranks,
Kokkos::View<int *, DeviceType> &ids,
Kokkos::View<float *, DeviceType> *distances_ptr = nullptr);

template <typename ExecutionSpace, typename Predicates>
Ranks &fwd_ranks);

template <typename ExecutionSpace, typename OutputView, typename Ranks,
typename Distances = Kokkos::View<float *, DeviceType>>
static void communicateResultsBack(MPI_Comm comm, ExecutionSpace const &space,
OutputView &view,
Kokkos::View<int *, DeviceType> offset,
Ranks &ranks,
Kokkos::View<int *, DeviceType> &ids,
Distances *distances_ptr = nullptr);

template <typename ExecutionSpace, typename Predicates, typename Indices,
typename Offset, typename Ranks>
static void filterResults(ExecutionSpace const &space,
Predicates const &queries,
Kokkos::View<float *, DeviceType> distances,
Kokkos::View<int *, DeviceType> &indices,
Kokkos::View<int *, DeviceType> &offset,
Kokkos::View<int *, DeviceType> &ranks);
Indices &indices, Offset &offset, Ranks &ranks);

template <typename ExecutionSpace, typename View, typename... OtherViews>
static void sortResults(ExecutionSpace const &space, View keys,
OtherViews... other_views);

template <typename ExecutionSpace>
template <typename ExecutionSpace, typename OffsetView>
static void countResults(ExecutionSpace const &space, int n_queries,
Kokkos::View<int *, DeviceType> query_ids,
Kokkos::View<int *, DeviceType> &offset);
OffsetView &offset);

template <typename ExecutionSpace, typename View>
static typename std::enable_if<Kokkos::is_view<View>::value>::type
Expand Down Expand Up @@ -287,13 +281,12 @@ DistributedSearchTreeImpl<DeviceType>::sendAcrossNetwork(
}

template <typename DeviceType>
template <typename ExecutionSpace, typename Predicates>
template <typename DistributedTree, typename ExecutionSpace,
typename Predicates, typename Indices, typename Offset,
typename Distances>
void DistributedSearchTreeImpl<DeviceType>::deviseStrategy(
ExecutionSpace const &space, Predicates const &queries,
DistributedSearchTree<typename DeviceType::memory_space, void> const &tree,
Kokkos::View<int *, DeviceType> &indices,
Kokkos::View<int *, DeviceType> &offset,
Kokkos::View<float *, DeviceType> &)
DistributedTree const &tree, Indices &indices, Offset &offset, Distances &)
{
auto const &top_tree = tree._top_tree;
auto const &bottom_tree_sizes = tree._bottom_tree_sizes;
Expand Down Expand Up @@ -344,13 +337,13 @@ void DistributedSearchTreeImpl<DeviceType>::deviseStrategy(
}

template <typename DeviceType>
template <typename ExecutionSpace, typename Predicates>
template <typename DistributedTree, typename ExecutionSpace,
typename Predicates, typename Indices, typename Offset,
typename Distances>
void DistributedSearchTreeImpl<DeviceType>::reassessStrategy(
ExecutionSpace const &space, Predicates const &queries,
DistributedSearchTree<typename DeviceType::memory_space, void> const &tree,
Kokkos::View<int *, DeviceType> &indices,
Kokkos::View<int *, DeviceType> &offset,
Kokkos::View<float *, DeviceType> &distances)
DistributedTree const &tree, Indices &indices, Offset &offset,
Distances &distances)
{
auto const &top_tree = tree._top_tree;
using Access = AccessTraits<Predicates, PredicatesTag>;
Expand Down Expand Up @@ -388,20 +381,20 @@ void DistributedSearchTreeImpl<DeviceType>::reassessStrategy(
}

template <typename DeviceType>
template <typename ExecutionSpace, typename Predicates>
void DistributedSearchTreeImpl<DeviceType>::queryDispatch(
NearestPredicateTag,
DistributedSearchTree<typename DeviceType::memory_space, void> const &tree,
ExecutionSpace const &space, Predicates const &queries,
Kokkos::View<int *, DeviceType> &indices,
Kokkos::View<int *, DeviceType> &offset,
Kokkos::View<int *, DeviceType> &ranks,
Kokkos::View<float *, DeviceType> *distances_ptr)
template <typename DistributedTree, typename ExecutionSpace,
typename Predicates, typename Indices, typename Offset,
typename Ranks, typename Distances>
std::enable_if_t<Kokkos::is_view<Indices>{} && Kokkos::is_view<Offset>{} &&
Kokkos::is_view<Ranks>{} && Kokkos::is_view<Distances>{}>
DistributedSearchTreeImpl<DeviceType>::queryDispatch(
NearestPredicateTag, DistributedTree const &tree,
ExecutionSpace const &space, Predicates const &queries, Indices &indices,
Offset &offset, Ranks &ranks, Distances *distances_ptr)
{
auto const &bottom_tree = tree._bottom_tree;
auto comm = tree._comm;

Kokkos::View<float *, DeviceType> distances("distances", 0);
Distances distances("distances", 0);
if (distances_ptr)
distances = *distances_ptr;

Expand All @@ -417,11 +410,9 @@ void DistributedSearchTreeImpl<DeviceType>::queryDispatch(

// NOTE: compiler would not deduce __range for the braced-init-list but I
// got it to work with the static_cast to function pointers.
using Strategy = void (*)(
ExecutionSpace const &, Predicates const &,
DistributedSearchTree<typename DeviceType::memory_space, void> const &,
Kokkos::View<int *, DeviceType> &, Kokkos::View<int *, DeviceType> &,
Kokkos::View<float *, DeviceType> &);
using Strategy =
void (*)(ExecutionSpace const &, Predicates const &,
DistributedTree const &, Indices &, Offset &, Distances &);
for (auto implementStrategy :
{static_cast<Strategy>(
DistributedSearchTreeImpl<DeviceType>::deviseStrategy),
Expand Down Expand Up @@ -463,14 +454,14 @@ void DistributedSearchTreeImpl<DeviceType>::queryDispatch(
}

template <typename DeviceType>
template <typename ExecutionSpace, typename Predicates, typename OutputView,
template <typename DistributedTree, typename ExecutionSpace,
typename Predicates, typename OutputView, typename OffsetView,
typename Callback>
void DistributedSearchTreeImpl<DeviceType>::queryDispatch(
SpatialPredicateTag,
DistributedSearchTree<typename DeviceType::memory_space, void> const &tree,
std::enable_if_t<Kokkos::is_view<OutputView>{} && Kokkos::is_view<OffsetView>{}>
DistributedSearchTreeImpl<DeviceType>::queryDispatch(
SpatialPredicateTag, DistributedTree const &tree,
ExecutionSpace const &space, Predicates const &queries,
Callback const &callback, OutputView &out,
Kokkos::View<int *, DeviceType> &offset)
Callback const &callback, OutputView &out, OffsetView &offset)
{
auto const &top_tree = tree._top_tree;
auto const &bottom_tree = tree._bottom_tree;
Expand Down Expand Up @@ -539,11 +530,10 @@ void DistributedSearchTreeImpl<DeviceType>::sortResults(
}

template <typename DeviceType>
template <typename ExecutionSpace>
template <typename ExecutionSpace, typename OffsetView>
void DistributedSearchTreeImpl<DeviceType>::countResults(
ExecutionSpace const &space, int n_queries,
Kokkos::View<int *, DeviceType> query_ids,
Kokkos::View<int *, DeviceType> &offset)
Kokkos::View<int *, DeviceType> query_ids, OffsetView &offset)
{
int const nnz = query_ids.extent(0);

Expand All @@ -559,14 +549,14 @@ void DistributedSearchTreeImpl<DeviceType>::countResults(
}

template <typename DeviceType>
template <typename ExecutionSpace, typename Predicates, typename Query>
template <typename ExecutionSpace, typename Predicates, typename Ranks,
typename Query>
void DistributedSearchTreeImpl<DeviceType>::forwardQueries(
MPI_Comm comm, ExecutionSpace const &space, Predicates const &queries,
Kokkos::View<int *, DeviceType> indices,
Kokkos::View<int *, DeviceType> offset,
Kokkos::View<Query *, DeviceType> &fwd_queries,
Kokkos::View<int *, DeviceType> &fwd_ids,
Kokkos::View<int *, DeviceType> &fwd_ranks)
Kokkos::View<int *, DeviceType> &fwd_ids, Ranks &fwd_ranks)
{
int comm_rank;
MPI_Comm_rank(comm, &comm_rank);
Expand Down Expand Up @@ -624,13 +614,12 @@ void DistributedSearchTreeImpl<DeviceType>::forwardQueries(
}

template <typename DeviceType>
template <typename ExecutionSpace, typename OutputView>
template <typename ExecutionSpace, typename OutputView, typename Ranks,
typename Distances>
void DistributedSearchTreeImpl<DeviceType>::communicateResultsBack(
MPI_Comm comm, ExecutionSpace const &space, OutputView &out,
Kokkos::View<int *, DeviceType> offset,
Kokkos::View<int *, DeviceType> &ranks,
Kokkos::View<int *, DeviceType> &ids,
Kokkos::View<float *, DeviceType> *distances_ptr)
Kokkos::View<int *, DeviceType> offset, Ranks &ranks,
Kokkos::View<int *, DeviceType> &ids, Distances *distances_ptr)
{
int comm_rank;
MPI_Comm_rank(comm, &comm_rank);
Expand All @@ -641,7 +630,13 @@ void DistributedSearchTreeImpl<DeviceType>::communicateResultsBack(
// We are assuming here that if the same rank is related to multiple batches
// these batches appear consecutively. Hence, no reordering is necessary.
Distributor<DeviceType> distributor(comm);
int const n_imports = distributor.createFromSends(space, ranks, offset);
// FIXME Distributor::createFromSends takes two views of the same type by
// a const reference. There were two easy ways out, either take the views by
// value or cast at the callsite. I went with the latter. Proper fix
// involves more code cleanup in ArborX_DetailsDistributor.hpp than I am
// willing to do just now.
int const n_imports =
distributor.createFromSends(space, ranks, static_cast<Ranks>(offset));

Kokkos::View<int *, DeviceType> export_ranks(
Kokkos::ViewAllocateWithoutInitializing(ranks.label()), n_exports);
Expand Down Expand Up @@ -676,7 +671,7 @@ void DistributedSearchTreeImpl<DeviceType>::communicateResultsBack(

if (distances_ptr)
{
Kokkos::View<float *, DeviceType> &distances = *distances_ptr;
auto &distances = *distances_ptr;
Kokkos::View<float *, DeviceType> export_distances = distances;
Kokkos::View<float *, DeviceType> import_distances(
Kokkos::ViewAllocateWithoutInitializing(distances.label()), n_imports);
Expand All @@ -686,13 +681,12 @@ void DistributedSearchTreeImpl<DeviceType>::communicateResultsBack(
}

template <typename DeviceType>
template <typename ExecutionSpace, typename Predicates>
template <typename ExecutionSpace, typename Predicates, typename Indices,
typename Offset, typename Ranks>
void DistributedSearchTreeImpl<DeviceType>::filterResults(
ExecutionSpace const &space, Predicates const &queries,
Kokkos::View<float *, DeviceType> distances,
Kokkos::View<int *, DeviceType> &indices,
Kokkos::View<int *, DeviceType> &offset,
Kokkos::View<int *, DeviceType> &ranks)
Kokkos::View<float *, DeviceType> distances, Indices &indices,
Offset &offset, Ranks &ranks)
{
using Access = AccessTraits<Predicates, PredicatesTag>;
int const n_queries = Access::size(queries);
Expand Down

0 comments on commit ffed0fe

Please # to comment.