Skip to content
New issue

Have a question about this project? # for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “#”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? # to your account

Modify files to support distributed mode for XGBoost #553

Merged
merged 4 commits into from
Mar 27, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions federatedscope/vertical_fl/trainer/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,6 +54,10 @@ def fetch_train_data(self, index=None):
feature_index, self.batch_x = self.dataloader.sample_feature(
self.batch_x)

# convert 'range' to 'list'
# to support gRPC protocols in distributed mode
batch_index = list(batch_index)

# If the complete trainset is used, we only need to get the slices
# from the complete feature order info according to the feature index,
# rather than re-ordering the instance
Expand Down
1 change: 1 addition & 0 deletions federatedscope/vertical_fl/xgb_base/worker/XGBClient.py
Original file line number Diff line number Diff line change
Expand Up @@ -75,6 +75,7 @@ def callback_func_for_model_para(self, message: Message):

# other clients receive the data-sample information
def callback_func_for_data_sample(self, message: Message):
self.state = message.state
batch_index, sender = message.content, message.sender
_, feature_order_info = self.trainer.fetch_train_data(
index=batch_index)
Expand Down
5 changes: 3 additions & 2 deletions federatedscope/vertical_fl/xgb_base/worker/train_wrapper.py
Original file line number Diff line number Diff line change
Expand Up @@ -68,7 +68,7 @@ def _find_and_send_split(self, split_ref, tree_num, node_num):
state=self.state,
receiver=[client_id],
content=(tree_num, node_num, split_ref,
split_child))
int(split_child)))
if client_id == self.ID:
self.callback_func_for_split(send_message)
else:
Expand Down Expand Up @@ -156,6 +156,7 @@ def callback_funcs_for_local_best_gain(self, message: Message):
client_id = message.sender
self.msg_buffer['train'][client_id] = (local_best_gain, improved_flag,
split_info)

if len(self.msg_buffer['train']) == self.client_num:
received_msg = copy.deepcopy(self.msg_buffer['train'])
self.msg_buffer['train'].clear()
Expand All @@ -171,7 +172,7 @@ def callback_funcs_for_local_best_gain(self, message: Message):
state=self.state,
receiver=[split_client_id],
content=(tree_num, node_num, split_ref,
split_child))
int(split_child)))
if split_client_id == self.ID:
self.callback_func_for_split(send_message)
else:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,31 +17,32 @@ distribute:
client_port: 50052
role: 'client'
data_idx: 1
model:
type: xgb_tree
lambda_: 0.1
gamma: 0
num_of_trees: 10
max_tree_depth: 3
data:
root: data/
type: credit
splits: [0.8, 0.2]
dataloader:
type: raw
batch_size: 2000
model:
type: lr
criterion:
type: CrossEntropyLoss
trainer:
type: verticaltrainer
train:
optimizer:
lambda_: 0.1
gamma: 0
num_of_trees: 10
max_tree_depth: 3
# learning rate for xgb model
eta: 0.5
vertical_dims: [5, 10]
xgb_base:
vertical:
use: True
use_bin: True
criterion:
type: CrossEntropyLoss
trainer:
type: none
dims: [5, 10]
algo: 'xgb'
data_size_for_debug: 2000
eval:
freq: 3
best_res_update_round_wise_key: test_loss
Original file line number Diff line number Diff line change
Expand Up @@ -17,31 +17,32 @@ distribute:
client_port: 50053
role: 'client'
data_idx: 2
model:
type: xgb_tree
lambda_: 0.1
gamma: 0
num_of_trees: 10
max_tree_depth: 3
data:
root: data/
type: credit
splits: [0.8, 0.2]
dataloader:
type: raw
batch_size: 2000
model:
type: lr
criterion:
type: CrossEntropyLoss
trainer:
type: verticaltrainer
train:
optimizer:
lambda_: 0.1
gamma: 0
num_of_trees: 10
max_tree_depth: 3
# learning rate for xgb model
eta: 0.5
vertical_dims: [5, 10]
xgb_base:
vertical:
use: True
use_bin: True
criterion:
type: CrossEntropyLoss
trainer:
type: none
dims: [5, 10]
algo: 'xgb'
data_size_for_debug: 2000
eval:
freq: 3
best_res_update_round_wise_key: test_loss
Original file line number Diff line number Diff line change
Expand Up @@ -15,31 +15,32 @@ distribute:
server_port: 50051
role: 'server'
data_idx: 0
model:
type: xgb_tree
lambda_: 0.1
gamma: 0
num_of_trees: 10
max_tree_depth: 3
data:
root: data/
type: credit
splits: [0.8, 0.2]
dataloader:
type: raw
batch_size: 2000
model:
type: lr
criterion:
type: CrossEntropyLoss
trainer:
type: verticaltrainer
train:
optimizer:
lambda_: 0.1
gamma: 0
num_of_trees: 10
max_tree_depth: 3
# learning rate for xgb model
eta: 0.5
vertical_dims: [5, 10]
xgb_base:
vertical:
use: True
use_bin: True
criterion:
type: CrossEntropyLoss
trainer:
type: none
dims: [5, 10]
algo: 'xgb'
data_size_for_debug: 2000
eval:
freq: 3
best_res_update_round_wise_key: test_loss