Skip to content

Commit f9026cc

Browse files
youkaichaoJC1DA
authored andcommitted
[core][distributed] add stateless_init_process_group (vllm-project#10072)
Signed-off-by: youkaichao <youkaichao@gmail.com> Signed-off-by: Loc Huynh <jc1da.3011@gmail.com>
1 parent acd46ed commit f9026cc

File tree

3 files changed

+147
-3
lines changed

3 files changed

+147
-3
lines changed

.buildkite/test-pipeline.yaml

+1-1
Original file line numberDiff line numberDiff line change
@@ -120,6 +120,7 @@ steps:
120120
- tests/spec_decode/e2e/test_integration_dist_tp4
121121
- tests/compile
122122
commands:
123+
- pytest -v -s distributed/test_utils.py
123124
- pytest -v -s compile/test_basic_correctness.py
124125
- pytest -v -s distributed/test_pynccl.py
125126
- pytest -v -s spec_decode/e2e/test_integration_dist_tp4.py
@@ -431,7 +432,6 @@ steps:
431432
- pip install -e ./plugins/vllm_add_dummy_model
432433
- pytest -v -s distributed/test_distributed_oot.py
433434
- CUDA_VISIBLE_DEVICES=0,1 pytest -v -s test_sharded_state_loader.py
434-
- CUDA_VISIBLE_DEVICES=0,1 pytest -v -s distributed/test_utils.py
435435

436436
- label: Multi-step Tests (4 GPUs) # 36min
437437
working_dir: "/vllm-workspace/tests"

tests/distributed/test_utils.py

+73-2
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,15 @@
1+
import pytest
12
import ray
3+
import torch
4+
import torch.distributed as dist
25

36
import vllm.envs as envs
7+
from vllm.distributed.utils import stateless_init_process_group
48
from vllm.utils import (cuda_device_count_stateless,
59
update_environment_variables)
610

11+
from ..utils import multi_gpu_test
12+
713

814
@ray.remote
915
class _CUDADeviceCountStatelessTestActor:
@@ -24,10 +30,75 @@ def test_cuda_device_count_stateless():
2430
CUDA_VISIBLE_DEVICES is changed."""
2531
actor = _CUDADeviceCountStatelessTestActor.options( # type: ignore
2632
num_gpus=2).remote()
27-
assert sorted(ray.get(
28-
actor.get_cuda_visible_devices.remote()).split(",")) == ["0", "1"]
33+
assert len(
34+
sorted(ray.get(
35+
actor.get_cuda_visible_devices.remote()).split(","))) == 2
2936
assert ray.get(actor.get_count.remote()) == 2
3037
ray.get(actor.set_cuda_visible_devices.remote("0"))
3138
assert ray.get(actor.get_count.remote()) == 1
3239
ray.get(actor.set_cuda_visible_devices.remote(""))
3340
assert ray.get(actor.get_count.remote()) == 0
41+
42+
43+
def cpu_worker(rank, WORLD_SIZE):
44+
pg1 = stateless_init_process_group(init_method="tcp://127.0.0.1:29500",
45+
rank=rank,
46+
world_size=WORLD_SIZE,
47+
backend="gloo")
48+
if rank <= 2:
49+
pg2 = stateless_init_process_group(init_method="tcp://127.0.0.1:29501",
50+
rank=rank,
51+
world_size=3,
52+
backend="gloo")
53+
data = torch.tensor([rank])
54+
dist.all_reduce(data, op=dist.ReduceOp.SUM, group=pg1)
55+
if rank <= 2:
56+
dist.all_reduce(data, op=dist.ReduceOp.SUM, group=pg2)
57+
item = data[0].item()
58+
print(f"rank: {rank}, item: {item}")
59+
if rank == 3:
60+
assert item == 6
61+
else:
62+
assert item == 18
63+
64+
65+
def gpu_worker(rank, WORLD_SIZE):
66+
pg1 = stateless_init_process_group(init_method="tcp://127.0.0.1:29502",
67+
rank=rank,
68+
world_size=WORLD_SIZE,
69+
backend="nccl")
70+
if rank <= 2:
71+
pg2 = stateless_init_process_group(init_method="tcp://127.0.0.1:29503",
72+
rank=rank,
73+
world_size=3,
74+
backend="nccl")
75+
torch.cuda.set_device(rank)
76+
data = torch.tensor([rank]).cuda()
77+
dist.all_reduce(data, op=dist.ReduceOp.SUM, group=pg1)
78+
if rank <= 2:
79+
dist.all_reduce(data, op=dist.ReduceOp.SUM, group=pg2)
80+
item = data[0].item()
81+
print(f"rank: {rank}, item: {item}")
82+
if rank == 3:
83+
assert item == 6
84+
else:
85+
assert item == 18
86+
87+
88+
@multi_gpu_test(num_gpus=4)
89+
@pytest.mark.parametrize("worker", [cpu_worker, gpu_worker])
90+
def test_stateless_init_process_group(worker):
91+
WORLD_SIZE = 4
92+
from multiprocessing import get_context
93+
ctx = get_context("fork")
94+
processes = []
95+
for i in range(WORLD_SIZE):
96+
rank = i
97+
processes.append(ctx.Process(target=worker, args=(rank, WORLD_SIZE)))
98+
for p in processes:
99+
p.start()
100+
for p in processes:
101+
p.join()
102+
for p in processes:
103+
assert not p.exitcode
104+
print("All processes finished.")

vllm/distributed/utils.py

+73
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,11 @@
55
from typing import Sequence, Tuple
66

77
import torch
8+
from torch.distributed import ProcessGroup
9+
from torch.distributed.distributed_c10d import (Backend, PrefixStore,
10+
_get_default_timeout,
11+
is_nccl_available)
12+
from torch.distributed.rendezvous import rendezvous
813

914
import vllm.envs as envs
1015
from vllm.logger import init_logger
@@ -84,3 +89,71 @@ def get_pp_indices(num_hidden_layers: int, pp_rank: int,
8489
end_layer = num_hidden_layers
8590

8691
return (start_layer, end_layer)
92+
93+
94+
def stateless_init_process_group(init_method: str, rank: int, world_size: int,
95+
backend: str) -> ProcessGroup:
96+
"""A replacement for `torch.distributed.init_process_group` that does not
97+
pollute the global state.
98+
99+
If we have process A and process B called `torch.distributed.init_process_group`
100+
to form a group, and then we want to form another group with process A, B, C,
101+
D, it is not possible in PyTorch, because process A and process B have already
102+
formed a group, and process C and process D cannot join that group. This
103+
function is a workaround for this issue.
104+
105+
`torch.distributed.init_process_group` is a global call, while this function
106+
is a stateless call. It will return a `ProcessGroup` object that can be used
107+
for collective communication. With this function, process A and process B
108+
can call `stateless_init_process_group` to form a group, and then process A, B,
109+
C, and D can call `stateless_init_process_group` to form another group.
110+
""" # noqa
111+
112+
backend = Backend(backend) # it is basically string
113+
timeout = _get_default_timeout(backend)
114+
115+
store, rank, world_size = next(
116+
rendezvous(init_method, rank, world_size, timeout=timeout))
117+
store.set_timeout(timeout)
118+
119+
group_rank = rank
120+
group_size = world_size
121+
122+
# Use a PrefixStore to avoid accidental overrides of keys used by
123+
# different systems (e.g. RPC) in case the store is multi-tenant.
124+
prefix_store = PrefixStore(init_method, store)
125+
126+
pg_options = ProcessGroup.Options(backend=backend, timeout=timeout)
127+
128+
pg: ProcessGroup = ProcessGroup(
129+
prefix_store,
130+
group_rank,
131+
group_size,
132+
pg_options,
133+
)
134+
135+
if backend == "gloo":
136+
from torch.distributed.distributed_c10d import ProcessGroupGloo
137+
backend_class = ProcessGroupGloo(prefix_store,
138+
group_rank,
139+
group_size,
140+
timeout=timeout)
141+
backend_type = ProcessGroup.BackendType.GLOO
142+
device = torch.device("cpu")
143+
elif backend == "nccl":
144+
assert is_nccl_available()
145+
from torch.distributed.distributed_c10d import ProcessGroupNCCL
146+
147+
backend_options = ProcessGroupNCCL.Options()
148+
backend_options._timeout = timeout
149+
150+
backend_class = ProcessGroupNCCL(prefix_store, group_rank, group_size,
151+
backend_options)
152+
backend_type = ProcessGroup.BackendType.NCCL
153+
device = torch.device("cuda")
154+
155+
backend_class._set_sequence_number_for_group()
156+
157+
pg._register_backend(device, backend_type, backend_class)
158+
159+
return pg

0 commit comments

Comments
 (0)