Skip to content

Commit

Permalink
Add feedback during training of xgb (#484)
Browse files Browse the repository at this point in the history
  • Loading branch information
qbc2016 authored Jan 4, 2023
1 parent b3f823d commit 709c618
Show file tree
Hide file tree
Showing 3 changed files with 60 additions and 40 deletions.
65 changes: 34 additions & 31 deletions federatedscope/vertical_fl/xgb_base/worker/Test_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,38 +42,15 @@ def test_for_root(self, tree_num):
def test_for_node(self, tree_num, node_num):
if node_num >= 2**self.client.max_tree_depth - 1:
if tree_num + 1 < self.client.num_of_trees:
# TODO: add feedback during training
logger.info(f'----------- Building a new tree (Tree '
f'#{tree_num + 1}) -------------')
# build the next tree
self.client.fs.compute_for_root(tree_num + 1)

if (tree_num + 1) % self.client._cfg.eval.freq == 0:
self.feedback_results(tree_num)
else:
logger.info(f'----------- Building a new tree (Tree '
f'#{tree_num + 1}) -------------')
# build the next tree
self.client.fs.compute_for_root(tree_num + 1)
else:
metrics = self.evaluation()
self.client.comm_manager.send(
Message(msg_type='test_result',
sender=self.client.ID,
state=self.client.state,
receiver=self.client.server_id,
content=(tree_num, metrics)))

self.client.comm_manager.send(
Message(msg_type='send_feature_importance',
sender=self.client.ID,
state=self.client.state,
receiver=[
each for each in list(
self.client.comm_manager.neighbors.keys())
if each != self.client.server_id
and each != self.client.ID
],
content='None'))
self.client.comm_manager.send(
Message(msg_type='feature_importance',
sender=self.client.ID,
state=self.client.state,
receiver=self.client.server_id,
content=self.client.feature_importance))
self.feedback_results(tree_num)
elif self.client.tree_list[tree_num][node_num].weight:
self.client.test_result += self.client.tree_list[tree_num][
node_num].indicator * self.client.tree_list[tree_num][
Expand Down Expand Up @@ -112,3 +89,29 @@ def callback_func_for_LR(self, message: Message):
2].indicator = self.client.tree_list[
tree_num][node_num].indicator * R
self.test_for_node(tree_num, node_num + 1)

def feedback_results(self, tree_num):
metrics = self.evaluation()
self.client.comm_manager.send(
Message(msg_type='test_result',
sender=self.client.ID,
state=self.client.state,
receiver=self.client.server_id,
content=(tree_num, metrics)))
self.client.comm_manager.send(
Message(
msg_type='send_feature_importance',
sender=self.client.ID,
state=self.client.state,
receiver=[
each
for each in list(self.client.comm_manager.neighbors.keys())
if each != self.client.server_id and each != self.client.ID
],
content='None'))
self.client.comm_manager.send(
Message(msg_type='feature_importance',
sender=self.client.ID,
state=self.client.state,
receiver=self.client.server_id,
content=self.client.feature_importance))
8 changes: 8 additions & 0 deletions federatedscope/vertical_fl/xgb_base/worker/XGBClient.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,6 +62,7 @@ def __init__(self,
self.callback_func_for_compute_next_node)
self.register_handlers('send_feature_importance',
self.callback_func_for_send_feature_importance)
self.register_handlers('continue', self.callback_func_for_continue)
self.register_handlers('finish', self.callback_func_for_finish)

def _init_data_related_var(self):
Expand Down Expand Up @@ -234,5 +235,12 @@ def callback_func_for_send_feature_importance(self, message: Message):
receiver=self.server_id,
content=self.feature_importance))

def callback_func_for_continue(self, message: Message):
tree_num = message.content
logger.info(f'---------- Building a new tree (Tree '
f'#{tree_num + 1}) -------------')
# build the next tree
self.fs.compute_for_root(tree_num + 1)

def callback_func_for_finish(self, message: Message):
pass
27 changes: 18 additions & 9 deletions federatedscope/vertical_fl/xgb_base/worker/XGBServer.py
Original file line number Diff line number Diff line change
Expand Up @@ -88,16 +88,25 @@ def callback_func_for_feature_importance(self, message: Message):
role='Server #',
forms=self._cfg.eval.report)
formatted_logs['feature_importance'] = self.feature_importance_dict
self.feature_importance_dict = {}
logger.info(formatted_logs)
self.comm_manager.send(
Message(msg_type='finish',
sender=self.ID,
receiver=list(
self.comm_manager.get_neighbors().keys()),
state=self.state,
content='None'))
# jump out running
self.state = self.total_round_num + 1
if self.tree_num + 1 < self.num_of_trees:
self.comm_manager.send(
Message(msg_type='continue',
sender=self.ID,
receiver=self.num_of_parties,
state=self.state,
content=self.tree_num))
else:
self.comm_manager.send(
Message(msg_type='finish',
sender=self.ID,
receiver=list(
self.comm_manager.get_neighbors().keys()),
state=self.state,
content='None'))
# jump out running
self.state = self.total_round_num + 1

def callback_func_for_test_result(self, message: Message):
self.tree_num, self.metrics = message.content

0 comments on commit 709c618

Please # to comment.