Skip to content

Commit

Permalink
[core][distributed] add stateless process group (vllm-project#10216)
Browse files Browse the repository at this point in the history
Signed-off-by: youkaichao <youkaichao@gmail.com>
  • Loading branch information
youkaichao authored Nov 11, 2024
1 parent 945180e commit d1eacd2
Show file tree
Hide file tree
Showing 3 changed files with 217 additions and 112 deletions.
79 changes: 52 additions & 27 deletions tests/distributed/test_utils.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,10 @@
import pytest
import ray
import torch
import torch.distributed as dist

import vllm.envs as envs
from vllm.distributed.utils import stateless_init_process_group
from vllm.distributed.device_communicators.pynccl import PyNcclCommunicator
from vllm.distributed.utils import StatelessProcessGroup
from vllm.utils import (cuda_device_count_stateless,
update_environment_variables)

Expand Down Expand Up @@ -41,42 +41,45 @@ def test_cuda_device_count_stateless():


def cpu_worker(rank, WORLD_SIZE):
pg1 = stateless_init_process_group(init_method="tcp://127.0.0.1:29500",
pg1 = StatelessProcessGroup.create(init_method="tcp://127.0.0.1:29500",
rank=rank,
world_size=WORLD_SIZE,
backend="gloo")
world_size=WORLD_SIZE)
if rank <= 2:
pg2 = stateless_init_process_group(init_method="tcp://127.0.0.1:29501",
pg2 = StatelessProcessGroup.create(init_method="tcp://127.0.0.1:29501",
rank=rank,
world_size=3,
backend="gloo")
world_size=3)
data = torch.tensor([rank])
dist.all_reduce(data, op=dist.ReduceOp.SUM, group=pg1)
data = pg1.broadcast_obj(data, src=2)
assert data.item() == 2
if rank <= 2:
dist.all_reduce(data, op=dist.ReduceOp.SUM, group=pg2)
item = data[0].item()
print(f"rank: {rank}, item: {item}")
if rank == 3:
assert item == 6
else:
assert item == 18
data = torch.tensor([rank + 1])
data = pg2.broadcast_obj(data, src=2)
assert data.item() == 3
pg2.barrier()
pg1.barrier()


def gpu_worker(rank, WORLD_SIZE):
pg1 = stateless_init_process_group(init_method="tcp://127.0.0.1:29502",
torch.cuda.set_device(rank)
pg1 = StatelessProcessGroup.create(init_method="tcp://127.0.0.1:29502",
rank=rank,
world_size=WORLD_SIZE,
backend="nccl")
world_size=WORLD_SIZE)
pynccl1 = PyNcclCommunicator(pg1, device=rank)
pynccl1.disabled = False
if rank <= 2:
pg2 = stateless_init_process_group(init_method="tcp://127.0.0.1:29503",
pg2 = StatelessProcessGroup.create(init_method="tcp://127.0.0.1:29503",
rank=rank,
world_size=3,
backend="nccl")
torch.cuda.set_device(rank)
world_size=3)
pynccl2 = PyNcclCommunicator(pg2, device=rank)
pynccl2.disabled = False
data = torch.tensor([rank]).cuda()
dist.all_reduce(data, op=dist.ReduceOp.SUM, group=pg1)
pynccl1.all_reduce(data)
pg1.barrier()
torch.cuda.synchronize()
if rank <= 2:
dist.all_reduce(data, op=dist.ReduceOp.SUM, group=pg2)
pynccl2.all_reduce(data)
pg2.barrier()
torch.cuda.synchronize()
item = data[0].item()
print(f"rank: {rank}, item: {item}")
if rank == 3:
Expand All @@ -85,9 +88,31 @@ def gpu_worker(rank, WORLD_SIZE):
assert item == 18


def broadcast_worker(rank, WORLD_SIZE):
pg1 = StatelessProcessGroup.create(init_method="tcp://127.0.0.1:29504",
rank=rank,
world_size=WORLD_SIZE)
if rank == 2:
pg1.broadcast_obj("secret", src=2)
else:
obj = pg1.broadcast_obj(None, src=2)
assert obj == "secret"
pg1.barrier()


def allgather_worker(rank, WORLD_SIZE):
pg1 = StatelessProcessGroup.create(init_method="tcp://127.0.0.1:29505",
rank=rank,
world_size=WORLD_SIZE)
data = pg1.all_gather_obj(rank)
assert data == list(range(WORLD_SIZE))
pg1.barrier()


@multi_gpu_test(num_gpus=4)
@pytest.mark.parametrize("worker", [cpu_worker, gpu_worker])
def test_stateless_init_process_group(worker):
@pytest.mark.parametrize(
"worker", [cpu_worker, gpu_worker, broadcast_worker, allgather_worker])
def test_stateless_process_group(worker):
WORLD_SIZE = 4
from multiprocessing import get_context
ctx = get_context("fork")
Expand Down
38 changes: 24 additions & 14 deletions vllm/distributed/device_communicators/pynccl.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
from vllm.distributed.device_communicators.pynccl_wrapper import (
NCCLLibrary, buffer_type, cudaStream_t, ncclComm_t, ncclDataTypeEnum,
ncclRedOpTypeEnum, ncclUniqueId)
from vllm.distributed.utils import StatelessProcessGroup
from vllm.logger import init_logger

logger = init_logger(__name__)
Expand All @@ -18,7 +19,7 @@ class PyNcclCommunicator:

def __init__(
self,
group: ProcessGroup,
group: Union[ProcessGroup, StatelessProcessGroup],
device: Union[int, str, torch.device],
library_path: Optional[str] = None,
):
Expand All @@ -33,13 +34,18 @@ def __init__(
It is the caller's responsibility to make sure each communicator
is bind to a unique device.
"""
assert dist.is_initialized()
assert dist.get_backend(group) != dist.Backend.NCCL, (
"PyNcclCommunicator should be attached to a non-NCCL group.")
if not isinstance(group, StatelessProcessGroup):
assert dist.is_initialized()
assert dist.get_backend(group) != dist.Backend.NCCL, (
"PyNcclCommunicator should be attached to a non-NCCL group.")
# note: this rank is the rank in the group
self.rank = dist.get_rank(group)
self.world_size = dist.get_world_size(group)
else:
self.rank = group.rank
self.world_size = group.world_size

self.group = group
# note: this rank is the rank in the group
self.rank = dist.get_rank(group)
self.world_size = dist.get_world_size(group)

# if world_size == 1, no need to create communicator
if self.world_size == 1:
Expand Down Expand Up @@ -68,13 +74,17 @@ def __init__(
else:
# construct an empty unique id
self.unique_id = ncclUniqueId()
tensor = torch.ByteTensor(list(self.unique_id.internal))
ranks = dist.get_process_group_ranks(group)
# arg `src` in `broadcast` is the global rank
dist.broadcast(tensor, src=ranks[0], group=group)
byte_list = tensor.tolist()
for i, byte in enumerate(byte_list):
self.unique_id.internal[i] = byte

if not isinstance(group, StatelessProcessGroup):
tensor = torch.ByteTensor(list(self.unique_id.internal))
ranks = dist.get_process_group_ranks(group)
# arg `src` in `broadcast` is the global rank
dist.broadcast(tensor, src=ranks[0], group=group)
byte_list = tensor.tolist()
for i, byte in enumerate(byte_list):
self.unique_id.internal[i] = byte
else:
self.unique_id = group.broadcast_obj(self.unique_id, src=0)
if isinstance(device, int):
device = torch.device(f"cuda:{device}")
elif isinstance(device, str):
Expand Down
Loading

0 comments on commit d1eacd2

Please # to comment.