Skip to content
New issue

Have a question about this project? # for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “#”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? # to your account

Saving prediction results on cikm contest data #238

Merged
merged 9 commits into from
Jul 18, 2022
Merged
Show file tree
Hide file tree
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions federatedscope/core/trainers/torch_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -289,6 +289,7 @@ def _hook_on_batch_end(self, ctx):
ctx.y_true = None
ctx.y_prob = None


def _hook_on_fit_end(self, ctx):
"""Evaluate metrics.

Expand All @@ -302,6 +303,7 @@ def _hook_on_fit_end(self, ctx):
results = self.metric_calculator.eval(ctx)
setattr(ctx, 'eval_metrics', results)


def save_model(self, path, cur_round=-1):
assert self.ctx.model is not None

Expand Down
8 changes: 8 additions & 0 deletions federatedscope/core/worker/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -444,6 +444,14 @@ def callback_funcs_for_finish(self, message: Message):
self.trainer.update(message.content,
strict=self._cfg.federate.share_local_model)

# TODO: more elegant here
# Save final prediction result
if self._cfg.data.type == 'cikmcup':
# Evaluate
self.trainer.evaluate(target_data_split_name='test')
# Save results
self.trainer.save_prediction(self.ID, self._cfg.model.task)

self._monitor.finish_fl()

def callback_funcs_for_converged(self, message: Message):
Expand Down
24 changes: 24 additions & 0 deletions federatedscope/gfl/trainer/graphtrainer.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
import logging

import numpy as np

from federatedscope.core.monitors import Monitor
from federatedscope.register import register_trainer
from federatedscope.core.trainers import GeneralTorchTrainer
Expand All @@ -8,6 +10,10 @@


class GraphMiniBatchTrainer(GeneralTorchTrainer):
def _hook_on_fit_start_init(self, ctx):
super()._hook_on_fit_start_init()
setattr(ctx, "{}_y_inds".format(ctx.cur_data_split), [])

def _hook_on_batch_forward(self, ctx):
batch = ctx.data_batch.to(ctx.device)
pred = ctx.model(batch)
Expand All @@ -20,6 +26,14 @@ def _hook_on_batch_forward(self, ctx):
ctx.y_true = label
ctx.y_prob = pred

# record the index of the ${MODE} samples
if hasattr(batch, 'data_index'):
setattr(
ctx,
f'{ctx.cur_data_split}_y_inds',
ctx.get(f'{ctx.cur_data_split}_y_inds') + ctx.data_batch.data_index.cpu().numpy().tolist()
)

def _hook_on_batch_forward_flop_count(self, ctx):
if not isinstance(self.ctx.monitor, Monitor):
logger.warning(
Expand Down Expand Up @@ -65,6 +79,16 @@ def _hook_on_batch_forward_flop_count(self, ctx):
self.ctx.monitor.total_flops += self.ctx.monitor.flops_per_sample * \
ctx.batch_size

def save_prediction(self, client_id, task_type):
y_inds, y_probs = self.ctx.test_y_inds, self.ctx.test_y_prob
if 'classification' in task_type:
y_probs = np.argmax(y_probs, axis=-1)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

y_preds = np.argmax(y_probs, axis=-1)

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Modified accordingly

# TODO: more feasible, for now we hard code it for cikmcup
with open('prediction/round_{}.csv', 'a') as file:
for y_ind, y_prob in zip(y_inds, y_probs):
line = [client_id, y_ind] + list(y_prob)
file.write(','.join(line) + '\n')


def call_graph_level_trainer(trainer_type):
if trainer_type == 'graphminibatch_trainer':
Expand Down