From d5b4039f3df72c643b45fdc4018a7821f4741348 Mon Sep 17 00:00:00 2001 From: Alex Cheema Date: Thu, 5 Dec 2024 21:00:33 +0000 Subject: [PATCH 01/12] show interface type and name, fix propagation of topology --- exo/networking/udp/udp_discovery.py | 6 ++++-- exo/orchestration/standard_node.py | 8 ++++---- 2 files changed, 8 insertions(+), 6 deletions(-) diff --git a/exo/networking/udp/udp_discovery.py b/exo/networking/udp/udp_discovery.py index 2dd9235fc..624933d28 100644 --- a/exo/networking/udp/udp_discovery.py +++ b/exo/networking/udp/udp_discovery.py @@ -88,7 +88,7 @@ async def task_broadcast_presence(self): # Explicitly broadcasting on all assigned ips since broadcasting on `0.0.0.0` on MacOS does not broadcast over # the Thunderbolt bridge when other connection modalities exist such as WiFi or Ethernet for addr, interface_name in get_all_ip_addresses_and_interfaces(): - interface_priority, _ = get_interface_priority_and_type(interface_name) + interface_priority, interface_type = get_interface_priority_and_type(interface_name) message = json.dumps({ "type": "discovery", "node_id": self.node_id, @@ -96,6 +96,7 @@ async def task_broadcast_presence(self): "device_capabilities": self.device_capabilities.to_dict(), "priority": interface_priority, # TODO: Prioritise interfaces based on bandwidth, latency, and jitter e.g. prioritise Thunderbolt over WiFi. "interface_name": interface_name, + "interface_type": interface_type, }) if DEBUG_DISCOVERY >= 3: print(f"Broadcasting presence at ({addr} - {interface_name} - {interface_priority}): {message}") @@ -145,6 +146,7 @@ async def on_listen_message(self, data, addr): peer_port = message["grpc_port"] peer_prio = message["priority"] peer_interface_name = message["interface_name"] + peer_interface_type = message["interface_type"] device_capabilities = DeviceCapabilities(**message["device_capabilities"]) if peer_id not in self.known_peers or self.known_peers[peer_id][0].addr() != f"{peer_host}:{peer_port}": @@ -154,7 +156,7 @@ async def on_listen_message(self, data, addr): if DEBUG >= 1: print(f"Ignoring peer {peer_id} at {peer_host}:{peer_port} with priority {peer_prio} because we already know about a peer with higher or equal priority: {existing_peer_prio}") return - new_peer_handle = self.create_peer_handle(peer_id, f"{peer_host}:{peer_port}", peer_interface_name, device_capabilities) + new_peer_handle = self.create_peer_handle(peer_id, f"{peer_host}:{peer_port}", f"{peer_interface_type} ({peer_interface_name})", device_capabilities) if not await new_peer_handle.health_check(): if DEBUG >= 1: print(f"Peer {peer_id} at {peer_host}:{peer_port} is not healthy. Skipping.") return diff --git a/exo/orchestration/standard_node.py b/exo/orchestration/standard_node.py index f045c4efe..1deb7c044 100644 --- a/exo/orchestration/standard_node.py +++ b/exo/orchestration/standard_node.py @@ -410,16 +410,16 @@ async def collect_topology(self, visited: set[str] = set(), max_depth: int = 4) try: other_topology = await asyncio.wait_for(peer.collect_topology(visited, max_depth=max_depth - 1), timeout=5.0) if DEBUG >= 2: print(f"Collected topology from: {peer.id()}: {other_topology}") - self.topology.merge(other_topology) + next_topology.merge(other_topology) except Exception as e: print(f"Error collecting topology from {peer.id()}: {e}") traceback.print_exc() - next_topology.active_node_id = self.topology.active_node_id # this is not so clean. + next_topology.active_node_id = self.topology.active_node_id self.topology = next_topology if self.topology_viz: - self.topology_viz.update_visualization(self.current_topology, self.partitioning_strategy.partition(self.current_topology), self.id) - return next_topology + self.topology_viz.update_visualization(self.topology, self.partitioning_strategy.partition(self.topology), self.id) + return self.topology @property def on_token(self) -> AsyncCallbackSystem[str, Tuple[str, List[int], bool]]: From 99b5bf01a24a2b5960224e9c86c215542f309f86 Mon Sep 17 00:00:00 2001 From: Alex Cheema Date: Thu, 5 Dec 2024 21:12:47 +0000 Subject: [PATCH 02/12] fix topology merging --- exo/topology/topology.py | 17 +++++------------ 1 file changed, 5 insertions(+), 12 deletions(-) diff --git a/exo/topology/topology.py b/exo/topology/topology.py index 2ec5a2aba..5549a9cbb 100644 --- a/exo/topology/topology.py +++ b/exo/topology/topology.py @@ -34,18 +34,11 @@ def get_node(self, node_id: str) -> DeviceCapabilities: def all_nodes(self): return self.nodes.items() - def add_edge(self, node1_id: str, node2_id: str, description: Optional[str] = None): - if node1_id not in self.peer_graph: - self.peer_graph[node1_id] = set() - if node2_id not in self.peer_graph: - self.peer_graph[node2_id] = set() - - # Create bidirectional connections with the same description - conn1 = PeerConnection(node1_id, node2_id, description) - conn2 = PeerConnection(node2_id, node1_id, description) - - self.peer_graph[node1_id].add(conn1) - self.peer_graph[node2_id].add(conn2) + def add_edge(self, from_id: str, to_id: str, description: Optional[str] = None): + if from_id not in self.peer_graph: + self.peer_graph[from_id] = set() + conn = PeerConnection(from_id, to_id, description) + self.peer_graph[from_id].add(conn) def get_neighbors(self, node_id: str) -> Set[str]: # Convert PeerConnection objects back to just destination IDs From dba720445387961d8a9fac0673600089e04d1261 Mon Sep 17 00:00:00 2001 From: Alex Cheema Date: Thu, 5 Dec 2024 21:38:03 +0000 Subject: [PATCH 03/12] handle mutable visited properly --- exo/orchestration/standard_node.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/exo/orchestration/standard_node.py b/exo/orchestration/standard_node.py index 1deb7c044..310f18f60 100644 --- a/exo/orchestration/standard_node.py +++ b/exo/orchestration/standard_node.py @@ -56,7 +56,7 @@ async def start(self, wait_for_peers: int = 0) -> None: await self.server.start() await self.discovery.start() await self.update_peers(wait_for_peers) - await self.collect_topology() + await self.collect_topology(set()) if DEBUG >= 2: print(f"Collected topology: {self.topology}") asyncio.create_task(self.periodic_topology_collection(1.0)) @@ -374,8 +374,8 @@ async def periodic_topology_collection(self, interval: int): try: did_peers_change = await self.update_peers() if DEBUG >= 2: print(f"{did_peers_change=}") + await self.collect_topology(set()) if did_peers_change: - await self.collect_topology() await self.select_best_inference_engine() except Exception as e: print(f"Error collecting topology: {e}") @@ -386,7 +386,7 @@ async def get_inference_result(self, request_id: str) -> Tuple[Optional[np.ndarr return None, False return np.array(self.buffered_token_output[request_id][0]), self.buffered_token_output[request_id][1] - async def collect_topology(self, visited: set[str] = set(), max_depth: int = 4) -> Topology: + async def collect_topology(self, visited: set[str], max_depth: int = 4) -> Topology: next_topology = Topology() next_topology.update_node(self.id, self.device_capabilities) From f8d195eea580c8bfec0b64c179cbadb7727706da Mon Sep 17 00:00:00 2001 From: Alex Cheema Date: Thu, 5 Dec 2024 21:42:41 +0000 Subject: [PATCH 04/12] only collect topology when peers changed --- exo/orchestration/standard_node.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/exo/orchestration/standard_node.py b/exo/orchestration/standard_node.py index 310f18f60..777947538 100644 --- a/exo/orchestration/standard_node.py +++ b/exo/orchestration/standard_node.py @@ -374,8 +374,8 @@ async def periodic_topology_collection(self, interval: int): try: did_peers_change = await self.update_peers() if DEBUG >= 2: print(f"{did_peers_change=}") - await self.collect_topology(set()) if did_peers_change: + await self.collect_topology(set()) await self.select_best_inference_engine() except Exception as e: print(f"Error collecting topology: {e}") From 68d70be900c909e1eca48c7c437ecdc5efaaac28 Mon Sep 17 00:00:00 2001 From: Alex Cheema Date: Thu, 5 Dec 2024 21:55:31 +0000 Subject: [PATCH 05/12] always show desc1/desc2 in tui --- exo/viz/topology_viz.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/exo/viz/topology_viz.py b/exo/viz/topology_viz.py index 1c121dea6..4487cb221 100644 --- a/exo/viz/topology_viz.py +++ b/exo/viz/topology_viz.py @@ -253,7 +253,7 @@ def _generate_main_layout(self) -> str: conn2 = self.topology.peer_graph.get(self.partitions[next_i].node_id, set()) description1 = next((c.description for c in conn1 if c.to_id == self.partitions[next_i].node_id), "") description2 = next((c.description for c in conn2 if c.to_id == partition.node_id), "") - connection_description = f"{description1}/{description2}" if description1 != description2 else description1 + connection_description = f"{description1}/{description2}" # Simple line drawing steps = max(abs(next_x - x), abs(next_y - y)) From 272b1e2a1ed9c52ff104c68ffd5a0246d0920a16 Mon Sep 17 00:00:00 2001 From: Alex Cheema Date: Thu, 5 Dec 2024 21:55:40 +0000 Subject: [PATCH 06/12] remove unused funcs --- exo/topology/topology.py | 15 +-------------- 1 file changed, 1 insertion(+), 14 deletions(-) diff --git a/exo/topology/topology.py b/exo/topology/topology.py index 5549a9cbb..b966f7f14 100644 --- a/exo/topology/topology.py +++ b/exo/topology/topology.py @@ -1,5 +1,5 @@ from .device_capabilities import DeviceCapabilities -from typing import Dict, Set, Optional, NamedTuple +from typing import Dict, Set, Optional from dataclasses import dataclass @dataclass @@ -40,19 +40,6 @@ def add_edge(self, from_id: str, to_id: str, description: Optional[str] = None): conn = PeerConnection(from_id, to_id, description) self.peer_graph[from_id].add(conn) - def get_neighbors(self, node_id: str) -> Set[str]: - # Convert PeerConnection objects back to just destination IDs - return {conn.to_id for conn in self.peer_graph.get(node_id, set())} - - def all_edges(self): - edges = [] - for node_id, connections in self.peer_graph.items(): - for conn in connections: - # Only include each edge once by checking if reverse already exists - if not any(e[0] == conn.to_id and e[1] == conn.from_id for e in edges): - edges.append((conn.from_id, conn.to_id, conn.description)) - return edges - def merge(self, other: "Topology"): for node_id, capabilities in other.nodes.items(): self.update_node(node_id, capabilities) From 81632247653ab2799a4b161ccc73d5283dc51b7a Mon Sep 17 00:00:00 2001 From: Alex Cheema Date: Thu, 5 Dec 2024 21:59:06 +0000 Subject: [PATCH 07/12] coll --- exo/orchestration/standard_node.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/exo/orchestration/standard_node.py b/exo/orchestration/standard_node.py index 777947538..310f18f60 100644 --- a/exo/orchestration/standard_node.py +++ b/exo/orchestration/standard_node.py @@ -374,8 +374,8 @@ async def periodic_topology_collection(self, interval: int): try: did_peers_change = await self.update_peers() if DEBUG >= 2: print(f"{did_peers_change=}") + await self.collect_topology(set()) if did_peers_change: - await self.collect_topology(set()) await self.select_best_inference_engine() except Exception as e: print(f"Error collecting topology: {e}") From 657520ed4c9abfedcf8a4e4dec80e567c190ee7a Mon Sep 17 00:00:00 2001 From: Alex Cheema Date: Thu, 5 Dec 2024 22:31:21 +0000 Subject: [PATCH 08/12] pass origin_node_id to merge --- exo/orchestration/standard_node.py | 2 +- exo/topology/topology.py | 9 +++++---- 2 files changed, 6 insertions(+), 5 deletions(-) diff --git a/exo/orchestration/standard_node.py b/exo/orchestration/standard_node.py index 310f18f60..b88b17223 100644 --- a/exo/orchestration/standard_node.py +++ b/exo/orchestration/standard_node.py @@ -410,7 +410,7 @@ async def collect_topology(self, visited: set[str], max_depth: int = 4) -> Topol try: other_topology = await asyncio.wait_for(peer.collect_topology(visited, max_depth=max_depth - 1), timeout=5.0) if DEBUG >= 2: print(f"Collected topology from: {peer.id()}: {other_topology}") - next_topology.merge(other_topology) + next_topology.merge(self.id, other_topology) except Exception as e: print(f"Error collecting topology from {peer.id()}: {e}") traceback.print_exc() diff --git a/exo/topology/topology.py b/exo/topology/topology.py index b966f7f14..6040369fa 100644 --- a/exo/topology/topology.py +++ b/exo/topology/topology.py @@ -21,7 +21,6 @@ def __eq__(self, other): class Topology: def __init__(self): self.nodes: Dict[str, DeviceCapabilities] = {} - # Store PeerConnection objects in the adjacency lists self.peer_graph: Dict[str, Set[PeerConnection]] = {} self.active_node_id: Optional[str] = None @@ -40,12 +39,14 @@ def add_edge(self, from_id: str, to_id: str, description: Optional[str] = None): conn = PeerConnection(from_id, to_id, description) self.peer_graph[from_id].add(conn) - def merge(self, other: "Topology"): + def merge(self, origin_node_id: str, other: "Topology"): for node_id, capabilities in other.nodes.items(): - self.update_node(node_id, capabilities) + if node_id != origin_node_id: + self.update_node(node_id, capabilities) for node_id, connections in other.peer_graph.items(): for conn in connections: - self.add_edge(conn.from_id, conn.to_id, conn.description) + if conn.from_id != origin_node_id: + self.add_edge(conn.from_id, conn.to_id, conn.description) def __str__(self): nodes_str = ", ".join(f"{node_id}: {cap}" for node_id, cap in self.nodes.items()) From 0b9ee8abf70631737d141bca983df9c97e3162d1 Mon Sep 17 00:00:00 2001 From: Alex Cheema Date: Thu, 5 Dec 2024 22:33:42 +0000 Subject: [PATCH 09/12] consistnet self.topology --- exo/orchestration/standard_node.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/exo/orchestration/standard_node.py b/exo/orchestration/standard_node.py index b88b17223..5ba3b632e 100644 --- a/exo/orchestration/standard_node.py +++ b/exo/orchestration/standard_node.py @@ -83,7 +83,7 @@ def on_node_status(self, request_id, opaque_status): download_progress = RepoProgressEvent.from_dict(status_data.get('progress')) self.node_download_progress[status_data.get('node_id')] = download_progress if self.topology_viz: - self.topology_viz.update_visualization(self.current_topology, self.partitioning_strategy.partition(self.current_topology), self.id, self.node_download_progress) + self.topology_viz.update_visualization(self.topology, self.partitioning_strategy.partition(self.topology), self.id, self.node_download_progress) except Exception as e: if DEBUG >= 1: print(f"Error updating visualization: {e}") if DEBUG >= 1: traceback.print_exc() From a0e083a14118c38b9a1cb655498d0a54057eed5d Mon Sep 17 00:00:00 2001 From: Alex Cheema Date: Thu, 5 Dec 2024 22:51:52 +0000 Subject: [PATCH 10/12] test --- exo/networking/grpc/grpc_server.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/exo/networking/grpc/grpc_server.py b/exo/networking/grpc/grpc_server.py index d696eca90..d4a90d602 100644 --- a/exo/networking/grpc/grpc_server.py +++ b/exo/networking/grpc/grpc_server.py @@ -85,7 +85,7 @@ async def GetInferenceResult(self, request, context): async def CollectTopology(self, request, context): max_depth = request.max_depth visited = set(request.visited) - topology = await self.node.collect_topology(visited, max_depth) + topology = self.node.current_topology nodes = { node_id: node_service_pb2.DeviceCapabilities( From 55344241aa192fd42338c10337afa670d51ebfdd Mon Sep 17 00:00:00 2001 From: Alex Cheema Date: Thu, 5 Dec 2024 22:57:31 +0000 Subject: [PATCH 11/12] use double for flops protobuf --- exo/networking/grpc/node_service.proto | 6 +++--- exo/networking/grpc/node_service_pb2.py | 2 +- 2 files changed, 4 insertions(+), 4 deletions(-) diff --git a/exo/networking/grpc/node_service.proto b/exo/networking/grpc/node_service.proto index 8739faa21..5ea8b11ae 100644 --- a/exo/networking/grpc/node_service.proto +++ b/exo/networking/grpc/node_service.proto @@ -66,9 +66,9 @@ message PeerConnections { } message DeviceFlops { - float fp32 = 1; - float fp16 = 2; - float int8 = 3; + double fp32 = 1; + double fp16 = 2; + double int8 = 3; } message DeviceCapabilities { diff --git a/exo/networking/grpc/node_service_pb2.py b/exo/networking/grpc/node_service_pb2.py index a7bf1a776..87246583b 100644 --- a/exo/networking/grpc/node_service_pb2.py +++ b/exo/networking/grpc/node_service_pb2.py @@ -24,7 +24,7 @@ -DESCRIPTOR = _descriptor_pool.Default().AddSerializedFile(b'\n\x12node_service.proto\x12\x0cnode_service\"S\n\x05Shard\x12\x10\n\x08model_id\x18\x01 \x01(\t\x12\x13\n\x0bstart_layer\x18\x02 \x01(\x05\x12\x11\n\tend_layer\x18\x03 \x01(\x05\x12\x10\n\x08n_layers\x18\x04 \x01(\x05\"k\n\rPromptRequest\x12\"\n\x05shard\x18\x01 \x01(\x0b\x32\x13.node_service.Shard\x12\x0e\n\x06prompt\x18\x02 \x01(\t\x12\x17\n\nrequest_id\x18\x03 \x01(\tH\x00\x88\x01\x01\x42\r\n\x0b_request_id\"\x81\x01\n\rTensorRequest\x12\"\n\x05shard\x18\x01 \x01(\x0b\x32\x13.node_service.Shard\x12$\n\x06tensor\x18\x02 \x01(\x0b\x32\x14.node_service.Tensor\x12\x17\n\nrequest_id\x18\x03 \x01(\tH\x00\x88\x01\x01\x42\r\n\x0b_request_id\"/\n\x19GetInferenceResultRequest\x12\x12\n\nrequest_id\x18\x01 \x01(\t\"\\\n\x0fInferenceResult\x12)\n\x06tensor\x18\x01 \x01(\x0b\x32\x14.node_service.TensorH\x00\x88\x01\x01\x12\x13\n\x0bis_finished\x18\x02 \x01(\x08\x42\t\n\x07_tensor\";\n\x06Tensor\x12\x13\n\x0btensor_data\x18\x01 \x01(\x0c\x12\r\n\x05shape\x18\x02 \x03(\x05\x12\r\n\x05\x64type\x18\x03 \x01(\t\"<\n\x16\x43ollectTopologyRequest\x12\x0f\n\x07visited\x18\x01 \x03(\t\x12\x11\n\tmax_depth\x18\x02 \x01(\x05\"\x98\x02\n\x08Topology\x12\x30\n\x05nodes\x18\x01 \x03(\x0b\x32!.node_service.Topology.NodesEntry\x12\x39\n\npeer_graph\x18\x02 \x03(\x0b\x32%.node_service.Topology.PeerGraphEntry\x1aN\n\nNodesEntry\x12\x0b\n\x03key\x18\x01 \x01(\t\x12/\n\x05value\x18\x02 \x01(\x0b\x32 .node_service.DeviceCapabilities:\x02\x38\x01\x1aO\n\x0ePeerGraphEntry\x12\x0b\n\x03key\x18\x01 \x01(\t\x12,\n\x05value\x18\x02 \x01(\x0b\x32\x1d.node_service.PeerConnections:\x02\x38\x01\"I\n\x0ePeerConnection\x12\r\n\x05to_id\x18\x01 \x01(\t\x12\x18\n\x0b\x64\x65scription\x18\x02 \x01(\tH\x00\x88\x01\x01\x42\x0e\n\x0c_description\"D\n\x0fPeerConnections\x12\x31\n\x0b\x63onnections\x18\x01 \x03(\x0b\x32\x1c.node_service.PeerConnection\"7\n\x0b\x44\x65viceFlops\x12\x0c\n\x04\x66p32\x18\x01 \x01(\x02\x12\x0c\n\x04\x66p16\x18\x02 \x01(\x02\x12\x0c\n\x04int8\x18\x03 \x01(\x02\"k\n\x12\x44\x65viceCapabilities\x12\r\n\x05model\x18\x01 \x01(\t\x12\x0c\n\x04\x63hip\x18\x02 \x01(\t\x12\x0e\n\x06memory\x18\x03 \x01(\x05\x12(\n\x05\x66lops\x18\x04 \x01(\x0b\x32\x19.node_service.DeviceFlops\"L\n\x11SendResultRequest\x12\x12\n\nrequest_id\x18\x01 \x01(\t\x12\x0e\n\x06result\x18\x02 \x03(\x05\x12\x13\n\x0bis_finished\x18\x03 \x01(\x08\"=\n\x17SendOpaqueStatusRequest\x12\x12\n\nrequest_id\x18\x01 \x01(\t\x12\x0e\n\x06status\x18\x02 \x01(\t\"\x14\n\x12HealthCheckRequest\")\n\x13HealthCheckResponse\x12\x12\n\nis_healthy\x18\x01 \x01(\x08\"\x07\n\x05\x45mpty2\xb4\x04\n\x0bNodeService\x12\x41\n\nSendPrompt\x12\x1b.node_service.PromptRequest\x1a\x14.node_service.Tensor\"\x00\x12\x41\n\nSendTensor\x12\x1b.node_service.TensorRequest\x1a\x14.node_service.Tensor\"\x00\x12^\n\x12GetInferenceResult\x12\'.node_service.GetInferenceResultRequest\x1a\x1d.node_service.InferenceResult\"\x00\x12Q\n\x0f\x43ollectTopology\x12$.node_service.CollectTopologyRequest\x1a\x16.node_service.Topology\"\x00\x12\x44\n\nSendResult\x12\x1f.node_service.SendResultRequest\x1a\x13.node_service.Empty\"\x00\x12P\n\x10SendOpaqueStatus\x12%.node_service.SendOpaqueStatusRequest\x1a\x13.node_service.Empty\"\x00\x12T\n\x0bHealthCheck\x12 .node_service.HealthCheckRequest\x1a!.node_service.HealthCheckResponse\"\x00\x62\x06proto3') +DESCRIPTOR = _descriptor_pool.Default().AddSerializedFile(b'\n\x12node_service.proto\x12\x0cnode_service\"S\n\x05Shard\x12\x10\n\x08model_id\x18\x01 \x01(\t\x12\x13\n\x0bstart_layer\x18\x02 \x01(\x05\x12\x11\n\tend_layer\x18\x03 \x01(\x05\x12\x10\n\x08n_layers\x18\x04 \x01(\x05\"k\n\rPromptRequest\x12\"\n\x05shard\x18\x01 \x01(\x0b\x32\x13.node_service.Shard\x12\x0e\n\x06prompt\x18\x02 \x01(\t\x12\x17\n\nrequest_id\x18\x03 \x01(\tH\x00\x88\x01\x01\x42\r\n\x0b_request_id\"\x81\x01\n\rTensorRequest\x12\"\n\x05shard\x18\x01 \x01(\x0b\x32\x13.node_service.Shard\x12$\n\x06tensor\x18\x02 \x01(\x0b\x32\x14.node_service.Tensor\x12\x17\n\nrequest_id\x18\x03 \x01(\tH\x00\x88\x01\x01\x42\r\n\x0b_request_id\"/\n\x19GetInferenceResultRequest\x12\x12\n\nrequest_id\x18\x01 \x01(\t\"\\\n\x0fInferenceResult\x12)\n\x06tensor\x18\x01 \x01(\x0b\x32\x14.node_service.TensorH\x00\x88\x01\x01\x12\x13\n\x0bis_finished\x18\x02 \x01(\x08\x42\t\n\x07_tensor\";\n\x06Tensor\x12\x13\n\x0btensor_data\x18\x01 \x01(\x0c\x12\r\n\x05shape\x18\x02 \x03(\x05\x12\r\n\x05\x64type\x18\x03 \x01(\t\"<\n\x16\x43ollectTopologyRequest\x12\x0f\n\x07visited\x18\x01 \x03(\t\x12\x11\n\tmax_depth\x18\x02 \x01(\x05\"\x98\x02\n\x08Topology\x12\x30\n\x05nodes\x18\x01 \x03(\x0b\x32!.node_service.Topology.NodesEntry\x12\x39\n\npeer_graph\x18\x02 \x03(\x0b\x32%.node_service.Topology.PeerGraphEntry\x1aN\n\nNodesEntry\x12\x0b\n\x03key\x18\x01 \x01(\t\x12/\n\x05value\x18\x02 \x01(\x0b\x32 .node_service.DeviceCapabilities:\x02\x38\x01\x1aO\n\x0ePeerGraphEntry\x12\x0b\n\x03key\x18\x01 \x01(\t\x12,\n\x05value\x18\x02 \x01(\x0b\x32\x1d.node_service.PeerConnections:\x02\x38\x01\"I\n\x0ePeerConnection\x12\r\n\x05to_id\x18\x01 \x01(\t\x12\x18\n\x0b\x64\x65scription\x18\x02 \x01(\tH\x00\x88\x01\x01\x42\x0e\n\x0c_description\"D\n\x0fPeerConnections\x12\x31\n\x0b\x63onnections\x18\x01 \x03(\x0b\x32\x1c.node_service.PeerConnection\"7\n\x0b\x44\x65viceFlops\x12\x0c\n\x04\x66p32\x18\x01 \x01(\x01\x12\x0c\n\x04\x66p16\x18\x02 \x01(\x01\x12\x0c\n\x04int8\x18\x03 \x01(\x01\"k\n\x12\x44\x65viceCapabilities\x12\r\n\x05model\x18\x01 \x01(\t\x12\x0c\n\x04\x63hip\x18\x02 \x01(\t\x12\x0e\n\x06memory\x18\x03 \x01(\x05\x12(\n\x05\x66lops\x18\x04 \x01(\x0b\x32\x19.node_service.DeviceFlops\"L\n\x11SendResultRequest\x12\x12\n\nrequest_id\x18\x01 \x01(\t\x12\x0e\n\x06result\x18\x02 \x03(\x05\x12\x13\n\x0bis_finished\x18\x03 \x01(\x08\"=\n\x17SendOpaqueStatusRequest\x12\x12\n\nrequest_id\x18\x01 \x01(\t\x12\x0e\n\x06status\x18\x02 \x01(\t\"\x14\n\x12HealthCheckRequest\")\n\x13HealthCheckResponse\x12\x12\n\nis_healthy\x18\x01 \x01(\x08\"\x07\n\x05\x45mpty2\xb4\x04\n\x0bNodeService\x12\x41\n\nSendPrompt\x12\x1b.node_service.PromptRequest\x1a\x14.node_service.Tensor\"\x00\x12\x41\n\nSendTensor\x12\x1b.node_service.TensorRequest\x1a\x14.node_service.Tensor\"\x00\x12^\n\x12GetInferenceResult\x12\'.node_service.GetInferenceResultRequest\x1a\x1d.node_service.InferenceResult\"\x00\x12Q\n\x0f\x43ollectTopology\x12$.node_service.CollectTopologyRequest\x1a\x16.node_service.Topology\"\x00\x12\x44\n\nSendResult\x12\x1f.node_service.SendResultRequest\x1a\x13.node_service.Empty\"\x00\x12P\n\x10SendOpaqueStatus\x12%.node_service.SendOpaqueStatusRequest\x1a\x13.node_service.Empty\"\x00\x12T\n\x0bHealthCheck\x12 .node_service.HealthCheckRequest\x1a!.node_service.HealthCheckResponse\"\x00\x62\x06proto3') _globals = globals() _builder.BuildMessageAndEnumDescriptors(DESCRIPTOR, _globals) From db7c388ac61f7edeab57deef8681ea769332d31b Mon Sep 17 00:00:00 2001 From: Alex Cheema Date: Thu, 5 Dec 2024 23:03:18 +0000 Subject: [PATCH 12/12] remove origin_node_id --- exo/orchestration/standard_node.py | 2 +- exo/topology/topology.py | 8 +++----- 2 files changed, 4 insertions(+), 6 deletions(-) diff --git a/exo/orchestration/standard_node.py b/exo/orchestration/standard_node.py index 5ba3b632e..dbb146a87 100644 --- a/exo/orchestration/standard_node.py +++ b/exo/orchestration/standard_node.py @@ -410,7 +410,7 @@ async def collect_topology(self, visited: set[str], max_depth: int = 4) -> Topol try: other_topology = await asyncio.wait_for(peer.collect_topology(visited, max_depth=max_depth - 1), timeout=5.0) if DEBUG >= 2: print(f"Collected topology from: {peer.id()}: {other_topology}") - next_topology.merge(self.id, other_topology) + next_topology.merge(other_topology) except Exception as e: print(f"Error collecting topology from {peer.id()}: {e}") traceback.print_exc() diff --git a/exo/topology/topology.py b/exo/topology/topology.py index 6040369fa..a002cdc11 100644 --- a/exo/topology/topology.py +++ b/exo/topology/topology.py @@ -39,14 +39,12 @@ def add_edge(self, from_id: str, to_id: str, description: Optional[str] = None): conn = PeerConnection(from_id, to_id, description) self.peer_graph[from_id].add(conn) - def merge(self, origin_node_id: str, other: "Topology"): + def merge(self, other: "Topology"): for node_id, capabilities in other.nodes.items(): - if node_id != origin_node_id: - self.update_node(node_id, capabilities) + self.update_node(node_id, capabilities) for node_id, connections in other.peer_graph.items(): for conn in connections: - if conn.from_id != origin_node_id: - self.add_edge(conn.from_id, conn.to_id, conn.description) + self.add_edge(conn.from_id, conn.to_id, conn.description) def __str__(self): nodes_str = ", ".join(f"{node_id}: {cap}" for node_id, cap in self.nodes.items())