diff --git a/federatedscope/core/trainers/torch_trainer.py b/federatedscope/core/trainers/torch_trainer.py index 5c7e4c404..e3b36a242 100644 --- a/federatedscope/core/trainers/torch_trainer.py +++ b/federatedscope/core/trainers/torch_trainer.py @@ -289,6 +289,7 @@ def _hook_on_batch_end(self, ctx): ctx.y_true = None ctx.y_prob = None + def _hook_on_fit_end(self, ctx): """Evaluate metrics. @@ -302,6 +303,7 @@ def _hook_on_fit_end(self, ctx): results = self.metric_calculator.eval(ctx) setattr(ctx, 'eval_metrics', results) + def save_model(self, path, cur_round=-1): assert self.ctx.model is not None diff --git a/federatedscope/core/worker/client.py b/federatedscope/core/worker/client.py index 715df910c..ba6b2743d 100644 --- a/federatedscope/core/worker/client.py +++ b/federatedscope/core/worker/client.py @@ -444,6 +444,15 @@ def callback_funcs_for_finish(self, message: Message): self.trainer.update(message.content, strict=self._cfg.federate.share_local_model) + # TODO: more elegant here + # Save final prediction result + if self._cfg.data.type == 'cikmcup': + # Evaluate + self.trainer.evaluate(target_data_split_name='test') + # Save results + self.trainer.save_prediction(self.ID, self._cfg.model.task) + logger.info(f"Client #{self.ID} finished saving prediction results.") + self._monitor.finish_fl() def callback_funcs_for_converged(self, message: Message): diff --git a/federatedscope/gfl/baseline/fedavg_gin_minibatch_on_cikmcup.yaml b/federatedscope/gfl/baseline/fedavg_gin_minibatch_on_cikmcup.yaml index 957684f3d..f49950054 100644 --- a/federatedscope/gfl/baseline/fedavg_gin_minibatch_on_cikmcup.yaml +++ b/federatedscope/gfl/baseline/fedavg_gin_minibatch_on_cikmcup.yaml @@ -9,6 +9,7 @@ federate: make_global_eval: False total_round_num: 100 share_local_model: False + client_num: 13 data: root: data/ type: cikmcup diff --git a/federatedscope/gfl/baseline/isolated_gin_minibatch_on_cikmcup.yaml b/federatedscope/gfl/baseline/isolated_gin_minibatch_on_cikmcup.yaml index 7880f4f3b..f4db08333 100644 --- a/federatedscope/gfl/baseline/isolated_gin_minibatch_on_cikmcup.yaml +++ b/federatedscope/gfl/baseline/isolated_gin_minibatch_on_cikmcup.yaml @@ -10,6 +10,7 @@ federate: make_global_eval: False total_round_num: 10 share_local_model: False + client_num: 13 data: batch_size: 64 root: data/ diff --git a/federatedscope/gfl/dataset/cikm_cup.py b/federatedscope/gfl/dataset/cikm_cup.py index 4240b2a07..3dfeb730d 100644 --- a/federatedscope/gfl/dataset/cikm_cup.py +++ b/federatedscope/gfl/dataset/cikm_cup.py @@ -7,7 +7,7 @@ logger = logging.getLogger(__name__) class CIKMCUPDataset(InMemoryDataset): - name = 'CIKM_CUP' + name = 'CIKM22Competition' def __init__(self, root): super(CIKMCUPDataset, self).__init__(root) diff --git a/federatedscope/gfl/trainer/graphtrainer.py b/federatedscope/gfl/trainer/graphtrainer.py index 4b4f547ad..07c746a74 100644 --- a/federatedscope/gfl/trainer/graphtrainer.py +++ b/federatedscope/gfl/trainer/graphtrainer.py @@ -1,4 +1,7 @@ import logging +import os + +import numpy as np from federatedscope.core.monitors import Monitor from federatedscope.register import register_trainer @@ -8,6 +11,10 @@ class GraphMiniBatchTrainer(GeneralTorchTrainer): + def _hook_on_fit_start_init(self, ctx): + super()._hook_on_fit_start_init(ctx) + setattr(ctx, "{}_y_inds".format(ctx.cur_data_split), []) + def _hook_on_batch_forward(self, ctx): batch = ctx.data_batch.to(ctx.device) pred = ctx.model(batch) @@ -24,6 +31,14 @@ def _hook_on_batch_forward(self, ctx): ctx.y_true = label ctx.y_prob = pred + # record the index of the ${MODE} samples + if hasattr(ctx.data_batch, 'data_index'): + setattr( + ctx, + f'{ctx.cur_data_split}_y_inds', + ctx.get(f'{ctx.cur_data_split}_y_inds') + ctx.data_batch.data_index.detach().cpu().numpy().tolist() + ) + def _hook_on_batch_forward_flop_count(self, ctx): if not isinstance(self.ctx.monitor, Monitor): logger.warning( @@ -69,6 +84,21 @@ def _hook_on_batch_forward_flop_count(self, ctx): self.ctx.monitor.total_flops += self.ctx.monitor.flops_per_sample * \ ctx.batch_size + def save_prediction(self, client_id, task_type): + y_inds, y_probs = self.ctx.test_y_inds, self.ctx.test_y_prob + os.makedirs('prediction', exist_ok=True) + + # TODO: more feasible, for now we hard code it for cikmcup + y_preds = np.argmax(y_probs, axis=-1) if 'classification' in task_type.lower() else y_probs + + with open('prediction/prediction.csv', 'a') as file: + for y_ind, y_pred in zip(y_inds, y_preds): + if 'classification' in task_type.lower(): + line = [client_id, y_ind] + [y_pred] + else: + line = [client_id, y_ind] + list(y_pred) + file.write(','.join([str(_) for _ in line]) + '\n') + def call_graph_level_trainer(trainer_type): if trainer_type == 'graphminibatch_trainer':