From f534c4f3de3ad4fc326a21d041130a488cc55922 Mon Sep 17 00:00:00 2001 From: youkaichao Date: Sat, 8 Feb 2025 11:52:21 +0800 Subject: [PATCH 1/2] respect distributed_executor_backend Signed-off-by: youkaichao --- vllm/config.py | 3 +++ vllm/engine/llm_engine.py | 44 ++++++++++++++++++------------------ vllm/v1/executor/abstract.py | 17 +++++++------- 3 files changed, 33 insertions(+), 31 deletions(-) diff --git a/vllm/config.py b/vllm/config.py index 5579d6936d105..426ba38080270 100644 --- a/vllm/config.py +++ b/vllm/config.py @@ -1401,6 +1401,9 @@ def __post_init__(self) -> None: logger.info("Defaulting to use %s for distributed inference", backend) + if self.distributed_executor_backend is None and self.world_size == 1: + self.distributed_executor_backend = "uni" + self._verify_args() @property diff --git a/vllm/engine/llm_engine.py b/vllm/engine/llm_engine.py index d82d9ad9df323..2e5bc75c6db38 100644 --- a/vllm/engine/llm_engine.py +++ b/vllm/engine/llm_engine.py @@ -434,6 +434,7 @@ def _initialize_kv_caches(self) -> None: @classmethod def _get_executor_cls(cls, engine_config: VllmConfig) -> Type[ExecutorBase]: + # distributed_executor_backend must be set in VllmConfig.__post_init__ distributed_executor_backend = ( engine_config.parallel_config.distributed_executor_backend) # Initialize the cluster and specify the executor class. @@ -443,30 +444,29 @@ def _get_executor_cls(cls, "distributed_executor_backend must be a subclass of " f"ExecutorBase. Got {distributed_executor_backend}.") executor_class = distributed_executor_backend - elif engine_config.parallel_config.world_size > 1: - if distributed_executor_backend == "ray": - from vllm.executor.ray_distributed_executor import ( - RayDistributedExecutor) - executor_class = RayDistributedExecutor - elif distributed_executor_backend == "mp": - from vllm.executor.mp_distributed_executor import ( - MultiprocessingDistributedExecutor) - assert not envs.VLLM_USE_RAY_SPMD_WORKER, ( - "multiprocessing distributed executor backend does not " - "support VLLM_USE_RAY_SPMD_WORKER=1") - executor_class = MultiprocessingDistributedExecutor - elif distributed_executor_backend == "uni": - # JAX-style, single-process, multi-device executor. - from vllm.executor.uniproc_executor import UniProcExecutor - executor_class = UniProcExecutor - elif distributed_executor_backend == "external_launcher": - # executor with external launcher - from vllm.executor.uniproc_executor import ( # noqa - ExecutorWithExternalLauncher) - executor_class = ExecutorWithExternalLauncher - else: + elif distributed_executor_backend == "ray": + from vllm.executor.ray_distributed_executor import ( + RayDistributedExecutor) + executor_class = RayDistributedExecutor + elif distributed_executor_backend == "mp": + from vllm.executor.mp_distributed_executor import ( + MultiprocessingDistributedExecutor) + assert not envs.VLLM_USE_RAY_SPMD_WORKER, ( + "multiprocessing distributed executor backend does not " + "support VLLM_USE_RAY_SPMD_WORKER=1") + executor_class = MultiprocessingDistributedExecutor + elif distributed_executor_backend == "uni": + # JAX-style, single-process, multi-device executor. from vllm.executor.uniproc_executor import UniProcExecutor executor_class = UniProcExecutor + elif distributed_executor_backend == "external_launcher": + # executor with external launcher + from vllm.executor.uniproc_executor import ( # noqa + ExecutorWithExternalLauncher) + executor_class = ExecutorWithExternalLauncher + else: + raise ValueError("unrecognized distributed_executor_backend: " + f"{distributed_executor_backend}") return executor_class @classmethod diff --git a/vllm/v1/executor/abstract.py b/vllm/v1/executor/abstract.py index ac10d43eb0d54..093be09ae11bb 100644 --- a/vllm/v1/executor/abstract.py +++ b/vllm/v1/executor/abstract.py @@ -25,15 +25,14 @@ def get_class(vllm_config: VllmConfig) -> Type["Executor"]: parallel_config = vllm_config.parallel_config distributed_executor_backend = ( parallel_config.distributed_executor_backend) - if distributed_executor_backend is None: - # If the user does not specify the distributed executor backend, - # we will choose the backend based on the world size. - if parallel_config.world_size > 1: - distributed_executor_backend = "mp" - else: - distributed_executor_backend = "uni" - - if distributed_executor_backend == "ray": + # distributed_executor_backend must be set in VllmConfig.__post_init__ + if isinstance(distributed_executor_backend, type): + if not issubclass(distributed_executor_backend, ExecutorBase): + raise TypeError( + "distributed_executor_backend must be a subclass of " + f"ExecutorBase. Got {distributed_executor_backend}.") + executor_class = distributed_executor_backend + elif distributed_executor_backend == "ray": executor_class = RayDistributedExecutor elif distributed_executor_backend == "mp": from vllm.v1.executor.multiproc_executor import MultiprocExecutor From 72b64bb0d70ebcc16444c57e5badf247507161d5 Mon Sep 17 00:00:00 2001 From: youkaichao Date: Sat, 8 Feb 2025 11:58:00 +0800 Subject: [PATCH 2/2] add tests Signed-off-by: youkaichao --- ...st_custom_executor.py => test_executor.py} | 21 ++++++++++++++++++- 1 file changed, 20 insertions(+), 1 deletion(-) rename tests/engine/{test_custom_executor.py => test_executor.py} (79%) diff --git a/tests/engine/test_custom_executor.py b/tests/engine/test_executor.py similarity index 79% rename from tests/engine/test_custom_executor.py rename to tests/engine/test_executor.py index 3e77faecbd3f5..84cc3ed63bb93 100644 --- a/tests/engine/test_custom_executor.py +++ b/tests/engine/test_executor.py @@ -55,6 +55,7 @@ def test_custom_executor(model, tmp_path): engine_args = EngineArgs( model=model, distributed_executor_backend=CustomUniExecutor, + enforce_eager=True, # reduce test time ) engine = LLMEngine.from_engine_args(engine_args) sampling_params = SamplingParams(max_tokens=1) @@ -75,7 +76,10 @@ def test_custom_executor_async(model, tmp_path): assert not os.path.exists(".marker") engine_args = AsyncEngineArgs( - model=model, distributed_executor_backend=CustomUniExecutorAsync) + model=model, + distributed_executor_backend=CustomUniExecutorAsync, + enforce_eager=True, # reduce test time + ) engine = AsyncLLMEngine.from_engine_args(engine_args) sampling_params = SamplingParams(max_tokens=1) @@ -89,3 +93,18 @@ async def t(): assert os.path.exists(".marker") finally: os.chdir(cwd) + + +@pytest.mark.parametrize("model", ["facebook/opt-125m"]) +def test_respect_ray(model): + # even for TP=1 and PP=1, + # if users specify ray, we should use ray. + # users might do this if they want to manage the + # resources using ray. + engine_args = EngineArgs( + model=model, + distributed_executor_backend="ray", + enforce_eager=True, # reduce test time + ) + engine = LLMEngine.from_engine_args(engine_args) + assert engine.model_executor.uses_ray