diff --git a/alpa/device_mesh.py b/alpa/device_mesh.py index 060f8830a..4950005a6 100644 --- a/alpa/device_mesh.py +++ b/alpa/device_mesh.py @@ -599,10 +599,14 @@ def get_live_buffer_uuids(self): return list(self.buffers.keys()) ##### Other Functions ##### - def sync(self): + def sync(self, sync_all_devices=False): # We sync one device instead of all for smaller runtime overhead. # This is correct because of SPMD. - self.local_devices[0].synchronize_all_activity() + if sync_all_devices: + for device in self.local_devices: + device.synchronize_all_activity() + else: + self.local_devices[0].synchronize_all_activity() def sync_all(self): for device in self.local_devices: @@ -1396,8 +1400,8 @@ def reset_memory_stats(self): ray.get(worker.reset_memory_stats.remote()) ##### Other Functions ##### - def sync_workers(self): - ray.get([w.sync.remote() for w in self.workers]) + def sync_workers(self, sync_all_devices=False): + ray.get([w.sync.remote(sync_all_devices) for w in self.workers]) def sync_move_workers(self): ray.get([w.sync_move_worker.remote() for w in self.workers]) diff --git a/alpa/global_env.py b/alpa/global_env.py index be58ea545..50c961785 100644 --- a/alpa/global_env.py +++ b/alpa/global_env.py @@ -65,6 +65,11 @@ def __init__(self): # "xla_extension"} self.nccl_mode = "cupy" self.enable_overlapping = False + # Cross mesh resharding load balancing mode. + # Possible choices: {"normal", "no_loadbalance", + # "loadbalance_size", "loadbalance_order"} + self.resharding_loadbalance_mode = "normal" + self.loadbalance_order_algo = "greedy" ########## Options of benchmark ########## # If true, the system is allowed to use dummy values during diff --git a/alpa/pipeline_parallel/cross_mesh_resharding.py b/alpa/pipeline_parallel/cross_mesh_resharding.py index 65c551ffc..388ba8847 100644 --- a/alpa/pipeline_parallel/cross_mesh_resharding.py +++ b/alpa/pipeline_parallel/cross_mesh_resharding.py @@ -1,17 +1,21 @@ """Cross mesh resharding for pipeline parallelism.""" +from abc import ABC, abstractmethod +from collections import namedtuple import logging import math +import random +import time from typing import List, Any -import numpy as np from jax.interpreters import pxla +import numpy as np import ray import alpa.collective as col from alpa.device_mesh import (DistributedArray, RemoteArrayRef, ReshardingRecvSpec, ReshardingSendSpec, ReshardingTileSpec, ReshardingBroadcastSpec, - _device_mesh_put_dummy) + _device_mesh_put_dummy, device_id_to_str) from alpa.global_env import global_config from alpa.mesh_executable import (UtilMeshWorkerExecutable, next_mesh_executable_uuid) @@ -289,48 +293,53 @@ def create_resharding_communicators(self): def _compile_send_recv_tasks(self): """Generate all send/recv tasks.""" - for i, (dst_tile, src_tiles, indices_in_dst_tiles) in enumerate( - self.task_spec.dst_tile_to_src_tiles_map): + dtype = self.task_spec.src.aval.dtype + + # print("order: ", self.task_spec.strategy.order) + for i, k, j in self.task_spec.strategy.order: spec_plan = self.task_spec.strategy.per_spec_plans[i] - for replica_index, receiver in enumerate( - dst_tile.replica_device_strs): - # Get args for an empty buffer - receiver_device_id = ( - self.collective_group.device_str_to_device_id_map[receiver]) - receiver_worker = (self.collective_group. - device_str_to_mesh_worker_map[receiver]) - dtype = self.task_spec.src.aval.dtype - # Get args for send/recv - senders = [ - spec_plan[replica_index][src_tile_index] - for src_tile_index, _ in enumerate(src_tiles) - ] - receiver_rank, receiver_gpu_idx = ( - self.collective_group.device_str_to_rank_map[receiver]) - recv_tile_specs = [] - for sender_idx, sender in enumerate(senders): - # Sender's task - sender_worker = (self.collective_group. - device_str_to_mesh_worker_map[sender]) - src_device_id = (self.collective_group. - device_str_to_device_id_map[sender]) - self._sender_tasks[sender_worker].append( - ReshardingSendSpec( - src_device_id, - ReshardingTileSpec(src_tiles[sender_idx].offset, - receiver_rank, - receiver_gpu_idx))) - # Receiver's task - sender_rank, sender_gpu_idx = \ - self.collective_group.device_str_to_rank_map[sender] - indices_in_dst_tile = indices_in_dst_tiles[sender_idx] - recv_tile_specs.append( - ReshardingTileSpec(indices_in_dst_tile, sender_rank, - sender_gpu_idx)) - receiver_task = ReshardingRecvSpec(receiver_device_id, - dst_tile.tile_shape, dtype, - recv_tile_specs) - self._receiver_tasks[receiver_worker].append(receiver_task) + dst_tile, src_tiles, indices_in_dst_tiles = ( + self.task_spec.dst_tile_to_src_tiles_map[i]) + replica_index, receiver = k, dst_tile.replica_device_strs[k] + _, _, indices_in_dst_tile = (j, src_tiles[j], + indices_in_dst_tiles[j]) + + # Get args for an empty buffer + receiver_device_id = ( + self.collective_group.device_str_to_device_id_map[receiver]) + receiver_worker = ( + self.collective_group.device_str_to_mesh_worker_map[receiver]) + dtype = self.task_spec.src.aval.dtype + # Get args for send/recv + senders = [ + spec_plan[replica_index][src_tile_index] + for src_tile_index, _ in enumerate(src_tiles) + ] + receiver_rank, receiver_gpu_idx = ( + self.collective_group.device_str_to_rank_map[receiver]) + recv_tile_specs = [] + for sender_idx, sender in enumerate(senders): + # Sender's task + sender_worker = ( + self.collective_group.device_str_to_mesh_worker_map[sender]) + src_device_id = ( + self.collective_group.device_str_to_device_id_map[sender]) + self._sender_tasks[sender_worker].append( + ReshardingSendSpec( + src_device_id, + ReshardingTileSpec(src_tiles[sender_idx].offset, + receiver_rank, receiver_gpu_idx))) + # Receiver's task + sender_rank, sender_gpu_idx = \ + self.collective_group.device_str_to_rank_map[sender] + indices_in_dst_tile = indices_in_dst_tiles[sender_idx] + recv_tile_specs.append( + ReshardingTileSpec(indices_in_dst_tile, sender_rank, + sender_gpu_idx)) + receiver_task = ReshardingRecvSpec(receiver_device_id, + dst_tile.tile_shape, dtype, + recv_tile_specs) + self._receiver_tasks[receiver_worker].append(receiver_task) # FIXME(Hao): test the function below; it might be buggy. def do_prepared(self, src_array, profiling=False): @@ -457,64 +466,65 @@ def put_all_tasks(self): def _compile_broadcast_tasks(self): """Compile broadcast tasks.""" dtype = self.task_spec.src.aval.dtype - for i, (dst_tile, src_tiles, indices_in_dst_tiles) in enumerate( - self.task_spec.dst_tile_to_src_tiles_map): - spec_plan = self.task_spec.strategy.per_spec_plans[i] - for src_tile_index, (src_tile, indices_in_dst_tile) in enumerate( - zip(src_tiles, indices_in_dst_tiles)): - sender = spec_plan[src_tile_index] - sender_worker = ( - self.collective_group.device_str_to_mesh_worker_map[sender]) - broadcast_group = (i, src_tile_index) - devices = [sender] + dst_tile.replica_device_strs - comm_key = "$".join(devices) - world_size = len(devices) - comm_config = CommunicatorConfig(comm_key) + # print("order: ", self.task_spec.strategy.order) + for i, j in self.task_spec.strategy.order: + spec_plan = self.task_spec.strategy.per_spec_plans[i] + dst_tile, src_tiles, indices_in_dst_tiles = ( + self.task_spec.dst_tile_to_src_tiles_map[i]) + src_tile, indices_in_dst_tile = (src_tiles[j], + indices_in_dst_tiles[j]) + + sender = spec_plan[j] + sender_worker = ( + self.collective_group.device_str_to_mesh_worker_map[sender]) + broadcast_group = (i, j) + devices = [sender] + dst_tile.replica_device_strs + comm_key = "$".join(devices) + world_size = len(devices) + + comm_config = CommunicatorConfig(comm_key) + + group_spec = self._broadcast_tasks[sender_worker].setdefault( + broadcast_group, + ReshardingBroadcastSpec(comm_key=comm_key, + world_size=world_size, + devices_ids=[ + self.collective_group. + device_str_to_device_id_map[sender] + ], + devices_global_rank=[0], + tensor_slices=[src_tile.offset], + recv_tile_shape=src_tile.tile_shape, + dtype=dtype)) + comm_config.add( + sender_worker, + self.collective_group.device_str_to_device_id_map[sender]) - group_spec = self._broadcast_tasks[sender_worker].setdefault( + for replica_index, receiver in enumerate( + dst_tile.replica_device_strs): + receiver_worker = (self.collective_group. + device_str_to_mesh_worker_map[receiver]) + group_spec = self._broadcast_tasks[receiver_worker].setdefault( broadcast_group, - ReshardingBroadcastSpec( - comm_key=comm_key, - world_size=world_size, - devices_ids=[ - self.collective_group. - device_str_to_device_id_map[sender] - ], - devices_global_rank=[0], - tensor_slices=[src_tile.offset], - recv_tile_shape=src_tile.tile_shape, - dtype=dtype)) + ReshardingBroadcastSpec(comm_key=comm_key, + world_size=world_size, + devices_ids=[], + devices_global_rank=[], + tensor_slices=[], + recv_tile_shape=dst_tile.tile_shape, + dtype=dtype)) + + group_spec.devices_ids.append( + self.collective_group.device_str_to_device_id_map[receiver]) + group_spec.devices_global_rank.append(1 + replica_index) + group_spec.tensor_slices.append(indices_in_dst_tile) comm_config.add( - sender_worker, - self.collective_group.device_str_to_device_id_map[sender]) + receiver_worker, + self.collective_group.device_str_to_device_id_map[receiver]) + + self.communicator_configs.add(comm_config) - for replica_index, receiver in enumerate( - dst_tile.replica_device_strs): - receiver_worker = (self.collective_group. - device_str_to_mesh_worker_map[receiver]) - group_spec = self._broadcast_tasks[ - receiver_worker].setdefault( - broadcast_group, - ReshardingBroadcastSpec( - comm_key=comm_key, - world_size=world_size, - devices_ids=[], - devices_global_rank=[], - tensor_slices=[], - recv_tile_shape=dst_tile.tile_shape, - dtype=dtype)) - - group_spec.devices_ids.append( - self.collective_group. - device_str_to_device_id_map[receiver]) - group_spec.devices_global_rank.append(1 + replica_index) - group_spec.tensor_slices.append(indices_in_dst_tile) - comm_config.add( - receiver_worker, self.collective_group. - device_str_to_device_id_map[receiver]) - - self.communicator_configs.add(comm_config) return self._broadcast_tasks def create_resharding_communicators(self): @@ -600,7 +610,7 @@ def __init__(self, device_strs, src_mesh, dst_mesh): i + len(self.src_mesh.host_ips)] = self.dst_mesh.workers[i] for j in range(dst_mesh.num_devices_per_host): device_str = self.dst_mesh.device_strs[ - i * src_mesh.num_devices_per_host + j] + i * dst_mesh.num_devices_per_host + j] self.device_str_to_rank_map[device_str] = ( i + len(src_mesh.host_ips), j) self.device_str_to_mesh_worker_map[ @@ -853,6 +863,26 @@ def strategy(self): "first.") return self._strategy + def generate_naive_order(self, mode): + """Return the naive order to submit resharding tasks.""" + + order = [] + if mode == "sendrecv": + for i, (dst_tile, src_tiles, + _) in enumerate(self.dst_tile_to_src_tiles_map): + for k, _ in enumerate(dst_tile.replica_device_strs): + for j, _ in enumerate(src_tiles): + order.append((i, k, j)) + elif mode == "broadcast": + for i, (_, src_tiles, + _) in enumerate(self.dst_tile_to_src_tiles_map): + for j, _ in enumerate(src_tiles): + order.append((i, j)) + else: + raise NotImplementedError + + return order + def get_participant_device_strs(self): """Identify all participant device strs (for NCCL setup) in this task spec.""" @@ -881,17 +911,24 @@ class ReshardingStrategy: """A data class for storing resharding communication information. Args: + mode (str): Two choices:["sendrecv", "broadcast"]. per_spec_plans (List[np.ndarray]): `per_spec_plan` is a list a np array, with length as len(spec.dst_tile_to_src_tiles_map), each array is with shape [len(dst_tile.devices), len(src_tiles)]; it specifies for each replica of a dst tile, how it should get the data from src_tiles (src tile replicas). + order (List[Tuple(int, ...)]): in which order we should submit + these nccl communication operation into cuda stream. When mode + is "sendrecv", order is of type List[Tuple(int, int)]; + Otherwise, order is of type List[Tuple(int, int, int)]. is_local_allgather (bool): if this strategy involves post allgather operations. """ - def __init__(self, per_spec_plans, is_local_allgather): + def __init__(self, mode, per_spec_plans, order, is_local_allgather): + self.mode = mode self.per_spec_plans = per_spec_plans + self.order = order self.is_local_allgather = is_local_allgather @@ -937,16 +974,16 @@ def __init__(self, sharded_stages, schedule): self._create_resharding_specs() # Generate a send/recv strategies for all resharding tasks by looking # at their load. - for _, _, var_spec_map in self.task_spec_iter(): + for src_mesh_idx, dst_mesh_idx, var_spec_map in self.task_spec_iter(): for _, spec in var_spec_map.items(): if global_config.resharding_mode == "send_recv": - strategy = ( - self._generate_send_recv_resharding_strategy_by_loads( - spec, self._sender_loads, self._receiver_loads)) + strategy = (self._generate_send_recv_resharding_strategy( + spec, self._schedule.meshes[src_mesh_idx], + self._schedule.meshes[dst_mesh_idx])) else: - strategy = ( - self._generate_broadcast_resharding_strategy_by_loads( - spec, self._sender_loads, self._receiver_loads)) + strategy = (self._generate_broadcast_resharding_strategy( + spec, self._schedule.meshes[src_mesh_idx], + self._schedule.meshes[dst_mesh_idx])) spec.set_resharding_strategy(strategy) @property @@ -1112,6 +1149,35 @@ def task_spec_iter(self): continue yield i, j, self.resharding_specs[i][j] + @staticmethod + def get_resources_info_in_mesh(mesh): + device_strs = [] + device_host_map = {} + nic_constraints = [] + + for i in range(mesh.num_hosts): + ip = mesh.host_info[i]["NodeManagerAddress"] + one_nic_constraint = [] + for device in mesh.devices[i]: + device_str = device_id_to_str(ip, device) + device_strs.append(device_str) + one_nic_constraint.append(device_str) + #TODO: Here we assume there is only one NIC in one host. + device_host_map[device_str] = ip + nic_constraints.append(one_nic_constraint) + return device_strs, device_host_map, nic_constraints + + @staticmethod + def _get_hardware_info_for_loadbalance(src_mesh, dst_mesh): + src_mesh_devices, src_device_host_map, src_nic_constraints = ( + CrossMeshCommunicator.get_resources_info_in_mesh(src_mesh)) + dst_mesh_devices, dst_device_host_map, dst_nic_constraints = ( + CrossMeshCommunicator.get_resources_info_in_mesh(dst_mesh)) + device_host_map = {**src_device_host_map, **dst_device_host_map} + nic_constraints = src_nic_constraints + dst_nic_constraints + return (src_mesh_devices, dst_mesh_devices, device_host_map, + nic_constraints) + @staticmethod def _generate_send_recv_resharding_strategy_by_loads( spec: ReshardingTaskSpec, src_loads, dst_loads): @@ -1137,7 +1203,197 @@ def _generate_send_recv_resharding_strategy_by_loads( src_loads[sender] += src_tileslice.slice_size dst_loads[receiver] += src_tileslice.slice_size per_spec_plans.append(per_spec_plan) - strategy = ReshardingStrategy(per_spec_plans, is_local_allgather) + + strategy = ReshardingStrategy("sendrecv", per_spec_plans, + spec.generate_naive_order("sendrecv"), + is_local_allgather) + return strategy + + def _generate_send_recv_resharding_strategy(self, spec: ReshardingTaskSpec, + src_mesh, dst_mesh): + if global_config.resharding_loadbalance_mode == "normal": + strategy = (self._generate_send_recv_resharding_strategy_by_loads( + spec, self._sender_loads, self._receiver_loads)) + elif global_config.resharding_loadbalance_mode == "no_loadbalance": + strategy = ( + self._generate_send_recv_resharding_strategy_by_no_load(spec)) + elif global_config.resharding_loadbalance_mode in ([ + "loadbalance_size", "loadbalance_order" + ]): + strategy = self.\ + _generate_send_recv_resharding_strategy_by_loadbalance( + spec, src_mesh, dst_mesh) + else: + raise NotImplementedError() + return strategy + + def _generate_broadcast_resharding_strategy(self, spec: ReshardingTaskSpec, + src_mesh, dst_mesh): + if global_config.resharding_loadbalance_mode == "normal": + strategy = (self._generate_broadcast_resharding_strategy_by_loads( + spec, self._sender_loads, self._receiver_loads)) + elif global_config.resharding_loadbalance_mode == "no_loadbalance": + strategy = ( + self._generate_broadcast_resharding_strategy_by_no_load(spec)) + elif global_config.resharding_loadbalance_mode in [ + "loadbalance_size", "loadbalance_order" + ]: + strategy = ( + self._generate_broadcast_resharding_strategy_by_loadbalance( + spec, src_mesh, dst_mesh)) + else: + raise NotImplementedError() + return strategy + + @staticmethod + def _generate_send_recv_resharding_strategy_by_no_load( + spec: ReshardingTaskSpec): + """Generate the resharding strategy by balancing loads.""" + is_local_allgather = spec.final_dst_spec != spec.dst_sharding_spec + per_spec_plans = [] + for dst_tile, src_tileslices, _ in spec.dst_tile_to_src_tiles_map: + # plan is a 2D array + per_spec_plan = np.empty( + (len(dst_tile.replica_device_strs), len(src_tileslices)), + dtype=object) + for receiver_idx, _ in enumerate(dst_tile.replica_device_strs): + for src_tileslice_idx, src_tileslice in enumerate( + src_tileslices): + sender = src_tileslice.replica_device_strs[0] + # Choose an arbitrary sender without considering loads + per_spec_plan[receiver_idx][src_tileslice_idx] = sender + per_spec_plans.append(per_spec_plan) + + strategy = ReshardingStrategy("sendrecv", per_spec_plans, + spec.generate_naive_order("sendrecv"), + is_local_allgather) + return strategy + + @staticmethod + def _generate_send_recv_resharding_strategy_by_loadbalance( + spec, src_mesh, dst_mesh): + """ + Generate the send/recv-based resharding strategy by balancing + loads and along time. + """ + + # pre-process + src_mesh_devices, dst_mesh_devices, device_host_map, nic_constraints = ( + CrossMeshCommunicator._get_hardware_info_for_loadbalance( + src_mesh, dst_mesh)) + + works = [] + for i, (dst_tile, src_tileslices, + _) in enumerate(spec.dst_tile_to_src_tiles_map): + for receiver in dst_tile.replica_device_strs: + for j, src_tileslice in enumerate(src_tileslices): + senders = src_tileslice.replica_device_strs + data_size = src_tileslice.tile_size + works.append( + SingleReshardingLoadBalancingWork( + senders, [receiver], data_size)) + + # solve and get solution + task = ReshardingLoadBalancingTaskSolver(src_mesh_devices, + dst_mesh_devices, + device_host_map, works, + nic_constraints) + + sol_assigned_sender, sol_order = task.solve() + + # post-process + per_spec_plans = [] + rank_to_idx = [] + cnt = 0 + for i, (dst_tile, src_tileslices, + _) in enumerate(spec.dst_tile_to_src_tiles_map): + per_spec_plan = np.empty( + (len(dst_tile.replica_device_strs), len(src_tileslices)), + dtype=object) + for k, receiver in enumerate(dst_tile.replica_device_strs): + for j, src_tileslice in enumerate(src_tileslices): + sender = sol_assigned_sender[cnt] + per_spec_plan[k][j] = sender + rank_to_idx.append((i, k, j)) + cnt += 1 + per_spec_plans.append(per_spec_plan) + + order = [rank_to_idx[i] for i in sol_order] + is_local_allgather = spec.final_dst_spec != spec.dst_sharding_spec + strategy = ReshardingStrategy("sendrecv", per_spec_plans, order, + is_local_allgather) + return strategy + + @staticmethod + def _generate_broadcast_resharding_strategy_by_no_load( + spec: ReshardingTaskSpec): + """ + Generate the broadcast-based resharding strategy by balancing + loads. For each tile, I not only allow one source to provide + the tile. + """ + # pylint: disable=unused-argument + per_spec_plans = [] + for _, src_tileslices, _ in spec.dst_tile_to_src_tiles_map: + per_spec_plan = np.empty((len(src_tileslices),), dtype=object) + + for src_tileslice_idx, src_tileslice in enumerate(src_tileslices): + per_spec_plan[ + src_tileslice_idx] = src_tileslice.replica_device_strs[0] + per_spec_plans.append(per_spec_plan) + strategy = ReshardingStrategy("broadcast", per_spec_plans, + spec.generate_naive_order("broadcast"), + None) + return strategy + + @staticmethod + def _generate_broadcast_resharding_strategy_by_loadbalance( + spec, src_mesh, dst_mesh): + """ + Generate the broadcast-based resharding strategy by balancing + loads and along time. + """ + + # pre-process + src_mesh_devices, dst_mesh_devices, device_host_map, nic_constraints = ( + CrossMeshCommunicator._get_hardware_info_for_loadbalance( + src_mesh, dst_mesh)) + + works = [] + for i, (dst_tile, src_tileslices, + _) in enumerate(spec.dst_tile_to_src_tiles_map): + for j, src_tileslice in enumerate(src_tileslices): + senders = src_tileslice.replica_device_strs + receivers = dst_tile.replica_device_strs + data_size = src_tileslice.tile_size + works.append( + SingleReshardingLoadBalancingWork(senders, receivers, + data_size)) + + # solve and get solution + task = ReshardingLoadBalancingTaskSolver(src_mesh_devices, + dst_mesh_devices, + device_host_map, works, + nic_constraints) + + sol_assigned_sender, sol_order = task.solve() + + # post-process + per_spec_plans = [] + rank_to_idx = [] + cnt = 0 + for i, (dst_tile, src_tileslices, + _) in enumerate(spec.dst_tile_to_src_tiles_map): + per_spec_plan = np.empty((len(src_tileslices),), dtype=object) + for j, src_tileslice in enumerate(src_tileslices): + sender = sol_assigned_sender[cnt] + per_spec_plan[j] = sender + rank_to_idx.append((i, j)) + cnt += 1 + per_spec_plans.append(per_spec_plan) + + order = [rank_to_idx[i] for i in sol_order] + strategy = ReshardingStrategy("broadcast", per_spec_plans, order, None) return strategy @staticmethod @@ -1148,7 +1404,6 @@ def _generate_broadcast_resharding_strategy_by_loads( For each tile, I not only allow one source to provide the tile. """ # pylint: disable=unused-argument - #TODO(hexu): (1) allow for multiple sources. (2) update load on the fly. per_spec_plans = [] dst_loads = None for _, src_tileslices, _ in spec.dst_tile_to_src_tiles_map: @@ -1164,7 +1419,9 @@ def _generate_broadcast_resharding_strategy_by_loads( per_spec_plan[src_tileslice_idx] = sender src_loads[sender] += src_tileslice.slice_size per_spec_plans.append(per_spec_plan) - strategy = ReshardingStrategy(per_spec_plans, None) + strategy = ReshardingStrategy("broadcast", per_spec_plans, + spec.generate_naive_order("broadcast"), + None) return strategy @staticmethod @@ -1179,3 +1436,468 @@ def _args_between(src_stage, dst_stage): src_indices.append(i) dst_indices.append(dst_stage.invars.index(var)) return resharding_vars, src_indices, dst_indices + + +SingleReshardingLoadBalancingWork = namedtuple( + "SingleReshardingLoadBalancingWork", ["senders", "receivers", "data_size"]) +SingleAbstractedLoadBalancingWork = namedtuple( + "SingleAbstractedLoadBalancingWork", + ["sender_ids", "receiver_ids", "duration"]) + + +class ReshardingLoadBalancingTaskSolver: + """This is class of solver for load balancing problem""" + + def __init__(self, + src_mesh_devices, + dst_mesh_devices, + device_host_map, + works, + nic_contraints, + host_bridge_contraints=None): + """We define the load balancing problem in resharding problem. + Here both send_recv and broadcast based implementation could + be formulated in this way. + + Args: + src_mesh_devices: All gpus in src mesh. + dst_mesh_devices: All gpus in dst mesh. + device_host_map: a map from device to its corresponding host. + works (List[SingleReshardingLoadBalancingWork]): all works to + be scheduled in this task. + nic_contraints (List[List[device]]): each list[device] contains + a set of devices that competes for the same NIC. + Now I assmue sender and receiver do not share NIC. + The assumption is met in nic_contraints. + I assume these constraints are disjoint sets. + """ + self.src_mesh_devices = src_mesh_devices + self.dst_mesh_devices = dst_mesh_devices + self.all_devices = list( + set(src_mesh_devices).union(set(dst_mesh_devices))) + self.device_host_map = device_host_map + self.works = works + self.nic_contraints = nic_contraints + self.host_bridge_contraints = host_bridge_contraints + + # self.print_task() + + def solve(self): + """ + Return two data + 1. The first List[device] represents which sender to choose + for each work. + 2. The second List[int] represents the order to execute + these works. + """ + + # Deal with the case when a src device share the same NIC with a tar + # device. Now I assmue they do not share NIC. The assumption is met + # in nic_contraints so we do not need to deal with it in this method. + + tmp_device_to_worker_id_map = { + device: idx for idx, device in enumerate(self.all_devices) + } + for nic_contraint in self.nic_contraints: + min_id = min( + tmp_device_to_worker_id_map[device] for device in nic_contraint) + for device in nic_contraint: + tmp_device_to_worker_id_map[device] = min_id + + device_to_worker_id_map = {} + worker_id_to_devices = {} + n_workers = 0 + for idx, device in enumerate(self.all_devices): + if tmp_device_to_worker_id_map[device] == idx: + device_to_worker_id_map[device] = n_workers + worker_id_to_devices[n_workers] = [device] + n_workers += 1 + else: + group_head_device = self.all_devices[ + tmp_device_to_worker_id_map[device]] + worker_id = device_to_worker_id_map[group_head_device] + device_to_worker_id_map[device] = worker_id + worker_id_to_devices[worker_id].append(device) + + abstract_works = [] + for work in self.works: + sender_ids = set() + for sender in work.senders: + sender_ids.add(device_to_worker_id_map[sender]) + sender_ids = list(sender_ids) + sender_ids.sort() + receiver_ids = set() + for receiver in work.receivers: + receiver_ids.add(device_to_worker_id_map[receiver]) + receiver_ids = list(receiver_ids) + receiver_ids.sort() + time_spent = work.data_size + + abstract_works.append( + SingleAbstractedLoadBalancingWork(sender_ids, receiver_ids, + time_spent)) + + if global_config.resharding_loadbalance_mode == "loadbalance_size": + task = LoadBalancingOverSizeTaskSolver(n_workers, abstract_works) + else: + if global_config.loadbalance_order_algo == "search": + task = LoadBalancingTaskSolverSearchAlgo( + n_workers, abstract_works) + else: + task = LoadBalancingTaskSolverGreedyAlgo( + n_workers, abstract_works) + + sol_assigned_sender_id, sol_order = task.solve() + + sol_assigned_sender = [] + for work, worker_id in zip(self.works, sol_assigned_sender_id): + selected_sender = None + for sender in work.senders: + if device_to_worker_id_map[sender] == worker_id: + selected_sender = sender + break + assert selected_sender is not None + sol_assigned_sender.append(selected_sender) + return sol_assigned_sender, sol_order + + def print_task(self): + print("\nTask[START]") + print(f"src_mesh_devices: {self.src_mesh_devices}") + print(f"dst_mesh_devices: {self.dst_mesh_devices}") + print(f"device_host_map: {self.device_host_map}") + print("works:") + for work in self.works: + print(work) + print("nic_contraints:") + for contraint in self.nic_contraints: + print(contraint) + print("Task[END]\n") + + +class AbstractedLoadBalancingTaskSolver(ABC): + """This is class of solver for abstracted load balancing problem""" + + def __init__(self, n_workers, works): + """We abstract the load balancing problem into this mathematically + clear form. + + Args: + n_workers (int): The total number of single threaded + workers in this loadbalancing task. + works (List[SingleAbstractedLoadBalancingWork]): all works to + be scheduled in this task. + """ + self.n_workers = n_workers + self.n_works = len(works) + self.works = works + self.loads = [0 for _ in range(n_workers)] + + # self.print_task() + + @abstractmethod + def solve(self): + """ + Return two list[int] of length n_works + 1. The first represents which sender to choose for each work. + 2. The second represents the order to execute these works. + """ + raise NotImplementedError + + def print_task(self): + print("AbstractedTask[START]") + print(f"n_workers: {self.n_workers}") + print("works:") + for work in self.works: + print(work) + print("AbstractedTask[END]") + + +class LoadBalancingTaskSolverGreedyAlgo(AbstractedLoadBalancingTaskSolver): + """Implementation of load balance: use randomized greedy algorithm""" + + def find_one_random_concurrent_set_of_works(self, works_ids): + """This method finds one set of works that could be run + concurrently. + + Args: + works_ids (List[int]): The ids of works that could be + selected. + + Returns: + one_concurrent_works_ids (list[int]): The ids of works + selected in this method. + one_concurrent_selected_senders (list[int]): The assigned + senders for the selected works. + """ + + def probability_of_being_selected(loads): + # these weights could be more carefully tuned. + max_weight = max(loads) + weights = [max_weight - weight + 1 for weight in loads] + return weights + + used = [False for _ in range(self.n_workers)] + perm = np.random.permutation(np.array(works_ids)) + one_concurrent_works_ids = [] + one_concurrent_selected_senders = [] + for i in perm: + work = self.works[i] + receivers_availability = True + for receiver in work.receiver_ids: + if used[receiver]: + receivers_availability = False + break + if not receivers_availability: + continue + + available_senders = [] + for sender in work.sender_ids: + if not used[sender]: + available_senders.append(sender) + if not available_senders: + continue + + weights = probability_of_being_selected( + [self.loads[sender] for sender in available_senders]) + selected_sender = random.choices(available_senders, + weights=weights)[0] + + used[selected_sender] = True + for receiver in work.receiver_ids: + used[receiver] = True + + one_concurrent_works_ids.append(i) + one_concurrent_selected_senders.append(selected_sender) + return one_concurrent_works_ids, one_concurrent_selected_senders + + def find_best_concurrent_set_of_works(self, works_ids, n_rounds=100): + """ + One simple strategy is that everytime we choose the maximum number + of works and minimize std and put them into the sequence. + The simple logic behind is to maximize concurrency. + + Args: + works_ids (List[int]): All available works waiting for running. + n_rounds (int, optional): The number of rounds to run for finding + the best set of works. Defaults to 100. + """ + + def calc_std(data): + ave = sum(data) / len(data) + std = (sum((x - ave)**2 for x in data) / len(data))**0.5 + return std + + # def calc_max(A): + # return max(A) + + max_num = 0 + min_std = None + best_concurrent_works_ids = [] + best_concurrent_selected_senders = [] + for _ in range(n_rounds): + one_concurrent_works_ids, one_concurrent_selected_senders = \ + self.find_one_random_concurrent_set_of_works(works_ids) + num = len(one_concurrent_works_ids) + if num < max_num: + continue + + loads = list(self.loads) + for work_id, selected_sender in zip( + one_concurrent_works_ids, one_concurrent_selected_senders): + loads[selected_sender] += self.works[work_id].duration + + # here we could use different criterions + std = calc_std(loads) # calc_max(loads) + # std = calc_std( + # [self.works[i].duration for i in range(one_concurrent_works_ids)] + # ) + + if num > max_num or (num == max_num and std < min_std): + max_num = num + min_std = std + best_concurrent_works_ids = one_concurrent_works_ids + best_concurrent_selected_senders = ( + one_concurrent_selected_senders) + return best_concurrent_works_ids, best_concurrent_selected_senders + + def solve(self): + sol_assigned_sender_id = [None for _ in range(len(self.works))] + sol_order = [] + while True: + available_works_ids = [ + i for i in range(len(self.works)) if i not in sol_order + ] + best_concurrent_works_ids, best_concurrent_selected_senders = \ + self.find_best_concurrent_set_of_works(available_works_ids) + + for work_id, sender_id in zip(best_concurrent_works_ids, + best_concurrent_selected_senders): + sol_order.append(work_id) + sol_assigned_sender_id[work_id] = sender_id + self.loads[sender_id] += self.works[work_id].duration + + if len(sol_order) == len(self.works): + break + + assert None not in sol_assigned_sender_id + + return sol_assigned_sender_id, sol_order + + +class LoadBalancingTaskSolverSearchAlgo(AbstractedLoadBalancingTaskSolver): + """Implementation of load balance: use search algorithm with pruning""" + + def __init__(self, n_workers, works): + super().__init__(n_workers, works) + + self.sol_assigned_sender_id = [None for _ in range(len(self.works))] + self.sol_order = [] + self.minimal_finish_time = None + + self.cur_assigned_sender_id = [None for _ in range(len(self.works))] + self.cur_order = [] + + self.start_time = time.time() + self.search_time_threshold = 1 + + def evaluate_one_solution(self, assigned_sender_id, order): + """Given current task-sender assigment and order to submit + these tasks, this method return the finishing time of each + receiver for the current schedule as solution. + To get the finishing time, this method just simulates the + whole process. + + Args: + assigned_sender_id: This variable contains idx of sender + for each task. + order: The order to submit different tasks. + + Returns: + current_time (list[int]): the time for each receiver + after finishing all the tasks assigned to it. + """ + current_time = [0 for _ in range(self.n_workers)] + + for i in order: + work = self.works[i] + sender_id = assigned_sender_id[i] + mx_time = max([current_time[sender_id]] + [ + current_time[receiver_id] for receiver_id in work.receiver_ids + ]) + current_time[sender_id] = mx_time + work.duration + for receiver_id in work.receiver_ids: + current_time[receiver_id] = mx_time + work.duration + return current_time + + def heuristic(self, current_time, remained_work_ids): + """ Given the current time for each receiver to finish + its assigned works, and the remained work to be + assigned, this method estimate the minimal amount + of time to finish all works. If the minimal amount + of time to finish all works is still longer than + current best solution, then we could prune the current + search branch. + + Args: + current_time (list[int]): the time for each receiver + after finishing all the tasks assigned to it. + remained_work_ids (list[int]): The ids of works remained + to be assigned to workers. + + Returns: + int: the minimal amount of time to finish all works + with current assignment and order schedule. + """ + remained_time_lowerbound = [0 for _ in range(self.n_workers)] + for i in remained_work_ids: + work = self.works[i] + sender_id_with_mintime = -1 + for sender_id in work.sender_ids: + if sender_id_with_mintime == -1: + sender_id_with_mintime = sender_id + elif (remained_time_lowerbound[sender_id] + + current_time[sender_id] < + remained_time_lowerbound[sender_id_with_mintime] + + current_time[sender_id_with_mintime]): + sender_id_with_mintime = sender_id + # heuristic function could be continuely improved. + remained_time_lowerbound[sender_id_with_mintime] += work.duration + for receiver_id in work.receiver_ids: + remained_time_lowerbound[receiver_id] += work.duration + + max_time = max( + x + y for x, y in zip(remained_time_lowerbound, current_time)) + return max_time + + def dfs(self, depth): + """This is the Depth First Search function + to search the order of submitting works + and sender for each work. + + Args: + depth (int): The depth of the DFS; In other + words, we are deciding the depth_th task in + order array. + """ + if time.time() - self.start_time > self.search_time_threshold: + return + + current_time = self.evaluate_one_solution(self.cur_assigned_sender_id, + self.cur_order) + + if depth == len(self.works): + finish_time = max(current_time) + if (self.minimal_finish_time is None or + finish_time < self.minimal_finish_time): + self.minimal_finish_time = finish_time + self.sol_assigned_sender_id = list(self.cur_assigned_sender_id) + self.sol_order = list(self.cur_order) + return + + remained_work_ids = [ + i for i in range(len(self.works)) if i not in self.cur_order + ] + + heuristic = self.heuristic(current_time, remained_work_ids) + if (self.minimal_finish_time is not None and + heuristic > self.minimal_finish_time): + return + + for i in remained_work_ids: + self.cur_order.append(i) + work = self.works[i] + for sender_id in work.sender_ids: + self.cur_assigned_sender_id[i] = sender_id + self.dfs(depth + 1) + self.cur_assigned_sender_id[i] = None + self.cur_order.pop() + + def solve(self): + + self.dfs(depth=0) + + assert None not in self.sol_assigned_sender_id + + return self.sol_assigned_sender_id, self.sol_order + + +class LoadBalancingOverSizeTaskSolver(AbstractedLoadBalancingTaskSolver): + """Implementation of load balance: only consider workers' workloads""" + + def __init__(self, n_workers, works): + super().__init__(n_workers, works) + + self.sol_assigned_sender_id = [None for _ in range(len(self.works))] + self.sol_order = [] + + def solve(self): + for i, work in enumerate(self.works): + loads = {sender: self.loads[sender] for sender in work.sender_ids} + sender = min(loads, key=loads.get) + self.sol_assigned_sender_id[i] = sender + self.loads[sender] += work.duration + self.sol_order.append(i) + + assert None not in self.sol_assigned_sender_id + + return self.sol_assigned_sender_id, self.sol_order diff --git a/benchmark/alpa/benchmark_one_case_moe_inference.py b/benchmark/alpa/benchmark_one_case_moe_inference.py index f434c1331..b9a7d28f5 100644 --- a/benchmark/alpa/benchmark_one_case_moe_inference.py +++ b/benchmark/alpa/benchmark_one_case_moe_inference.py @@ -18,12 +18,13 @@ def create_infer_params_aval(rngkey, model, batch): params = jax.eval_shape(model.init, rngkey, batch["input_ids"], - batch["attention_mask"], - batch["token_type_ids"], batch["position_ids"]) + batch["attention_mask"], batch["token_type_ids"], + batch["position_ids"]) params = jax.eval_shape( lambda p: jax.tree_util.tree_map( lambda x: jnp.asarray(x, dtype=jnp.float16), p), params) - return params + return params + def get_infer_step(parallel_method, model): @@ -41,7 +42,7 @@ def infer_step(params, batch, rng_key): loss = -jnp.sum(labels * jax.nn.log_softmax(logits, axis=-1), axis=-1) loss = (label_mask * loss).sum() / label_mask.sum() return loss - + return parallelize(infer_step, method=parallel_method, donate_argnums=()) @@ -60,8 +61,8 @@ def prepare_moe_inference_input_and_model(benchmark_case, if correct_expert_group_size: rang_factor = 1 expected_expert_group_size = min( - expert_group_size, - batch_size * seq_len // benchmark_case.num_micro_batches // 1 // rang_factor) + expert_group_size, batch_size * seq_len // + benchmark_case.num_micro_batches // 1 // rang_factor) if expected_expert_group_size != expert_group_size: print("- Expected expert group size should be {}, " "but got {}. Will reset it".format(expected_expert_group_size, @@ -152,8 +153,8 @@ def benchmark_moe_inference_internal(benchmark_case, infer_step = get_infer_step(method, model) - (latencies, max_mem_allocated, compilation_times, - executable, per_stage_weight_mem, + (latencies, max_mem_allocated, compilation_times, executable, + per_stage_weight_mem, per_stage_peak_mem) = compile_and_benchmark_pipeshard_inference_executable( benchmark_case.parallel_mode, niter, diff --git a/benchmark/alpa/gen_serving_database.py b/benchmark/alpa/gen_serving_database.py index 15c968ce7..684a31792 100644 --- a/benchmark/alpa/gen_serving_database.py +++ b/benchmark/alpa/gen_serving_database.py @@ -8,7 +8,6 @@ from alpa_serve.profiling import ProfilingDatabase - if __name__ == "__main__": parser = argparse.ArgumentParser() parser.add_argument("--input", type=str, default="inference_prof_res.tsv") diff --git a/benchmark/alpa/resharding/README.md b/benchmark/alpa/resharding/README.md new file mode 100644 index 000000000..5ba4e1421 --- /dev/null +++ b/benchmark/alpa/resharding/README.md @@ -0,0 +1,70 @@ +# Benchmark +This folder contains benchmarking code for cross mesh resharding, corresponding to the experiment section in [On Optimizing the Communication of Model Parallelism](https://arxiv.org/abs/2211.05322). + +To make the benchmark feasible in a short amount of time, this documentation provides: Instructions for benchmarking on an AWS p3.8xlarge cluster. You can use these to quickly run cross mesh resharding using Alpa and get the statistics of the performance. The statistics may be different from that in our papaer if your cluster is not an AWS p3.8xlarge cluster. +There are two types of experiments for benchmarking: +- Single device to multiple devices microbenchmark: corronspond to section 5.1.1 in [On Optimizing the Communication of Model Parallelism](https://arxiv.org/abs/2211.05322). +- Multiple devices to multiple devices microbenchmark: corronspond to section 5.1.2 and 5.3.1 in [On Optimizing the Communication of Model Parallelism](https://arxiv.org/abs/2211.05322). + +## Benchmark Steps + +### Cluster Preparation + +Prepare 5 AWS p3.8xlarge instances and put them in the same Placement Group. + +### Start a Ray Cluster +Alpa uses a distributed framework Ray to manage the cluster and distributed workers. +Here, we provide instructions for manually launching a ray cluster. +You can also refer to the Ray [documentation](https://docs.ray.io/en/latest/cluster/quickstart.html#) for more methods on launching and managing ray clusters. + +1. Pick one node as the head node and run the command below on it + ``` + ray start --head + ``` +2. For all other 4 nodes, connect them to the head node following the instructions printed by the previous command. + ``` + # The command should look like this, but with the ip address and password printed by the previous command. + ray start --address='172.31.31.37:6379' --redis-password='5241590000000000' + ``` + +You can check the cluster status by +``` +ray status +``` +You should be able to see the number of CPUs and GPUs available on your cluster. We should have 20 GPUs to proceed. +All nodes should have alpa installed. + +### Single device to multiple devices microbenchmark +Run all benchmark tests with all GPUs in your cluster. +``` +python3 benchmark.py --suite 1-to-m +``` +The result will be saved in `tmp/1_to_m_result.json`. In this set of experiment, the sender mesh has only 1 GPU. We vary the number of GPUs in the receiver mesh. In the first half of benchmark tests, the receiver mesh has 1 node and the number of GPUs in this node varies from 1 to 4. In the second half of benchmark tests, the number of GPUs per node is fixed at 2, but the number of nodes in receiver mesh grows from 1 to 4. For more details, please refer to `perf_1_to_m_suite` in `suite.py`. + +If you only want to run one test case, +``` +python3 benchmark_cross_mesh_resharding.py --suite 1-to-m --n-nodes 1 --gpu-per-node 4 --resharding-mode send_recv --resharding-loadbalance-mode normal +``` +Here, I take dst mesh to be (1, 4) as example and you could also choose other cases. +You could use `--resharding-mode`, `--resharding-loadbalance-mode`, `--use-local-allgather` flags +to specify the configurations for cross mesh resharding. + +### Multiple devices to multiple devices microbenchmark +Similar to the previous subsection. +``` +python3 benchmark.py --suite n-to-m +``` +The result will be saved in `tmp/n_to_m_result.json`. In this set of experiment, we move to more complicated cases where both the sender mesh and receiver mesh have multiple nodes. For more details, please refer to `perf_n_to_m_suite` in `suite.py`. + +If you only want to run one test case, +``` +python3 benchmark_cross_mesh_resharding.py --suite n-to-m --case case1 --resharding-mode send_recv --resharding-loadbalance-mode normal +``` +Here, I take case1 as example and you could choose other cases by refering to `suite.py`. Same as above, you could +specify the configurations for cross mesh resharding. + +## Result + +By using the above benchmark scripts, you could compare the time spent among different resharding configurations. +And then we could see conclusions in [On Optimizing the Communication of Model Parallelism](https://arxiv.org/abs/2211.05322) from +these statistics. diff --git a/benchmark/alpa/resharding/benchmark.py b/benchmark/alpa/resharding/benchmark.py new file mode 100644 index 000000000..357752e5f --- /dev/null +++ b/benchmark/alpa/resharding/benchmark.py @@ -0,0 +1,101 @@ +"""The entry point of intra-op + inter-op parallelism benchmark.""" +import argparse +import json +import multiprocessing as mp +import os +import time + +from benchmark_cross_mesh_resharding import benchmark_one_case_internal +import suite + + +def benchmark_and_write_to_namespace(result_namespace, *args, **kwargs): + result = benchmark_one_case_internal(*args, **kwargs) + result_namespace.result = result + + +def benchmark_one_case(*args, use_separate_process=False, **kwargs): + if not use_separate_process: + return benchmark_one_case_internal(*args, **kwargs) + ctx = mp.get_context("spawn") + manager = ctx.Manager() + result_namespace = manager.Namespace() + p = ctx.Process(target=benchmark_and_write_to_namespace, + args=(result_namespace, *args), + kwargs=kwargs) + p.start() + p.join() + if p.exitcode != 0: + return -1, -1, [-1], -1, None + return result_namespace.result + + +def benchmark_n_to_m_suite(): + os.makedirs("tmp", exist_ok=True) + + result_file = "tmp/n_to_m_result.json" + result = [] + + benchmark_cases = suite.perf_n_to_m_suite + resharding_config_list = suite.resharding_n_to_m_configs + + # Run all cases + for case_name, benchmark_case in benchmark_cases.items(): + # Run one case + for config in resharding_config_list: + print("Working on {}: {}, config: {}".format( + case_name, str(benchmark_case), str(config))) + one_result = benchmark_one_case( + benchmark_case.src_mesh_shape, benchmark_case.dst_mesh_shape, + benchmark_case.src_sharding_spec, + benchmark_case.dst_sharding_spec, benchmark_case.tensor_shape, + config["resharding_mode"], config["use_local_allgather"], + config["resharding_loadbalance_mode"]) + + print(one_result) + result.append(one_result) + json.dump(result, open(result_file, "w"), indent=4) + + time.sleep(0.1) # for ctrl+c to work + + +def benchmark_1_to_m_suite(): + os.makedirs("tmp", exist_ok=True) + + result_file = "tmp/1_to_m_result.json" + result = [] + + benchmark_cases = suite.perf_1_to_m_suite + resharding_config_list = suite.resharding_1_to_m_configs + + # Run all cases + for case_name, benchmark_case in benchmark_cases.items(): + # Run one case + for config in resharding_config_list: + print("Working on {}: {}, config: {}".format( + case_name, str(benchmark_case), str(config))) + one_result = benchmark_one_case( + benchmark_case.src_mesh_shape, benchmark_case.dst_mesh_shape, + benchmark_case.src_sharding_spec, + benchmark_case.dst_sharding_spec, benchmark_case.tensor_shape, + config["resharding_mode"], config["use_local_allgather"], + config["resharding_loadbalance_mode"]) + print(one_result) + result.append(one_result) + json.dump(result, open(result_file, "w"), indent=4) + + time.sleep(0.1) # for ctrl+c to work + + +if __name__ == "__main__": + parser = argparse.ArgumentParser() + parser.add_argument("--suite", + choices=["1-to-m", "n-to-m"], + type=str, + required=True) + args = parser.parse_args() + + if args.suite == "1-to-m": + benchmark_1_to_m_suite() + else: + benchmark_n_to_m_suite() diff --git a/benchmark/alpa/resharding/benchmark_cross_mesh_resharding.py b/benchmark/alpa/resharding/benchmark_cross_mesh_resharding.py new file mode 100644 index 000000000..2a37b90f6 --- /dev/null +++ b/benchmark/alpa/resharding/benchmark_cross_mesh_resharding.py @@ -0,0 +1,288 @@ +"""Test cross-mesh resharding.""" +import argparse + +from jax import xla +from jax.core import Var +from jax._src.abstract_arrays import ShapedArray +from jax.interpreters.pxla import spec_to_indices +import jax.numpy as jnp +import numpy as np +import ray + +from alpa import init +from alpa.device_mesh import (create_remote_array_refs, + get_global_virtual_physical_mesh) +from alpa.mesh_executable import next_mesh_executable_uuid +from alpa.global_env import global_config +from alpa.pipeline_parallel.runtime_emitter import PipelineInstEmitter +from alpa.pipeline_parallel.cross_mesh_resharding import ( + CollectiveGroup, ReshardingTaskSpec, CrossMeshCommunicator, + SymbolicReshardingTask, SymbolicBroadcastReshardingTask) +from alpa.pipeline_parallel.pipeshard_executable import ( + AllocateZeroWorkerExecutableConfig, PipelineInstruction, + PipeshardMeshWorkerExecuable) +from alpa.pipeline_parallel.resharding_tensor import VirtualDistributedArray +from alpa.util import get_shard_shape +from alpa.timer import timers + +import suite + + +def get_device_meshes(src_mesh_shape, dst_mesh_shape): + virtual_mesh = get_global_virtual_physical_mesh() + src_num_host = src_mesh_shape[0] + dst_num_host = dst_mesh_shape[0] + assert virtual_mesh.num_hosts >= src_num_host+dst_num_host,\ + "Error: There are not enough nodes for this test case" + src_mesh = virtual_mesh.slice_2d(range(src_num_host), + [range(src_mesh_shape[1])] * + src_num_host).get_physical_mesh() + dst_host_indices = range(src_num_host, src_num_host + dst_num_host) + dst_device_indices = [range(dst_mesh_shape[1])] * dst_num_host + dst_mesh = virtual_mesh.slice_2d(dst_host_indices, + dst_device_indices).get_physical_mesh() + return src_mesh, dst_mesh + + +def get_mean_and_variance(results): + assert len(results) == 13 + results = results[3:] + mean = np.mean(results) + var = np.var(results) + return mean, var + + +def benchmark_one_case_internal( + src_mesh_shape, + dst_mesh_shape, + src_sharding_spec, + dst_sharding_spec, + tensor_shape, + resharding_mode="send_recv", + use_local_allgather=True, + resharding_loadbalance_mode="normal", +): + + global_config.resharding_mode = resharding_mode + global_config.resharding_loadbalance_mode = resharding_loadbalance_mode + global_config.use_local_allgather = use_local_allgather + + init(cluster="ray") + + src_mesh, dst_mesh = get_device_meshes(src_mesh_shape, dst_mesh_shape) + + var = Var(0, "", ShapedArray(tensor_shape, jnp.int32)) + + # Resharding task spec and send/recv strategy + src_loads = {src: 0 for src in src_mesh.device_strs} + dst_loads = {dst: 0 for dst in dst_mesh.device_strs} + if resharding_mode == "send_recv": + rewrite_dst_sharding_spec = CrossMeshCommunicator._rewrite_allgather_spec( + dst_sharding_spec, dst_mesh.num_hosts, var.aval.shape) + else: + rewrite_dst_sharding_spec = dst_sharding_spec + src_array = VirtualDistributedArray(device_mesh=src_mesh, + aval=var.aval, + sharding_spec=src_sharding_spec) + dst_array = VirtualDistributedArray(device_mesh=dst_mesh, + aval=var.aval, + sharding_spec=rewrite_dst_sharding_spec) + task_spec = ReshardingTaskSpec(src_array, dst_array, dst_sharding_spec) + + if resharding_mode == "send_recv": + if global_config.resharding_loadbalance_mode == "normal": + strategy = (CrossMeshCommunicator. + _generate_send_recv_resharding_strategy_by_loads( + task_spec, src_loads, dst_loads)) + elif global_config.resharding_loadbalance_mode == "no_loadbalance": + strategy = ( + CrossMeshCommunicator. + _generate_send_recv_resharding_strategy_by_no_load(task_spec)) + elif global_config.resharding_loadbalance_mode in [ + "loadbalance_size", "loadbalance_order" + ]: + strategy = (CrossMeshCommunicator. + _generate_send_recv_resharding_strategy_by_loadbalance( + task_spec, src_mesh, dst_mesh)) + else: + if global_config.resharding_loadbalance_mode == "normal": + strategy = (CrossMeshCommunicator. + _generate_broadcast_resharding_strategy_by_loads( + task_spec, src_loads, dst_loads)) + elif global_config.resharding_loadbalance_mode == "no_loadbalance": + strategy = ( + CrossMeshCommunicator. + _generate_broadcast_resharding_strategy_by_no_load(task_spec)) + elif global_config.resharding_loadbalance_mode in [ + "loadbalance_size", "loadbalance_order" + ]: + strategy = (CrossMeshCommunicator. + _generate_broadcast_resharding_strategy_by_loadbalance( + task_spec, src_mesh, dst_mesh)) + + task_spec.set_resharding_strategy(strategy) + + # Resharding task. Compile send/recv from strategy and allgather. + collective_group = CollectiveGroup(task_spec.get_participant_device_strs(), + src_mesh, dst_mesh) + if global_config.eagerly_create_communicators: + collective_group.instantiate_now() + else: + collective_group.instantiate() + if resharding_mode == "send_recv": + task = SymbolicReshardingTask(task_spec, collective_group, src_mesh, + dst_mesh) + else: + task = SymbolicBroadcastReshardingTask(task_spec, collective_group, + src_mesh, dst_mesh) + + if global_config.eagerly_create_communicators: + task.create_resharding_communicators() + + # Compile pipeline instructions + instruction_lists = {worker: [] for worker in src_mesh.workers} + for worker in dst_mesh.workers: + instruction_lists[worker] = [] + executable_config_lists = {worker: [] for worker in dst_mesh.workers} + src_uuid = 21474 + dst_uuid = 21475 + # allocate the buffer + exec_uuid = next_mesh_executable_uuid() + config = AllocateZeroWorkerExecutableConfig( + exec_uuid, [get_shard_shape(var.aval, rewrite_dst_sharding_spec)], + [var.aval.dtype]) + output_uuids = [dst_uuid] + for worker in dst_mesh.workers: + executable_config_lists[worker].append(config) + in_uuids = [] + out_uuids = output_uuids + instruction_lists[worker].append( + PipelineInstruction.run(config.exec_uuid, + in_uuids, + out_uuids, { + "sync_before": False, + "sync_after": False + }, + info="allocate zero for recv")) + # Create resharding task + + if resharding_mode == "send_recv": + PipelineInstEmitter._compile_resharding_task(src_uuid, task, dst_uuid, + instruction_lists) + else: + PipelineInstEmitter._compile_broadcast_resharding_task( + src_mesh, src_uuid, task, dst_uuid, instruction_lists) + + exec_uuids = {} + + # Compile Pipeline Executable + for worker in src_mesh.workers: + exec_uuid = next_mesh_executable_uuid() + # print(worker, exec_uuid) + worker.put_executable.remote(exec_uuid, PipeshardMeshWorkerExecuable, + instruction_lists[worker], [src_uuid], [], + [], [], [], + [False] * src_mesh.num_devices_per_host) + exec_uuids[worker] = exec_uuid + for worker in dst_mesh.workers: + exec_uuid = next_mesh_executable_uuid() + # print(worker, exec_uuid) + worker.put_executable.remote(exec_uuid, PipeshardMeshWorkerExecuable, + instruction_lists[worker], [], [dst_uuid], + executable_config_lists[worker], [], [], + [False] * dst_mesh.num_devices_per_host) + exec_uuids[worker] = exec_uuid + + # Prepare array and shard args + test_array = np.arange(np.prod(var.aval.shape), + dtype=var.aval.dtype).reshape(var.aval.shape) + indices = spec_to_indices(var.aval.shape, src_sharding_spec) + test_array = xla.canonicalize_dtype(test_array) + input_refs = src_mesh.shard_args_to_bufs([indices], (False,), (False,), + None, [test_array]) + input_refs = np.array(input_refs) + input_uuids = [ref.uuid for ref in input_refs] + output_refs, output_uuids = create_remote_array_refs(dst_mesh) + + # Run executables + time_spend = [] + for _ in range(13): + timers("overall_resharding_time").start() + for worker in src_mesh.workers: + worker.run_executable.remote(exec_uuids[worker], + input_uuids, [], + sync_for_timer=True, + collect_trace=False) + for worker in dst_mesh.workers: + worker.run_executable.remote(exec_uuids[worker], [], + output_uuids, + sync_for_timer=True, + collect_trace=False) + + dst_mesh.sync_workers(sync_all_devices=True) + timers("overall_resharding_time").stop() + time_spend.append(timers("overall_resharding_time").elapsed(mode="sum")) + timers("overall_resharding_time").reset() + + mean_time, var_time = get_mean_and_variance(time_spend) + result = { + "src_mesh_shape": src_mesh_shape, + "dst_mesh_shape": dst_mesh_shape, + "src_sharding_spec": str(src_sharding_spec), + "dst_sharding_spec": str(dst_sharding_spec), + "tensor_shape": tensor_shape, + "resharding_mode": resharding_mode, + "use_local_allgather": use_local_allgather, + "resharding_loadbalance_mode": resharding_loadbalance_mode, + "exec_time_mean": mean_time, + "exec_time_var": var_time + } + + # Delete executables + for worker in src_mesh.workers: + worker.delete_executable.remote(exec_uuids[worker]) + for worker in dst_mesh.workers: + worker.delete_executable.remote(exec_uuids[worker]) + + src_mesh.shutdown() + dst_mesh.shutdown() + + return result + + +if __name__ == "__main__": + parser = argparse.ArgumentParser() + parser.add_argument("--suite", + type=str, + required=True, + choices=["1-to-m", "n-to-m"]) + parser.add_argument("--case", type=str) + parser.add_argument("--n-nodes", type=int, default=1) + parser.add_argument("--gpu-per-node", type=int, default=1) + parser.add_argument("--resharding-mode", + type=str, + required=True, + choices=["send_recv", "broadcast"]) + parser.add_argument("--resharding-loadbalance-mode", + type=str, + required=True, + choices=[ + "normal", "no_loadbalance", "loadbalance_size", + "loadbalance_order" + ]) + parser.add_argument("--use-local-allgather", action="store_true") + parser.add_argument("--disable-tqdm", action="store_true") + args = parser.parse_args() + + if args.suite == "1-to-m": + case = suite.perf_1_to_m_suite[(args.n_nodes, args.gpu_per_node)] + else: + case = suite.perf_n_to_m_suite[args.case] + + result = benchmark_one_case_internal( + case.src_mesh_shape, case.dst_mesh_shape, case.src_sharding_spec, + case.dst_sharding_spec, case.tensor_shape, args.resharding_mode, + args.use_local_allgather, args.resharding_loadbalance_mode) + print(result) + +# python benchmark_cross_mesh_resharding.py --case case1 --resharding-mode broadcast --resharding-loadbalance-mode normal diff --git a/benchmark/alpa/resharding/suite.py b/benchmark/alpa/resharding/suite.py new file mode 100644 index 000000000..2821ea9e3 --- /dev/null +++ b/benchmark/alpa/resharding/suite.py @@ -0,0 +1,183 @@ +"""Benchmark suites for cross mesh resharding microbenchmarks.""" +from collections import namedtuple +from jax.interpreters.pxla import (Chunked, NoSharding, Replicated, ShardedAxis, + ShardingSpec) + +BenchmarkCase = namedtuple("BenchmarkCase", [ + "src_mesh_shape", "dst_mesh_shape", "tensor_shape", "src_sharding_spec", + "dst_sharding_spec" +]) + +perf_n_to_m_suite = { + "case1": + BenchmarkCase( + (2, 4), + (2, 4), + # (1024 // 8, 1024, 512), + (1024, 1024, 512), + ShardingSpec([Chunked( + [2]), NoSharding(), NoSharding()], + [ShardedAxis(0), Replicated(4)]), + ShardingSpec([Chunked( + [2]), NoSharding(), NoSharding()], + [ShardedAxis(0), Replicated(4)]), + ), + "case2": + BenchmarkCase( + (2, 4), + (2, 4), + # (1024 // 8, 1024, 512), + (1024, 1024, 512), + ShardingSpec( + [NoSharding(), NoSharding(), + NoSharding()], [Replicated(8)]), + ShardingSpec([Chunked( + [2]), NoSharding(), NoSharding()], + [ShardedAxis(0), Replicated(4)]), + ), + "case3": + BenchmarkCase( + (2, 4), + (2, 4), + # (1024 // 8, 1024, 512), + (1024, 1024, 512), + ShardingSpec( + [NoSharding(), Chunked([2]), + NoSharding()], [ShardedAxis(0), Replicated(4)]), + ShardingSpec([Chunked( + [2]), NoSharding(), NoSharding()], + [ShardedAxis(0), Replicated(4)]), + ), + "case4": + BenchmarkCase( + (2, 4), + (2, 4), + # (1024 // 8, 1024, 512), + (1024, 1024, 512), + ShardingSpec( + [NoSharding(), Chunked([8]), + NoSharding()], [ShardedAxis(0)]), + ShardingSpec([Chunked( + [8]), NoSharding(), NoSharding()], [ShardedAxis(0)]), + ), + "case5": + BenchmarkCase( + (2, 4), + (2, 4), + # (1024 // 8, 1024, 512), + (1024, 1024, 512), + ShardingSpec([Chunked( + [4]), NoSharding(), NoSharding()], + [Replicated(2), ShardedAxis(0)]), + ShardingSpec([Chunked( + [2]), NoSharding(), NoSharding()], + [ShardedAxis(0), Replicated(4)]), + ), + "case6": + BenchmarkCase( + (2, 4), + (3, 4), + # (1024*3//8, 1024, 170), + (1024 * 3, 1024, 170), + ShardingSpec([Chunked( + [2]), NoSharding(), NoSharding()], + [ShardedAxis(0), Replicated(4)]), + ShardingSpec([Chunked( + [3]), NoSharding(), NoSharding()], + [ShardedAxis(0), Replicated(4)]), + ), + "case7": + BenchmarkCase( + (1, 4), + (2, 4), + # (1024 // 8, 1024, 512), + (1024, 1024, 512), + ShardingSpec([Chunked( + [4]), NoSharding(), NoSharding()], [ShardedAxis(0)]), + ShardingSpec( + [NoSharding(), NoSharding(), + NoSharding()], [Replicated(4)]), + ), + "case8": + BenchmarkCase( + (1, 4), + (2, 4), + # (1024 // 8, 1024, 512), + (1024, 1024, 512), + ShardingSpec([Chunked( + [4]), NoSharding(), NoSharding()], [ShardedAxis(0)]), + ShardingSpec( + [NoSharding(), NoSharding(), + NoSharding()], [Replicated(4)]), + ), + "case9": + BenchmarkCase( + (2, 4), + (2, 4), + # (1024 // 8, 1024, 512), + (1024, 1024, 512), + ShardingSpec( + [NoSharding(), Chunked([2]), + NoSharding()], [ShardedAxis(0), Replicated(4)]), + ShardingSpec( + [NoSharding(), NoSharding(), + Chunked([2])], [ShardedAxis(0), Replicated(4)]), + ), +} + +resharding_n_to_m_configs = [ + { + "resharding_mode": "send_recv", + "resharding_loadbalance_mode": "normal", + "use_local_allgather": False + }, + { + "resharding_mode": "send_recv", + "resharding_loadbalance_mode": "normal", + "use_local_allgather": True + }, + { + "resharding_mode": "broadcast", + "resharding_loadbalance_mode": "no_loadbalance", + "use_local_allgather": False + }, + { + "resharding_mode": "broadcast", + "resharding_loadbalance_mode": "loadbalance_size", + "use_local_allgather": False + }, + { + "resharding_mode": "broadcast", + "resharding_loadbalance_mode": "loadbalance_order", + "use_local_allgather": False + }, +] + +perf_1_to_m_suite = {(n_node, gpu_per_node): BenchmarkCase( + (1, 1), + (n_node, gpu_per_node), + (1 << 28,), + ShardingSpec([NoSharding()], [Replicated(1)]), + ShardingSpec([NoSharding()], [Replicated(n_node * gpu_per_node)]), +) for n_node, gpu_per_node in [(1, 1), (1, 2), (1, 3), (1, 4), (2, + 2), (3, + 2), (4, 2)] + } + +resharding_1_to_m_configs = [ + { + "resharding_mode": "send_recv", + "resharding_loadbalance_mode": "normal", + "use_local_allgather": False + }, + { + "resharding_mode": "send_recv", + "resharding_loadbalance_mode": "normal", + "use_local_allgather": True + }, + { + "resharding_mode": "broadcast", + "resharding_loadbalance_mode": "normal", + "use_local_allgather": False + }, +] diff --git a/benchmark/alpa/suite_inference_moe.py b/benchmark/alpa/suite_inference_moe.py index 2c3437ef7..12449607f 100644 --- a/benchmark/alpa/suite_inference_moe.py +++ b/benchmark/alpa/suite_inference_moe.py @@ -35,12 +35,13 @@ def get_config(model_config, "uniform", parallel_args) profile_suite[num_gpus].append(case) + ## generate inference profiling results get_config(moe_specs["1.3B"], [1, 2, 4, 8, 16], [1], [1, 2, 4, 8], [1], - [1, 2, 4, 8, 16]) + [1, 2, 4, 8, 16]) get_config(moe_specs["2.4B"], [1, 2, 4, 8, 16], [1], [1, 2, 4, 8], [1], - [1, 2, 4, 8, 16]) + [1, 2, 4, 8, 16]) get_config(moe_specs["7.1B"], [1, 2, 4, 8, 16], [1], [1, 2, 4, 8], [1], - [1, 2, 4, 8, 16]) + [1, 2, 4, 8, 16]) get_config(moe_specs["10B"], [1, 2, 4, 8, 16], [1], [1, 2, 4, 8], [1], - [1, 2, 4, 8, 16]) + [1, 2, 4, 8, 16])