Skip to content
New issue

Have a question about this project? # for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “#”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? # to your account

[core][distributed] add stateless_init_process_group #10072

Merged
merged 8 commits into from
Nov 7, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion .buildkite/test-pipeline.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -120,6 +120,7 @@ steps:
- tests/spec_decode/e2e/test_integration_dist_tp4
- tests/compile
commands:
- pytest -v -s distributed/test_utils.py
- pytest -v -s compile/test_basic_correctness.py
- pytest -v -s distributed/test_pynccl.py
- pytest -v -s spec_decode/e2e/test_integration_dist_tp4.py
Expand Down Expand Up @@ -431,7 +432,6 @@ steps:
- pip install -e ./plugins/vllm_add_dummy_model
- pytest -v -s distributed/test_distributed_oot.py
- CUDA_VISIBLE_DEVICES=0,1 pytest -v -s test_sharded_state_loader.py
- CUDA_VISIBLE_DEVICES=0,1 pytest -v -s distributed/test_utils.py

- label: Multi-step Tests (4 GPUs) # 36min
working_dir: "/vllm-workspace/tests"
Expand Down
75 changes: 73 additions & 2 deletions tests/distributed/test_utils.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,15 @@
import pytest
import ray
import torch
import torch.distributed as dist

import vllm.envs as envs
from vllm.distributed.utils import stateless_init_process_group
from vllm.utils import (cuda_device_count_stateless,
update_environment_variables)

from ..utils import multi_gpu_test


@ray.remote
class _CUDADeviceCountStatelessTestActor:
Expand All @@ -24,10 +30,75 @@ def test_cuda_device_count_stateless():
CUDA_VISIBLE_DEVICES is changed."""
actor = _CUDADeviceCountStatelessTestActor.options( # type: ignore
num_gpus=2).remote()
assert sorted(ray.get(
actor.get_cuda_visible_devices.remote()).split(",")) == ["0", "1"]
assert len(
sorted(ray.get(
actor.get_cuda_visible_devices.remote()).split(","))) == 2
assert ray.get(actor.get_count.remote()) == 2
ray.get(actor.set_cuda_visible_devices.remote("0"))
assert ray.get(actor.get_count.remote()) == 1
ray.get(actor.set_cuda_visible_devices.remote(""))
assert ray.get(actor.get_count.remote()) == 0


def cpu_worker(rank, WORLD_SIZE):
pg1 = stateless_init_process_group(init_method="tcp://127.0.0.1:29500",
rank=rank,
world_size=WORLD_SIZE,
backend="gloo")
if rank <= 2:
pg2 = stateless_init_process_group(init_method="tcp://127.0.0.1:29501",
rank=rank,
world_size=3,
backend="gloo")
data = torch.tensor([rank])
dist.all_reduce(data, op=dist.ReduceOp.SUM, group=pg1)
if rank <= 2:
dist.all_reduce(data, op=dist.ReduceOp.SUM, group=pg2)
item = data[0].item()
print(f"rank: {rank}, item: {item}")
if rank == 3:
assert item == 6
else:
assert item == 18


def gpu_worker(rank, WORLD_SIZE):
pg1 = stateless_init_process_group(init_method="tcp://127.0.0.1:29502",
rank=rank,
world_size=WORLD_SIZE,
backend="nccl")
if rank <= 2:
pg2 = stateless_init_process_group(init_method="tcp://127.0.0.1:29503",
rank=rank,
world_size=3,
backend="nccl")
torch.cuda.set_device(rank)
data = torch.tensor([rank]).cuda()
dist.all_reduce(data, op=dist.ReduceOp.SUM, group=pg1)
if rank <= 2:
dist.all_reduce(data, op=dist.ReduceOp.SUM, group=pg2)
item = data[0].item()
print(f"rank: {rank}, item: {item}")
if rank == 3:
assert item == 6
else:
assert item == 18


@multi_gpu_test(num_gpus=4)
@pytest.mark.parametrize("worker", [cpu_worker, gpu_worker])
def test_stateless_init_process_group(worker):
WORLD_SIZE = 4
from multiprocessing import get_context
ctx = get_context("fork")
processes = []
for i in range(WORLD_SIZE):
rank = i
processes.append(ctx.Process(target=worker, args=(rank, WORLD_SIZE)))
for p in processes:
p.start()
for p in processes:
p.join()
for p in processes:
assert not p.exitcode
print("All processes finished.")
73 changes: 73 additions & 0 deletions vllm/distributed/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,11 @@
from typing import Sequence, Tuple

import torch
from torch.distributed import ProcessGroup
from torch.distributed.distributed_c10d import (Backend, PrefixStore,
_get_default_timeout,
is_nccl_available)
from torch.distributed.rendezvous import rendezvous

import vllm.envs as envs
from vllm.logger import init_logger
Expand Down Expand Up @@ -84,3 +89,71 @@ def get_pp_indices(num_hidden_layers: int, pp_rank: int,
end_layer = num_hidden_layers

return (start_layer, end_layer)


def stateless_init_process_group(init_method: str, rank: int, world_size: int,
backend: str) -> ProcessGroup:
"""A replacement for `torch.distributed.init_process_group` that does not
pollute the global state.

If we have process A and process B called `torch.distributed.init_process_group`
to form a group, and then we want to form another group with process A, B, C,
D, it is not possible in PyTorch, because process A and process B have already
formed a group, and process C and process D cannot join that group. This
function is a workaround for this issue.

`torch.distributed.init_process_group` is a global call, while this function
is a stateless call. It will return a `ProcessGroup` object that can be used
for collective communication. With this function, process A and process B
can call `stateless_init_process_group` to form a group, and then process A, B,
C, and D can call `stateless_init_process_group` to form another group.
""" # noqa

backend = Backend(backend) # it is basically string
timeout = _get_default_timeout(backend)

store, rank, world_size = next(
rendezvous(init_method, rank, world_size, timeout=timeout))
store.set_timeout(timeout)

group_rank = rank
group_size = world_size

# Use a PrefixStore to avoid accidental overrides of keys used by
# different systems (e.g. RPC) in case the store is multi-tenant.
prefix_store = PrefixStore(init_method, store)

pg_options = ProcessGroup.Options(backend=backend, timeout=timeout)

pg: ProcessGroup = ProcessGroup(
prefix_store,
group_rank,
group_size,
pg_options,
)

if backend == "gloo":
from torch.distributed.distributed_c10d import ProcessGroupGloo
backend_class = ProcessGroupGloo(prefix_store,
group_rank,
group_size,
timeout=timeout)
backend_type = ProcessGroup.BackendType.GLOO
device = torch.device("cpu")
elif backend == "nccl":
assert is_nccl_available()
from torch.distributed.distributed_c10d import ProcessGroupNCCL

backend_options = ProcessGroupNCCL.Options()
backend_options._timeout = timeout

backend_class = ProcessGroupNCCL(prefix_store, group_rank, group_size,
backend_options)
backend_type = ProcessGroup.BackendType.NCCL
device = torch.device("cuda")

backend_class._set_sequence_number_for_group()

pg._register_backend(device, backend_type, backend_class)

return pg