Skip to content

Commit

Permalink
extract_if_e bug fix. (#2096)
Browse files Browse the repository at this point in the history
extract_if_e primitive bug fix.

This primitive takes a device_view of row(or column)_properties_t objects, and in the internal implementation, we need to properly set the local edge partition ID to properly retrieve row properties (if CSR-ish data structure is used).

This cod was missing and this PR fixes the bug.

Authors:
  - Seunghwa Kang (https://github.com/seunghwak)

Approvers:
  - Chuck Hastings (https://github.com/ChuckHastings)
  - Kumar Aatish (https://github.com/kaatish)

URL: #2096
  • Loading branch information
seunghwak authored Mar 1, 2022
1 parent cef4924 commit 3cac4b1
Showing 1 changed file with 20 additions and 10 deletions.
30 changes: 20 additions & 10 deletions cpp/include/cugraph/prims/extract_if_e.cuh
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
/*
* Copyright (c) 2020-2021, NVIDIA CORPORATION.
* Copyright (c) 2020-2022, NVIDIA CORPORATION.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
Expand Down Expand Up @@ -156,6 +156,15 @@ extract_if_e(raft::handle_t const& handle,
auto matrix_partition =
matrix_partition_device_view_t<vertex_t, edge_t, weight_t, GraphViewType::is_multi_gpu>(
graph_view.get_matrix_partition_view(i));

auto matrix_partition_row_value_input = adj_matrix_row_value_input;
auto matrix_partition_col_value_input = adj_matrix_col_value_input;
if constexpr (GraphViewType::is_adj_matrix_transposed) {
matrix_partition_col_value_input.set_local_adj_matrix_partition_idx(i);
} else {
matrix_partition_row_value_input.set_local_adj_matrix_partition_idx(i);
}

detail::decompress_matrix_partition_to_edgelist(
handle,
matrix_partition,
Expand All @@ -169,15 +178,16 @@ extract_if_e(raft::handle_t const& handle,
edgelist_majors.begin(), edgelist_minors.begin(), (*edgelist_weights).begin()));
cur_size += static_cast<size_t>(thrust::distance(
edge_first + cur_size,
thrust::remove_if(
handle.get_thrust_policy(),
edge_first + cur_size,
edge_first + cur_size + edgelist_edge_counts[i],
detail::call_e_op_t<GraphViewType,
AdjMatrixRowValueInputWrapper,
AdjMatrixColValueInputWrapper,
EdgeOp>{
matrix_partition, adj_matrix_row_value_input, adj_matrix_col_value_input, e_op})));
thrust::remove_if(handle.get_thrust_policy(),
edge_first + cur_size,
edge_first + cur_size + edgelist_edge_counts[i],
detail::call_e_op_t<GraphViewType,
AdjMatrixRowValueInputWrapper,
AdjMatrixColValueInputWrapper,
EdgeOp>{matrix_partition,
matrix_partition_row_value_input,
matrix_partition_col_value_input,
e_op})));
} else {
auto edge_first = thrust::make_zip_iterator(
thrust::make_tuple(edgelist_majors.begin(), edgelist_minors.begin()));
Expand Down

0 comments on commit 3cac4b1

Please # to comment.