Skip to content

Commit

Permalink
[code_style fix] graph_brpc_server cpplint (#49462)
Browse files Browse the repository at this point in the history
  • Loading branch information
wangzhen38 authored Jan 9, 2023
1 parent 36c6c58 commit d4b3bfa
Show file tree
Hide file tree
Showing 2 changed files with 100 additions and 75 deletions.
126 changes: 70 additions & 56 deletions paddle/fluid/distributed/ps/service/graph_brpc_server.cc
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@

#include "paddle/fluid/distributed/ps/service/graph_brpc_server.h"

#include <string>
#include <thread> // NOLINT
#include <utility>

Expand Down Expand Up @@ -125,9 +126,9 @@ int32_t GraphBrpcService::clear_nodes(Table *table,
const PsRequestMessage &request,
PsResponseMessage &response,
brpc::Controller *cntl) {
int type_id = *(int *)(request.params(0).c_str());
int idx_ = *(int *)(request.params(1).c_str());
((GraphTable *)table)->clear_nodes(type_id, idx_);
int type_id = std::stoi(request.params(0).c_str());
int idx_ = std::stoi(request.params(1).c_str());
(reinterpret_cast<GraphTable *>(table))->clear_nodes(type_id, idx_);
return 0;
}

Expand All @@ -142,14 +143,16 @@ int32_t GraphBrpcService::add_graph_node(Table *table,
return 0;
}

int idx_ = *(int *)(request.params(0).c_str());
int idx_ = std::stoi(request.params(0).c_str());
size_t node_num = request.params(1).size() / sizeof(int64_t);
uint64_t *node_data = (uint64_t *)(request.params(1).c_str());
const uint64_t *node_data =
reinterpret_cast<const uint64_t *>(request.params(1).c_str());
std::vector<uint64_t> node_ids(node_data, node_data + node_num);
std::vector<bool> is_weighted_list;
if (request.params_size() == 3) {
size_t weight_list_size = request.params(2).size() / sizeof(bool);
bool *is_weighted_buffer = (bool *)(request.params(2).c_str());
const bool *is_weighted_buffer =
reinterpret_cast<const bool *>(request.params(2).c_str());
is_weighted_list = std::vector<bool>(is_weighted_buffer,
is_weighted_buffer + weight_list_size);
}
Expand All @@ -161,7 +164,8 @@ int32_t GraphBrpcService::add_graph_node(Table *table,
// weight_list_size);
// }

((GraphTable *)table)->add_graph_node(idx_, node_ids, is_weighted_list);
(reinterpret_cast<GraphTable *>(table))
->add_graph_node(idx_, node_ids, is_weighted_list);
return 0;
}
int32_t GraphBrpcService::remove_graph_node(Table *table,
Expand All @@ -176,12 +180,13 @@ int32_t GraphBrpcService::remove_graph_node(Table *table,
"remove_graph_node request requires at least 2 arguments");
return 0;
}
int idx_ = *(int *)(request.params(0).c_str());
int idx_ = std::stoi(request.params(0).c_str());
size_t node_num = request.params(1).size() / sizeof(uint64_t);
uint64_t *node_data = (uint64_t *)(request.params(1).c_str());
const uint64_t *node_data =
reinterpret_cast<const uint64_t *>(request.params(1).c_str());
std::vector<uint64_t> node_ids(node_data, node_data + node_num);

((GraphTable *)table)->remove_graph_node(idx_, node_ids);
(reinterpret_cast<GraphTable *>(table))->remove_graph_node(idx_, node_ids);
return 0;
}
int32_t GraphBrpcServer::Port() { return _server.listen_address().port; }
Expand Down Expand Up @@ -338,7 +343,7 @@ int32_t GraphBrpcService::StopServer(Table *table,
const PsRequestMessage &request,
PsResponseMessage &response,
brpc::Controller *cntl) {
GraphBrpcServer *p_server = (GraphBrpcServer *)_server;
GraphBrpcServer *p_server = reinterpret_cast<GraphBrpcServer *>(_server);
std::thread t_stop([p_server]() {
p_server->Stop();
LOG(INFO) << "Server Stoped";
Expand Down Expand Up @@ -375,14 +380,14 @@ int32_t GraphBrpcService::pull_graph_list(Table *table,
response, -1, "pull_graph_list request requires at least 5 arguments");
return 0;
}
int type_id = *(int *)(request.params(0).c_str());
int idx = *(int *)(request.params(1).c_str());
int start = *(int *)(request.params(2).c_str());
int size = *(int *)(request.params(3).c_str());
int step = *(int *)(request.params(4).c_str());
int type_id = std::stoi(request.params(0).c_str());
int idx = std::stoi(request.params(1).c_str());
int start = std::stoi(request.params(2).c_str());
int size = std::stoi(request.params(3).c_str());
int step = std::stoi(request.params(4).c_str());
std::unique_ptr<char[]> buffer;
int actual_size;
((GraphTable *)table)
(reinterpret_cast<GraphTable *>(table))
->pull_graph_list(
type_id, idx, start, size, buffer, actual_size, false, step);
cntl->response_attachment().append(buffer.get(), actual_size);
Expand All @@ -401,14 +406,16 @@ int32_t GraphBrpcService::graph_random_sample_neighbors(
"graph_random_sample_neighbors request requires at least 3 arguments");
return 0;
}
int idx_ = *(int *)(request.params(0).c_str());
int idx_ = std::stoi(request.params(0).c_str());
size_t node_num = request.params(1).size() / sizeof(uint64_t);
uint64_t *node_data = (uint64_t *)(request.params(1).c_str());
int sample_size = *(int *)(request.params(2).c_str());
bool need_weight = *(bool *)(request.params(3).c_str());
uint64_t *node_data = (uint64_t *)(request.params(1).c_str()); // NOLINT
const int sample_size =
*reinterpret_cast<const int *>(request.params(2).c_str());
const bool need_weight =
*reinterpret_cast<const bool *>(request.params(3).c_str());
std::vector<std::shared_ptr<char>> buffers(node_num);
std::vector<int> actual_sizes(node_num, 0);
((GraphTable *)table)
(reinterpret_cast<GraphTable *>(table))
->random_sample_neighbors(
idx_, node_data, sample_size, buffers, actual_sizes, need_weight);

Expand All @@ -425,18 +432,18 @@ int32_t GraphBrpcService::graph_random_sample_nodes(
const PsRequestMessage &request,
PsResponseMessage &response,
brpc::Controller *cntl) {
int type_id = *(int *)(request.params(0).c_str());
int idx_ = *(int *)(request.params(1).c_str());
size_t size = *(uint64_t *)(request.params(2).c_str());
int type_id = std::stoi(request.params(0).c_str());
int idx_ = std::stoi(request.params(1).c_str());
size_t size = std::stoull(request.params(2).c_str());
// size_t size = *(int64_t *)(request.params(0).c_str());
std::unique_ptr<char[]> buffer;
int actual_size;
if (((GraphTable *)table)
->random_sample_nodes(type_id, idx_, size, buffer, actual_size) ==
0) {
if (reinterpret_cast<GraphTable *>(table)->random_sample_nodes(
type_id, idx_, size, buffer, actual_size) == 0) {
cntl->response_attachment().append(buffer.get(), actual_size);
} else
} else {
cntl->response_attachment().append(NULL, 0);
}

return 0;
}
Expand All @@ -453,9 +460,10 @@ int32_t GraphBrpcService::graph_get_node_feat(Table *table,
"graph_get_node_feat request requires at least 3 arguments");
return 0;
}
int idx_ = *(int *)(request.params(0).c_str());
int idx_ = std::stoi(request.params(0).c_str());
size_t node_num = request.params(1).size() / sizeof(uint64_t);
uint64_t *node_data = (uint64_t *)(request.params(1).c_str());
const uint64_t *node_data =
reinterpret_cast<const uint64_t *>(request.params(1).c_str());
std::vector<uint64_t> node_ids(node_data, node_data + node_num);

std::vector<std::string> feature_names =
Expand All @@ -464,7 +472,8 @@ int32_t GraphBrpcService::graph_get_node_feat(Table *table,
std::vector<std::vector<std::string>> feature(
feature_names.size(), std::vector<std::string>(node_num));

((GraphTable *)table)->get_node_feat(idx_, node_ids, feature_names, feature);
(reinterpret_cast<GraphTable *>(table))
->get_node_feat(idx_, node_ids, feature_names, feature);

for (size_t feat_idx = 0; feat_idx < feature_names.size(); ++feat_idx) {
for (size_t node_idx = 0; node_idx < node_num; ++node_idx) {
Expand Down Expand Up @@ -492,20 +501,21 @@ int32_t GraphBrpcService::sample_neighbors_across_multi_servers(
return 0;
}

int idx_ = *(int *)(request.params(0).c_str());
int idx_ = std::stoi(request.params(0).c_str());
size_t node_num = request.params(1).size() / sizeof(uint64_t);
uint64_t *node_data = (uint64_t *)(request.params(1).c_str());
int sample_size = *(int *)(request.params(2).c_str());
bool need_weight = *(bool *)(request.params(3).c_str());
const uint64_t *node_data =
reinterpret_cast<const uint64_t *>(request.params(1).c_str());
int sample_size = std::stoi(request.params(2).c_str());
bool need_weight = std::stoi(request.params(3).c_str());

std::vector<int> request2server;
std::vector<int> server2request(server_size, -1);
std::vector<uint64_t> local_id;
std::vector<int> local_query_idx;
size_t rank = GetRank();
for (size_t query_idx = 0; query_idx < node_num; ++query_idx) {
int server_index =
((GraphTable *)table)->get_server_index_by_id(node_data[query_idx]);
int server_index = (reinterpret_cast<GraphTable *>(table))
->get_server_index_by_id(node_data[query_idx]);
if (server2request[server_index] == -1) {
server2request[server_index] = request2server.size();
request2server.push_back(server_index);
Expand All @@ -514,10 +524,10 @@ int32_t GraphBrpcService::sample_neighbors_across_multi_servers(
if (server2request[rank] != -1) {
auto pos = server2request[rank];
std::swap(request2server[pos],
request2server[(int)request2server.size() - 1]);
request2server[static_cast<int>(request2server.size()) - 1]);
server2request[request2server[pos]] = pos;
server2request[request2server[(int)request2server.size() - 1]] =
request2server.size() - 1;
server2request[request2server[static_cast<int>(request2server.size()) -
1]] = request2server.size() - 1;
}
size_t request_call_num = request2server.size();
std::vector<std::shared_ptr<char>> local_buffers;
Expand All @@ -526,8 +536,8 @@ int32_t GraphBrpcService::sample_neighbors_across_multi_servers(
std::vector<std::vector<uint64_t>> node_id_buckets(request_call_num);
std::vector<std::vector<int>> query_idx_buckets(request_call_num);
for (size_t query_idx = 0; query_idx < node_num; ++query_idx) {
int server_index =
((GraphTable *)table)->get_server_index_by_id(node_data[query_idx]);
int server_index = (reinterpret_cast<GraphTable *>(table))
->get_server_index_by_id(node_data[query_idx]);
int request_idx = server2request[server_index];
node_id_buckets[request_idx].push_back(node_data[query_idx]);
query_idx_buckets[request_idx].push_back(query_idx);
Expand All @@ -550,7 +560,7 @@ int32_t GraphBrpcService::sample_neighbors_across_multi_servers(
request_call_num](void *done) {
local_fut.get();
std::vector<int> actual_size;
auto *closure = (DownpourBrpcClosure *)done;
auto *closure = reinterpret_cast<DownpourBrpcClosure *>(done);
std::vector<std::unique_ptr<butil::IOBufBytesIterator>> res(
remote_call_num);
size_t fail_num = 0;
Expand Down Expand Up @@ -610,17 +620,19 @@ int32_t GraphBrpcService::sample_neighbors_across_multi_servers(
closure->request(request_idx)->set_client_id(rank);
size_t node_num = node_id_buckets[request_idx].size();

closure->request(request_idx)->add_params((char *)&idx_, sizeof(int));
closure->request(request_idx)
->add_params(reinterpret_cast<char *>(&idx_), sizeof(int));

closure->request(request_idx)
->add_params((char *)node_id_buckets[request_idx].data(),
sizeof(uint64_t) * node_num);
->add_params(
reinterpret_cast<char *>(node_id_buckets[request_idx].data()),
sizeof(uint64_t) * node_num);
closure->request(request_idx)
->add_params((char *)&sample_size, sizeof(int));
->add_params(reinterpret_cast<char *>(&sample_size), sizeof(int));
closure->request(request_idx)
->add_params((char *)&need_weight, sizeof(bool));
PsService_Stub rpc_stub(
((GraphBrpcServer *)GetServer())->GetCmdChannel(server_index));
->add_params(reinterpret_cast<char *>(&need_weight), sizeof(bool));
PsService_Stub rpc_stub((reinterpret_cast<GraphBrpcServer *>(GetServer())
->GetCmdChannel(server_index)));
// GraphPsService_Stub rpc_stub =
// getServiceStub(GetCmdChannel(server_index));
closure->cntl(request_idx)->set_log_id(butil::gettimeofday_ms());
Expand All @@ -630,7 +642,7 @@ int32_t GraphBrpcService::sample_neighbors_across_multi_servers(
closure);
}
if (server2request[rank] != -1) {
((GraphTable *)table)
(reinterpret_cast<GraphTable *>(table))
->random_sample_neighbors(idx_,
node_id_buckets.back().data(),
sample_size,
Expand All @@ -655,10 +667,11 @@ int32_t GraphBrpcService::graph_set_node_feat(Table *table,
"graph_set_node_feat request requires at least 3 arguments");
return 0;
}
int idx_ = *(int *)(request.params(0).c_str());
int idx_ = std::stoi(request.params(0).c_str());

size_t node_num = request.params(1).size() / sizeof(uint64_t);
uint64_t *node_data = (uint64_t *)(request.params(1).c_str());
const uint64_t *node_data =
reinterpret_cast<const uint64_t *>(request.params(1).c_str());
std::vector<uint64_t> node_ids(node_data, node_data + node_num);

// std::vector<std::string> feature_names =
Expand All @@ -675,15 +688,16 @@ int32_t GraphBrpcService::graph_set_node_feat(Table *table,

for (size_t feat_idx = 0; feat_idx < feature_names.size(); ++feat_idx) {
for (size_t node_idx = 0; node_idx < node_num; ++node_idx) {
size_t feat_len = *(size_t *)(buffer);
const size_t feat_len = *reinterpret_cast<const size_t *>(buffer);
buffer += sizeof(size_t);
auto feat = std::string(buffer, feat_len);
features[feat_idx][node_idx] = feat;
buffer += feat_len;
}
}

((GraphTable *)table)->set_node_feat(idx_, node_ids, feature_names, features);
(reinterpret_cast<GraphTable *>(table))
->set_node_feat(idx_, node_ids, feature_names, features);

return 0;
}
Expand Down
Loading

0 comments on commit d4b3bfa

Please # to comment.