Skip to content

Commit

Permalink
[bugfix] respect distributed_executor_backend in world_size=1 (vllm-p…
Browse files Browse the repository at this point in the history
…roject#12934)

Signed-off-by: youkaichao <youkaichao@gmail.com>
  • Loading branch information
youkaichao authored and AoyuQC committed Feb 8, 2025
1 parent ffdc44c commit a545163
Show file tree
Hide file tree
Showing 4 changed files with 53 additions and 32 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -55,6 +55,7 @@ def test_custom_executor(model, tmp_path):
engine_args = EngineArgs(
model=model,
distributed_executor_backend=CustomUniExecutor,
enforce_eager=True, # reduce test time
)
engine = LLMEngine.from_engine_args(engine_args)
sampling_params = SamplingParams(max_tokens=1)
Expand All @@ -75,7 +76,10 @@ def test_custom_executor_async(model, tmp_path):
assert not os.path.exists(".marker")

engine_args = AsyncEngineArgs(
model=model, distributed_executor_backend=CustomUniExecutorAsync)
model=model,
distributed_executor_backend=CustomUniExecutorAsync,
enforce_eager=True, # reduce test time
)
engine = AsyncLLMEngine.from_engine_args(engine_args)
sampling_params = SamplingParams(max_tokens=1)

Expand All @@ -89,3 +93,18 @@ async def t():
assert os.path.exists(".marker")
finally:
os.chdir(cwd)


@pytest.mark.parametrize("model", ["facebook/opt-125m"])
def test_respect_ray(model):
# even for TP=1 and PP=1,
# if users specify ray, we should use ray.
# users might do this if they want to manage the
# resources using ray.
engine_args = EngineArgs(
model=model,
distributed_executor_backend="ray",
enforce_eager=True, # reduce test time
)
engine = LLMEngine.from_engine_args(engine_args)
assert engine.model_executor.uses_ray
3 changes: 3 additions & 0 deletions vllm/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -1401,6 +1401,9 @@ def __post_init__(self) -> None:
logger.info("Defaulting to use %s for distributed inference",
backend)

if self.distributed_executor_backend is None and self.world_size == 1:
self.distributed_executor_backend = "uni"

self._verify_args()

@property
Expand Down
44 changes: 22 additions & 22 deletions vllm/engine/llm_engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -434,6 +434,7 @@ def _initialize_kv_caches(self) -> None:
@classmethod
def _get_executor_cls(cls,
engine_config: VllmConfig) -> Type[ExecutorBase]:
# distributed_executor_backend must be set in VllmConfig.__post_init__
distributed_executor_backend = (
engine_config.parallel_config.distributed_executor_backend)
# Initialize the cluster and specify the executor class.
Expand All @@ -443,30 +444,29 @@ def _get_executor_cls(cls,
"distributed_executor_backend must be a subclass of "
f"ExecutorBase. Got {distributed_executor_backend}.")
executor_class = distributed_executor_backend
elif engine_config.parallel_config.world_size > 1:
if distributed_executor_backend == "ray":
from vllm.executor.ray_distributed_executor import (
RayDistributedExecutor)
executor_class = RayDistributedExecutor
elif distributed_executor_backend == "mp":
from vllm.executor.mp_distributed_executor import (
MultiprocessingDistributedExecutor)
assert not envs.VLLM_USE_RAY_SPMD_WORKER, (
"multiprocessing distributed executor backend does not "
"support VLLM_USE_RAY_SPMD_WORKER=1")
executor_class = MultiprocessingDistributedExecutor
elif distributed_executor_backend == "uni":
# JAX-style, single-process, multi-device executor.
from vllm.executor.uniproc_executor import UniProcExecutor
executor_class = UniProcExecutor
elif distributed_executor_backend == "external_launcher":
# executor with external launcher
from vllm.executor.uniproc_executor import ( # noqa
ExecutorWithExternalLauncher)
executor_class = ExecutorWithExternalLauncher
else:
elif distributed_executor_backend == "ray":
from vllm.executor.ray_distributed_executor import (
RayDistributedExecutor)
executor_class = RayDistributedExecutor
elif distributed_executor_backend == "mp":
from vllm.executor.mp_distributed_executor import (
MultiprocessingDistributedExecutor)
assert not envs.VLLM_USE_RAY_SPMD_WORKER, (
"multiprocessing distributed executor backend does not "
"support VLLM_USE_RAY_SPMD_WORKER=1")
executor_class = MultiprocessingDistributedExecutor
elif distributed_executor_backend == "uni":
# JAX-style, single-process, multi-device executor.
from vllm.executor.uniproc_executor import UniProcExecutor
executor_class = UniProcExecutor
elif distributed_executor_backend == "external_launcher":
# executor with external launcher
from vllm.executor.uniproc_executor import ( # noqa
ExecutorWithExternalLauncher)
executor_class = ExecutorWithExternalLauncher
else:
raise ValueError("unrecognized distributed_executor_backend: "
f"{distributed_executor_backend}")
return executor_class

@classmethod
Expand Down
17 changes: 8 additions & 9 deletions vllm/v1/executor/abstract.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,15 +25,14 @@ def get_class(vllm_config: VllmConfig) -> Type["Executor"]:
parallel_config = vllm_config.parallel_config
distributed_executor_backend = (
parallel_config.distributed_executor_backend)
if distributed_executor_backend is None:
# If the user does not specify the distributed executor backend,
# we will choose the backend based on the world size.
if parallel_config.world_size > 1:
distributed_executor_backend = "mp"
else:
distributed_executor_backend = "uni"

if distributed_executor_backend == "ray":
# distributed_executor_backend must be set in VllmConfig.__post_init__
if isinstance(distributed_executor_backend, type):
if not issubclass(distributed_executor_backend, ExecutorBase):
raise TypeError(
"distributed_executor_backend must be a subclass of "
f"ExecutorBase. Got {distributed_executor_backend}.")
executor_class = distributed_executor_backend
elif distributed_executor_backend == "ray":
executor_class = RayDistributedExecutor
elif distributed_executor_backend == "mp":
from vllm.v1.executor.multiproc_executor import MultiprocExecutor
Expand Down

0 comments on commit a545163

Please # to comment.