Skip to content

Commit

Permalink
✅ Add krum test
Browse files Browse the repository at this point in the history
  • Loading branch information
futabato committed Jul 26, 2024
1 parent ebdb2ed commit db54eff
Showing 1 changed file with 59 additions and 0 deletions.
59 changes: 59 additions & 0 deletions tests/test_aggregator.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
import torch
from federatedlearning.server.aggregations.aggregators import (
average_weights,
krum,
median_weights,
)

Expand Down Expand Up @@ -87,5 +88,63 @@ def test_median_weights(self) -> None:
)


class TestKrumAlgorithm(unittest.TestCase):
def test_krum_basic(self) -> None:
weights = [
{
"a": torch.tensor([1.0, 2.0]),
"b": torch.tensor([3.0, 4.0]),
}, # Benign
{
"a": torch.tensor([1.1, 2.1]),
"b": torch.tensor([3.1, 4.1]),
}, # Benign
{
"a": torch.tensor([10.0, 20.0]),
"b": torch.tensor([30.0, 40.0]),
}, # Malicious
]
f = 1
# Weights to be considered most normal.
expected_result: dict[str, torch.Tensor] = {
"a": torch.tensor([1.0, 2.0]),
"b": torch.tensor([3.0, 4.0]),
}
result: dict[str, torch.Tensor] = krum(weights, f)
for key in expected_result.keys():
self.assertTrue(
torch.equal(expected_result[key], result[key]),
f"Weights median incorrectly for key '{key}'.",
)

def test_not_enough_weights(self) -> None:
weights = [
{"a": torch.tensor([1.0, 2.0]), "b": torch.tensor([3.0, 4.0])}
]
f = 1

with self.assertRaises(ValueError):
krum(weights, f)

def test_all_same_weights(self) -> None:
weights = [
{"a": torch.tensor([1.0, 1.0]), "b": torch.tensor([1.0, 1.0])},
{"a": torch.tensor([1.0, 1.0]), "b": torch.tensor([1.0, 1.0])},
{"a": torch.tensor([1.0, 1.0]), "b": torch.tensor([1.0, 1.0])},
]
f = 1
# All weights are equal, so any of them can be chosen
expected_result: dict[str, torch.Tensor] = {
"a": torch.tensor([1.0, 1.0]),
"b": torch.tensor([1.0, 1.0]),
}
result: dict[str, torch.Tensor] = krum(weights, f)
for key in expected_result.keys():
self.assertTrue(
torch.equal(expected_result[key], result[key]),
f"Weights median incorrectly for key '{key}'.",
)


if __name__ == "__main__":
unittest.main()

0 comments on commit db54eff

Please # to comment.