|
5 | 5 | from typing import Sequence, Tuple
|
6 | 6 |
|
7 | 7 | import torch
|
| 8 | +from torch.distributed import ProcessGroup |
| 9 | +from torch.distributed.distributed_c10d import (Backend, PrefixStore, |
| 10 | + _get_default_timeout, |
| 11 | + is_nccl_available) |
| 12 | +from torch.distributed.rendezvous import rendezvous |
8 | 13 |
|
9 | 14 | import vllm.envs as envs
|
10 | 15 | from vllm.logger import init_logger
|
@@ -84,3 +89,71 @@ def get_pp_indices(num_hidden_layers: int, pp_rank: int,
|
84 | 89 | end_layer = num_hidden_layers
|
85 | 90 |
|
86 | 91 | return (start_layer, end_layer)
|
| 92 | + |
| 93 | + |
| 94 | +def stateless_init_process_group(init_method: str, rank: int, world_size: int, |
| 95 | + backend: str) -> ProcessGroup: |
| 96 | + """A replacement for `torch.distributed.init_process_group` that does not |
| 97 | + pollute the global state. |
| 98 | +
|
| 99 | + If we have process A and process B called `torch.distributed.init_process_group` |
| 100 | + to form a group, and then we want to form another group with process A, B, C, |
| 101 | + D, it is not possible in PyTorch, because process A and process B have already |
| 102 | + formed a group, and process C and process D cannot join that group. This |
| 103 | + function is a workaround for this issue. |
| 104 | +
|
| 105 | + `torch.distributed.init_process_group` is a global call, while this function |
| 106 | + is a stateless call. It will return a `ProcessGroup` object that can be used |
| 107 | + for collective communication. With this function, process A and process B |
| 108 | + can call `stateless_init_process_group` to form a group, and then process A, B, |
| 109 | + C, and D can call `stateless_init_process_group` to form another group. |
| 110 | + """ # noqa |
| 111 | + |
| 112 | + backend = Backend(backend) # it is basically string |
| 113 | + timeout = _get_default_timeout(backend) |
| 114 | + |
| 115 | + store, rank, world_size = next( |
| 116 | + rendezvous(init_method, rank, world_size, timeout=timeout)) |
| 117 | + store.set_timeout(timeout) |
| 118 | + |
| 119 | + group_rank = rank |
| 120 | + group_size = world_size |
| 121 | + |
| 122 | + # Use a PrefixStore to avoid accidental overrides of keys used by |
| 123 | + # different systems (e.g. RPC) in case the store is multi-tenant. |
| 124 | + prefix_store = PrefixStore(init_method, store) |
| 125 | + |
| 126 | + pg_options = ProcessGroup.Options(backend=backend, timeout=timeout) |
| 127 | + |
| 128 | + pg: ProcessGroup = ProcessGroup( |
| 129 | + prefix_store, |
| 130 | + group_rank, |
| 131 | + group_size, |
| 132 | + pg_options, |
| 133 | + ) |
| 134 | + |
| 135 | + if backend == "gloo": |
| 136 | + from torch.distributed.distributed_c10d import ProcessGroupGloo |
| 137 | + backend_class = ProcessGroupGloo(prefix_store, |
| 138 | + group_rank, |
| 139 | + group_size, |
| 140 | + timeout=timeout) |
| 141 | + backend_type = ProcessGroup.BackendType.GLOO |
| 142 | + device = torch.device("cpu") |
| 143 | + elif backend == "nccl": |
| 144 | + assert is_nccl_available() |
| 145 | + from torch.distributed.distributed_c10d import ProcessGroupNCCL |
| 146 | + |
| 147 | + backend_options = ProcessGroupNCCL.Options() |
| 148 | + backend_options._timeout = timeout |
| 149 | + |
| 150 | + backend_class = ProcessGroupNCCL(prefix_store, group_rank, group_size, |
| 151 | + backend_options) |
| 152 | + backend_type = ProcessGroup.BackendType.NCCL |
| 153 | + device = torch.device("cuda") |
| 154 | + |
| 155 | + backend_class._set_sequence_number_for_group() |
| 156 | + |
| 157 | + pg._register_backend(device, backend_type, backend_class) |
| 158 | + |
| 159 | + return pg |
0 commit comments