-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathFederatedAggregator.py
144 lines (123 loc) · 4.95 KB
/
FederatedAggregator.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
# Ray
import ray
# Utils
import collections
# NumPy and PyTorch
import numpy as np
import ray.exceptions
import torch
def _is_numeric(x):
"""Return True if x is a numeric type (int, float, np.ndarray, torch.Tensor)."""
if isinstance(x, (int, float)):
return True
if isinstance(x, np.ndarray):
# Check if the dtype is numeric
return x.dtype.kind in ("i", "f") # int or float
if isinstance(x, torch.Tensor):
# Check if the dtype is numeric
return x.dtype in (torch.float16, torch.float32, torch.float64,
torch.int16, torch.int32, torch.int64)
return False
def _to_tensor(x):
""" Convert x to a torch.Tensor. """
if isinstance(x, (int, float)):
return torch.tensor(x, dtype=torch.float32)
elif isinstance(x, np.ndarray):
return torch.from_numpy(x).float()
elif isinstance(x, torch.Tensor):
return x.float()
else:
raise TypeError(f"[HEAD][WARN] Implicit conversion to tensor not supported for type {type(x)}.")
def _recursive_average(dict_list):
"""
From a list of dictionaries, compute the average of all numeric fields.
If a field is a dictionary, recursively compute the average of its fields.
"""
out = collections.OrderedDict()
all_keys = set()
for d in dict_list:
all_keys.update(d.keys())
for key in sorted(all_keys):
# Collect all values for this key
vals = [d[key] for d in dict_list if key in d]
# IF all vals are dictionaries -> recursively average
if all(isinstance(v, dict) for v in vals):
out[key] = _recursive_average(vals)
# If all vals are numeric -> convert to tensor and check shape
elif all(_is_numeric(v) for v in vals):
# Convert to tensor
ts = [_to_tensor(v) for v in vals]
shapes = [t.shape for t in ts]
if not all(s == shapes[0] for s in shapes):
# If shapes are different, skip
print(f"[HEAD][SKIP] Key {key}: different shapes.")
continue
# Calculate means
stacked = torch.stack(ts, dim=0)
mean_t = stacked.mean(dim=0)
out[key] = mean_t
else:
print(f"[HEAD][SKIP] Key {key}: not all numeric or all dictionaries.")
pass
return out
@ray.remote
class FederatedAggregator:
def __init__(self, nodes, timeout, EXCEPTIONS=None):
self.nodes = set(nodes)
self.failed_nodes = set()
self.timeout = timeout
self.EXCEPTIONS = EXCEPTIONS
def federated_averaging(self):
"""
Take the weights from each node, average them, and set the average as the new global weights.
In RLlib PPO, the actual weights are stored in weights["default_policy"].
"""
# Take the weights from each node
weights_list = []
successful_nodes = set()
# Collect weights from each node
aggr_tasks = {}
for (node_id, node_handle) in list(self.nodes):
task = node_handle.get_weights.remote()
aggr_tasks[task] = (node_id, node_handle)
# Wait for all tasks to finish
done, _ = ray.wait(
list(aggr_tasks.keys()),
timeout=self.timeout,
num_returns=len(aggr_tasks)
)
# Collect the results
for task in done:
node_id, node_handle = aggr_tasks[task]
try:
weights = ray.get(task)
weights_list.append(weights)
successful_nodes.add((node_id, node_handle))
except self.EXCEPTIONS as e:
print(f"[HEAD][WARN] Node {node_id} failed to return weights. Error: {e}")
self.nodes.remove((node_id, node_handle))
self.failed_nodes.add((node_id, node_handle))
print(f"[HEAD][INFO] Weights collected from {len(successful_nodes)} nodes.")
if not weights_list:
print("[HEAD][WARN] No weights collected. Skipping aggregation.")
return None
# Average the weights
default_policy_list = []
for w in weights_list:
# Check if the node has a "default_policy" field
if "default_policy" in w:
default_policy_list.append(w["default_policy"])
else:
print("[HEAD][WARN] Nothing to aggregate in this node.")
if not default_policy_list:
# If no default_policy found, skip
print("[HAED][WARN] No default_policy found in any node.")
return
# Recursive average
avg_default_policy = _recursive_average(default_policy_list)
# Return the new global weights (only the "default_policy" field)
global_weights = collections.OrderedDict({
"default_policy": avg_default_policy
})
print("[HAED][INFO] FedAvg: Global weights updated.")
return global_weights