From 3427bc4d11776903a7939ca107630eee6153c12a Mon Sep 17 00:00:00 2001 From: qbc Date: Tue, 3 Jan 2023 16:09:35 +0800 Subject: [PATCH 1/2] add feedback during training of xgb --- .../vertical_fl/xgb_base/worker/Test_base.py | 65 ++++++++++--------- .../vertical_fl/xgb_base/worker/XGBClient.py | 8 +++ .../vertical_fl/xgb_base/worker/XGBServer.py | 27 +++++--- 3 files changed, 60 insertions(+), 40 deletions(-) diff --git a/federatedscope/vertical_fl/xgb_base/worker/Test_base.py b/federatedscope/vertical_fl/xgb_base/worker/Test_base.py index ef53f328c..37ce71f00 100644 --- a/federatedscope/vertical_fl/xgb_base/worker/Test_base.py +++ b/federatedscope/vertical_fl/xgb_base/worker/Test_base.py @@ -42,38 +42,15 @@ def test_for_root(self, tree_num): def test_for_node(self, tree_num, node_num): if node_num >= 2**self.client.max_tree_depth - 1: if tree_num + 1 < self.client.num_of_trees: - # TODO: add feedback during training - logger.info(f'----------- Building a new tree (Tree ' - f'#{tree_num + 1}) -------------') - # build the next tree - self.client.fs.compute_for_root(tree_num + 1) - + if (tree_num + 1) % self.client._cfg.eval.freq == 0: + self.feedback_results(tree_num) + else: + logger.info(f'----------- Building a new tree (Tree ' + f'#{tree_num + 1}) -------------') + # build the next tree + self.client.fs.compute_for_root(tree_num + 1) else: - metrics = self.evaluation() - self.client.comm_manager.send( - Message(msg_type='test_result', - sender=self.client.ID, - state=self.client.state, - receiver=self.client.server_id, - content=(tree_num, metrics))) - - self.client.comm_manager.send( - Message(msg_type='send_feature_importance', - sender=self.client.ID, - state=self.client.state, - receiver=[ - each for each in list( - self.client.comm_manager.neighbors.keys()) - if each != self.client.server_id - and each != self.client.ID - ], - content='None')) - self.client.comm_manager.send( - Message(msg_type='feature_importance', - sender=self.client.ID, - state=self.client.state, - receiver=self.client.server_id, - content=self.client.feature_importance)) + self.feedback_results(tree_num) elif self.client.tree_list[tree_num][node_num].weight: self.client.test_result += self.client.tree_list[tree_num][ node_num].indicator * self.client.tree_list[tree_num][ @@ -112,3 +89,29 @@ def callback_func_for_LR(self, message: Message): 2].indicator = self.client.tree_list[ tree_num][node_num].indicator * R self.test_for_node(tree_num, node_num + 1) + + def feedback_results(self, tree_num): + metrics = self.evaluation() + self.client.comm_manager.send( + Message(msg_type='test_result', + sender=self.client.ID, + state=self.client.state, + receiver=self.client.server_id, + content=(tree_num, metrics))) + self.client.comm_manager.send( + Message( + msg_type='send_feature_importance', + sender=self.client.ID, + state=self.client.state, + receiver=[ + each + for each in list(self.client.comm_manager.neighbors.keys()) + if each != self.client.server_id and each != self.client.ID + ], + content='None')) + self.client.comm_manager.send( + Message(msg_type='feature_importance', + sender=self.client.ID, + state=self.client.state, + receiver=self.client.server_id, + content=self.client.feature_importance)) diff --git a/federatedscope/vertical_fl/xgb_base/worker/XGBClient.py b/federatedscope/vertical_fl/xgb_base/worker/XGBClient.py index b34c68ba9..06c07e064 100644 --- a/federatedscope/vertical_fl/xgb_base/worker/XGBClient.py +++ b/federatedscope/vertical_fl/xgb_base/worker/XGBClient.py @@ -62,6 +62,7 @@ def __init__(self, self.callback_func_for_compute_next_node) self.register_handlers('send_feature_importance', self.callback_func_for_send_feature_importance) + self.register_handlers('continue', self.callback_func_for_continue) self.register_handlers('finish', self.callback_func_for_finish) def _init_data_related_var(self): @@ -234,5 +235,12 @@ def callback_func_for_send_feature_importance(self, message: Message): receiver=self.server_id, content=self.feature_importance)) + def callback_func_for_continue(self, message: Message): + tree_num = message.content + logger.info(f'----------- Building a new tree (Tree ' + f'#{tree_num + 1}) -------------') + # build the next tree + self.fs.compute_for_root(tree_num + 1) + def callback_func_for_finish(self, message: Message): pass diff --git a/federatedscope/vertical_fl/xgb_base/worker/XGBServer.py b/federatedscope/vertical_fl/xgb_base/worker/XGBServer.py index 5589d5964..b1f488c98 100644 --- a/federatedscope/vertical_fl/xgb_base/worker/XGBServer.py +++ b/federatedscope/vertical_fl/xgb_base/worker/XGBServer.py @@ -88,16 +88,25 @@ def callback_func_for_feature_importance(self, message: Message): role='Server #', forms=self._cfg.eval.report) formatted_logs['feature_importance'] = self.feature_importance_dict + self.feature_importance_dict = {} logger.info(formatted_logs) - self.comm_manager.send( - Message(msg_type='finish', - sender=self.ID, - receiver=list( - self.comm_manager.get_neighbors().keys()), - state=self.state, - content='None')) - # jump out running - self.state = self.total_round_num + 1 + if self.tree_num + 1 < self.num_of_trees: + self.comm_manager.send( + Message(msg_type='continue', + sender=self.ID, + receiver=self.num_of_parties, + state=self.state, + content=self.tree_num)) + else: + self.comm_manager.send( + Message(msg_type='finish', + sender=self.ID, + receiver=list( + self.comm_manager.get_neighbors().keys()), + state=self.state, + content='None')) + # jump out running + self.state = self.total_round_num + 1 def callback_func_for_test_result(self, message: Message): self.tree_num, self.metrics = message.content From 602874e63684f167a240129c5ef77d102be7ca87 Mon Sep 17 00:00:00 2001 From: qbc Date: Tue, 3 Jan 2023 16:12:48 +0800 Subject: [PATCH 2/2] minor changes --- federatedscope/vertical_fl/xgb_base/worker/XGBClient.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/federatedscope/vertical_fl/xgb_base/worker/XGBClient.py b/federatedscope/vertical_fl/xgb_base/worker/XGBClient.py index 06c07e064..f49514a68 100644 --- a/federatedscope/vertical_fl/xgb_base/worker/XGBClient.py +++ b/federatedscope/vertical_fl/xgb_base/worker/XGBClient.py @@ -237,7 +237,7 @@ def callback_func_for_send_feature_importance(self, message: Message): def callback_func_for_continue(self, message: Message): tree_num = message.content - logger.info(f'----------- Building a new tree (Tree ' + logger.info(f'---------- Building a new tree (Tree ' f'#{tree_num + 1}) -------------') # build the next tree self.fs.compute_for_root(tree_num + 1)