Skip to content
This repository was archived by the owner on Oct 19, 2024. It is now read-only.

[FEATURE] load balance in cross-mesh resharding #798

Merged
merged 16 commits into from
Dec 17, 2022
Merged
12 changes: 8 additions & 4 deletions alpa/device_mesh.py
Original file line number Diff line number Diff line change
Expand Up @@ -599,10 +599,14 @@ def get_live_buffer_uuids(self):
return list(self.buffers.keys())

##### Other Functions #####
def sync(self):
def sync(self, sync_all_devices=False):
# We sync one device instead of all for smaller runtime overhead.
# This is correct because of SPMD.
self.local_devices[0].synchronize_all_activity()
if sync_all_devices:
for device in self.local_devices:
device.synchronize_all_activity()
else:
self.local_devices[0].synchronize_all_activity()

def sync_all(self):
for device in self.local_devices:
Expand Down Expand Up @@ -1396,8 +1400,8 @@ def reset_memory_stats(self):
ray.get(worker.reset_memory_stats.remote())

##### Other Functions #####
def sync_workers(self):
ray.get([w.sync.remote() for w in self.workers])
def sync_workers(self, sync_all_devices=False):
ray.get([w.sync.remote(sync_all_devices) for w in self.workers])

def sync_move_workers(self):
ray.get([w.sync_move_worker.remote() for w in self.workers])
Expand Down
5 changes: 5 additions & 0 deletions alpa/global_env.py
Original file line number Diff line number Diff line change
Expand Up @@ -65,6 +65,11 @@ def __init__(self):
# "xla_extension"}
self.nccl_mode = "cupy"
self.enable_overlapping = False
# Cross mesh resharding load balancing mode.
# Possible choices: {"normal", "no_loadbalance",
# "loadbalance_size", "loadbalance_order"}
self.resharding_loadbalance_mode = "normal"
self.loadbalance_order_algo = "greedy"

########## Options of benchmark ##########
# If true, the system is allowed to use dummy values during
Expand Down
938 changes: 830 additions & 108 deletions alpa/pipeline_parallel/cross_mesh_resharding.py

Large diffs are not rendered by default.

17 changes: 9 additions & 8 deletions benchmark/alpa/benchmark_one_case_moe_inference.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,12 +18,13 @@

def create_infer_params_aval(rngkey, model, batch):
params = jax.eval_shape(model.init, rngkey, batch["input_ids"],
batch["attention_mask"],
batch["token_type_ids"], batch["position_ids"])
batch["attention_mask"], batch["token_type_ids"],
batch["position_ids"])
params = jax.eval_shape(
lambda p: jax.tree_util.tree_map(
lambda x: jnp.asarray(x, dtype=jnp.float16), p), params)
return params
return params


def get_infer_step(parallel_method, model):

Expand All @@ -41,7 +42,7 @@ def infer_step(params, batch, rng_key):
loss = -jnp.sum(labels * jax.nn.log_softmax(logits, axis=-1), axis=-1)
loss = (label_mask * loss).sum() / label_mask.sum()
return loss

return parallelize(infer_step, method=parallel_method, donate_argnums=())


Expand All @@ -60,8 +61,8 @@ def prepare_moe_inference_input_and_model(benchmark_case,
if correct_expert_group_size:
rang_factor = 1
expected_expert_group_size = min(
expert_group_size,
batch_size * seq_len // benchmark_case.num_micro_batches // 1 // rang_factor)
expert_group_size, batch_size * seq_len //
benchmark_case.num_micro_batches // 1 // rang_factor)
if expected_expert_group_size != expert_group_size:
print("- Expected expert group size should be {}, "
"but got {}. Will reset it".format(expected_expert_group_size,
Expand Down Expand Up @@ -152,8 +153,8 @@ def benchmark_moe_inference_internal(benchmark_case,

infer_step = get_infer_step(method, model)

(latencies, max_mem_allocated, compilation_times,
executable, per_stage_weight_mem,
(latencies, max_mem_allocated, compilation_times, executable,
per_stage_weight_mem,
per_stage_peak_mem) = compile_and_benchmark_pipeshard_inference_executable(
benchmark_case.parallel_mode,
niter,
Expand Down
1 change: 0 additions & 1 deletion benchmark/alpa/gen_serving_database.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,6 @@

from alpa_serve.profiling import ProfilingDatabase


if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument("--input", type=str, default="inference_prof_res.tsv")
Expand Down
70 changes: 70 additions & 0 deletions benchmark/alpa/resharding/README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,70 @@
# Benchmark
This folder contains benchmarking code for cross mesh resharding, corresponding to the experiment section in [On Optimizing the Communication of Model Parallelism](https://arxiv.org/abs/2211.05322).

To make the benchmark feasible in a short amount of time, this documentation provides: Instructions for benchmarking on an AWS p3.8xlarge cluster. You can use these to quickly run cross mesh resharding using Alpa and get the statistics of the performance. The statistics may be different from that in our papaer if your cluster is not an AWS p3.8xlarge cluster.
There are two types of experiments for benchmarking:
- Single device to multiple devices microbenchmark: corronspond to section 5.1.1 in [On Optimizing the Communication of Model Parallelism](https://arxiv.org/abs/2211.05322).
- Multiple devices to multiple devices microbenchmark: corronspond to section 5.1.2 and 5.3.1 in [On Optimizing the Communication of Model Parallelism](https://arxiv.org/abs/2211.05322).

## Benchmark Steps

### Cluster Preparation

Prepare 5 AWS p3.8xlarge instances and put them in the same Placement Group.

### Start a Ray Cluster
Alpa uses a distributed framework Ray to manage the cluster and distributed workers.
Here, we provide instructions for manually launching a ray cluster.
You can also refer to the Ray [documentation](https://docs.ray.io/en/latest/cluster/quickstart.html#) for more methods on launching and managing ray clusters.

1. Pick one node as the head node and run the command below on it
```
ray start --head
```
2. For all other 4 nodes, connect them to the head node following the instructions printed by the previous command.
```
# The command should look like this, but with the ip address and password printed by the previous command.
ray start --address='172.31.31.37:6379' --redis-password='5241590000000000'
```

You can check the cluster status by
```
ray status
```
You should be able to see the number of CPUs and GPUs available on your cluster. We should have 20 GPUs to proceed.
All nodes should have alpa installed.

### Single device to multiple devices microbenchmark
Run all benchmark tests with all GPUs in your cluster.
```
python3 benchmark.py --suite 1-to-m
```
The result will be saved in `tmp/1_to_m_result.json`. In this set of experiment, the sender mesh has only 1 GPU. We vary the number of GPUs in the receiver mesh. In the first half of benchmark tests, the receiver mesh has 1 node and the number of GPUs in this node varies from 1 to 4. In the second half of benchmark tests, the number of GPUs per node is fixed at 2, but the number of nodes in receiver mesh grows from 1 to 4. For more details, please refer to `perf_1_to_m_suite` in `suite.py`.

If you only want to run one test case,
```
python3 benchmark_cross_mesh_resharding.py --suite 1-to-m --n-nodes 1 --gpu-per-node 4 --resharding-mode send_recv --resharding-loadbalance-mode normal
```
Here, I take dst mesh to be (1, 4) as example and you could also choose other cases.
You could use `--resharding-mode`, `--resharding-loadbalance-mode`, `--use-local-allgather` flags
to specify the configurations for cross mesh resharding.

### Multiple devices to multiple devices microbenchmark
Similar to the previous subsection.
```
python3 benchmark.py --suite n-to-m
```
The result will be saved in `tmp/n_to_m_result.json`. In this set of experiment, we move to more complicated cases where both the sender mesh and receiver mesh have multiple nodes. For more details, please refer to `perf_n_to_m_suite` in `suite.py`.

If you only want to run one test case,
```
python3 benchmark_cross_mesh_resharding.py --suite n-to-m --case case1 --resharding-mode send_recv --resharding-loadbalance-mode normal
```
Here, I take case1 as example and you could choose other cases by refering to `suite.py`. Same as above, you could
specify the configurations for cross mesh resharding.

## Result

By using the above benchmark scripts, you could compare the time spent among different resharding configurations.
And then we could see conclusions in [On Optimizing the Communication of Model Parallelism](https://arxiv.org/abs/2211.05322) from
these statistics.
101 changes: 101 additions & 0 deletions benchmark/alpa/resharding/benchmark.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,101 @@
"""The entry point of intra-op + inter-op parallelism benchmark."""
import argparse
import json
import multiprocessing as mp
import os
import time

from benchmark_cross_mesh_resharding import benchmark_one_case_internal
import suite


def benchmark_and_write_to_namespace(result_namespace, *args, **kwargs):
result = benchmark_one_case_internal(*args, **kwargs)
result_namespace.result = result


def benchmark_one_case(*args, use_separate_process=False, **kwargs):
if not use_separate_process:
return benchmark_one_case_internal(*args, **kwargs)
ctx = mp.get_context("spawn")
manager = ctx.Manager()
result_namespace = manager.Namespace()
p = ctx.Process(target=benchmark_and_write_to_namespace,
args=(result_namespace, *args),
kwargs=kwargs)
p.start()
p.join()
if p.exitcode != 0:
return -1, -1, [-1], -1, None
return result_namespace.result


def benchmark_n_to_m_suite():
os.makedirs("tmp", exist_ok=True)

result_file = "tmp/n_to_m_result.json"
result = []

benchmark_cases = suite.perf_n_to_m_suite
resharding_config_list = suite.resharding_n_to_m_configs

# Run all cases
for case_name, benchmark_case in benchmark_cases.items():
# Run one case
for config in resharding_config_list:
print("Working on {}: {}, config: {}".format(
case_name, str(benchmark_case), str(config)))
one_result = benchmark_one_case(
benchmark_case.src_mesh_shape, benchmark_case.dst_mesh_shape,
benchmark_case.src_sharding_spec,
benchmark_case.dst_sharding_spec, benchmark_case.tensor_shape,
config["resharding_mode"], config["use_local_allgather"],
config["resharding_loadbalance_mode"])

print(one_result)
result.append(one_result)
json.dump(result, open(result_file, "w"), indent=4)

time.sleep(0.1) # for ctrl+c to work


def benchmark_1_to_m_suite():
os.makedirs("tmp", exist_ok=True)

result_file = "tmp/1_to_m_result.json"
result = []

benchmark_cases = suite.perf_1_to_m_suite
resharding_config_list = suite.resharding_1_to_m_configs

# Run all cases
for case_name, benchmark_case in benchmark_cases.items():
# Run one case
for config in resharding_config_list:
print("Working on {}: {}, config: {}".format(
case_name, str(benchmark_case), str(config)))
one_result = benchmark_one_case(
benchmark_case.src_mesh_shape, benchmark_case.dst_mesh_shape,
benchmark_case.src_sharding_spec,
benchmark_case.dst_sharding_spec, benchmark_case.tensor_shape,
config["resharding_mode"], config["use_local_allgather"],
config["resharding_loadbalance_mode"])
print(one_result)
result.append(one_result)
json.dump(result, open(result_file, "w"), indent=4)

time.sleep(0.1) # for ctrl+c to work


if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument("--suite",
choices=["1-to-m", "n-to-m"],
type=str,
required=True)
args = parser.parse_args()

if args.suite == "1-to-m":
benchmark_1_to_m_suite()
else:
benchmark_n_to_m_suite()
Loading