From 9b59272cdc8db3c95b40745976ddb1d8f480f7ac Mon Sep 17 00:00:00 2001 From: Seunghwa Kang Date: Fri, 25 Feb 2022 16:57:45 -0800 Subject: [PATCH] bug fix --- cpp/include/cugraph/prims/extract_if_e.cuh | 30 ++++++++++++++-------- 1 file changed, 20 insertions(+), 10 deletions(-) diff --git a/cpp/include/cugraph/prims/extract_if_e.cuh b/cpp/include/cugraph/prims/extract_if_e.cuh index 46b79d56d98..1c4ed54b220 100644 --- a/cpp/include/cugraph/prims/extract_if_e.cuh +++ b/cpp/include/cugraph/prims/extract_if_e.cuh @@ -1,5 +1,5 @@ /* - * Copyright (c) 2020-2021, NVIDIA CORPORATION. + * Copyright (c) 2020-2022, NVIDIA CORPORATION. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -156,6 +156,15 @@ extract_if_e(raft::handle_t const& handle, auto matrix_partition = matrix_partition_device_view_t( graph_view.get_matrix_partition_view(i)); + + auto matrix_partition_row_value_input = adj_matrix_row_value_input; + auto matrix_partition_col_value_input = adj_matrix_col_value_input; + if constexpr (GraphViewType::is_adj_matrix_transposed) { + matrix_partition_col_value_input.set_local_adj_matrix_partition_idx(i); + } else { + matrix_partition_row_value_input.set_local_adj_matrix_partition_idx(i); + } + detail::decompress_matrix_partition_to_edgelist( handle, matrix_partition, @@ -169,15 +178,16 @@ extract_if_e(raft::handle_t const& handle, edgelist_majors.begin(), edgelist_minors.begin(), (*edgelist_weights).begin())); cur_size += static_cast(thrust::distance( edge_first + cur_size, - thrust::remove_if( - handle.get_thrust_policy(), - edge_first + cur_size, - edge_first + cur_size + edgelist_edge_counts[i], - detail::call_e_op_t{ - matrix_partition, adj_matrix_row_value_input, adj_matrix_col_value_input, e_op}))); + thrust::remove_if(handle.get_thrust_policy(), + edge_first + cur_size, + edge_first + cur_size + edgelist_edge_counts[i], + detail::call_e_op_t{matrix_partition, + matrix_partition_row_value_input, + matrix_partition_col_value_input, + e_op}))); } else { auto edge_first = thrust::make_zip_iterator( thrust::make_tuple(edgelist_majors.begin(), edgelist_minors.begin()));