From db54effb7bbb782218e4f300d3063a3e30064031 Mon Sep 17 00:00:00 2001 From: futabato <01futabato10@gmail.com> Date: Fri, 26 Jul 2024 17:59:13 +0900 Subject: [PATCH] :white_check_mark: Add krum test --- tests/test_aggregator.py | 59 ++++++++++++++++++++++++++++++++++++++++ 1 file changed, 59 insertions(+) diff --git a/tests/test_aggregator.py b/tests/test_aggregator.py index 6bbb9df..9c7ee30 100644 --- a/tests/test_aggregator.py +++ b/tests/test_aggregator.py @@ -3,6 +3,7 @@ import torch from federatedlearning.server.aggregations.aggregators import ( average_weights, + krum, median_weights, ) @@ -87,5 +88,63 @@ def test_median_weights(self) -> None: ) +class TestKrumAlgorithm(unittest.TestCase): + def test_krum_basic(self) -> None: + weights = [ + { + "a": torch.tensor([1.0, 2.0]), + "b": torch.tensor([3.0, 4.0]), + }, # Benign + { + "a": torch.tensor([1.1, 2.1]), + "b": torch.tensor([3.1, 4.1]), + }, # Benign + { + "a": torch.tensor([10.0, 20.0]), + "b": torch.tensor([30.0, 40.0]), + }, # Malicious + ] + f = 1 + # Weights to be considered most normal. + expected_result: dict[str, torch.Tensor] = { + "a": torch.tensor([1.0, 2.0]), + "b": torch.tensor([3.0, 4.0]), + } + result: dict[str, torch.Tensor] = krum(weights, f) + for key in expected_result.keys(): + self.assertTrue( + torch.equal(expected_result[key], result[key]), + f"Weights median incorrectly for key '{key}'.", + ) + + def test_not_enough_weights(self) -> None: + weights = [ + {"a": torch.tensor([1.0, 2.0]), "b": torch.tensor([3.0, 4.0])} + ] + f = 1 + + with self.assertRaises(ValueError): + krum(weights, f) + + def test_all_same_weights(self) -> None: + weights = [ + {"a": torch.tensor([1.0, 1.0]), "b": torch.tensor([1.0, 1.0])}, + {"a": torch.tensor([1.0, 1.0]), "b": torch.tensor([1.0, 1.0])}, + {"a": torch.tensor([1.0, 1.0]), "b": torch.tensor([1.0, 1.0])}, + ] + f = 1 + # All weights are equal, so any of them can be chosen + expected_result: dict[str, torch.Tensor] = { + "a": torch.tensor([1.0, 1.0]), + "b": torch.tensor([1.0, 1.0]), + } + result: dict[str, torch.Tensor] = krum(weights, f) + for key in expected_result.keys(): + self.assertTrue( + torch.equal(expected_result[key], result[key]), + f"Weights median incorrectly for key '{key}'.", + ) + + if __name__ == "__main__": unittest.main()