diff --git a/src/federatedlearning/main.py b/src/federatedlearning/main.py index af3b918..643d19b 100644 --- a/src/federatedlearning/main.py +++ b/src/federatedlearning/main.py @@ -30,6 +30,7 @@ ) from federatedlearning.server.aggregations.aggregators import ( average_weights, + krum, median_weights, ) from federatedlearning.server.inferencing import inference @@ -283,6 +284,10 @@ def main(cfg: DictConfig) -> float: # noqa: C901 global_weights = average_weights(local_weights) elif cfg.federatedlearning.aggregation == "median": global_weights = median_weights(local_weights) + elif cfg.federatedlearning.aggregation == "krum": + global_weights = krum( + local_weights, cfg.federatedlearning.num_byzantines + ) else: global_weights = average_weights(local_weights) # Save updated global model weights