Skip to content
New issue

Have a question about this project? # for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “#”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? # to your account

[bugfix] respect distributed_executor_backend in world_size=1 #12934

Merged
merged 2 commits into from
Feb 8, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -55,6 +55,7 @@ def test_custom_executor(model, tmp_path):
engine_args = EngineArgs(
model=model,
distributed_executor_backend=CustomUniExecutor,
enforce_eager=True, # reduce test time
)
engine = LLMEngine.from_engine_args(engine_args)
sampling_params = SamplingParams(max_tokens=1)
Expand All @@ -75,7 +76,10 @@ def test_custom_executor_async(model, tmp_path):
assert not os.path.exists(".marker")

engine_args = AsyncEngineArgs(
model=model, distributed_executor_backend=CustomUniExecutorAsync)
model=model,
distributed_executor_backend=CustomUniExecutorAsync,
enforce_eager=True, # reduce test time
)
engine = AsyncLLMEngine.from_engine_args(engine_args)
sampling_params = SamplingParams(max_tokens=1)

Expand All @@ -89,3 +93,18 @@ async def t():
assert os.path.exists(".marker")
finally:
os.chdir(cwd)


@pytest.mark.parametrize("model", ["facebook/opt-125m"])
def test_respect_ray(model):
# even for TP=1 and PP=1,
# if users specify ray, we should use ray.
# users might do this if they want to manage the
# resources using ray.
engine_args = EngineArgs(
model=model,
distributed_executor_backend="ray",
enforce_eager=True, # reduce test time
)
engine = LLMEngine.from_engine_args(engine_args)
assert engine.model_executor.uses_ray
3 changes: 3 additions & 0 deletions vllm/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -1401,6 +1401,9 @@ def __post_init__(self) -> None:
logger.info("Defaulting to use %s for distributed inference",
backend)

if self.distributed_executor_backend is None and self.world_size == 1:
self.distributed_executor_backend = "uni"

self._verify_args()

@property
Expand Down
44 changes: 22 additions & 22 deletions vllm/engine/llm_engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -434,6 +434,7 @@ def _initialize_kv_caches(self) -> None:
@classmethod
def _get_executor_cls(cls,
engine_config: VllmConfig) -> Type[ExecutorBase]:
# distributed_executor_backend must be set in VllmConfig.__post_init__
distributed_executor_backend = (
engine_config.parallel_config.distributed_executor_backend)
# Initialize the cluster and specify the executor class.
Expand All @@ -443,30 +444,29 @@ def _get_executor_cls(cls,
"distributed_executor_backend must be a subclass of "
f"ExecutorBase. Got {distributed_executor_backend}.")
executor_class = distributed_executor_backend
elif engine_config.parallel_config.world_size > 1:
if distributed_executor_backend == "ray":
from vllm.executor.ray_distributed_executor import (
RayDistributedExecutor)
executor_class = RayDistributedExecutor
elif distributed_executor_backend == "mp":
from vllm.executor.mp_distributed_executor import (
MultiprocessingDistributedExecutor)
assert not envs.VLLM_USE_RAY_SPMD_WORKER, (
"multiprocessing distributed executor backend does not "
"support VLLM_USE_RAY_SPMD_WORKER=1")
executor_class = MultiprocessingDistributedExecutor
elif distributed_executor_backend == "uni":
# JAX-style, single-process, multi-device executor.
from vllm.executor.uniproc_executor import UniProcExecutor
executor_class = UniProcExecutor
elif distributed_executor_backend == "external_launcher":
# executor with external launcher
from vllm.executor.uniproc_executor import ( # noqa
ExecutorWithExternalLauncher)
executor_class = ExecutorWithExternalLauncher
else:
elif distributed_executor_backend == "ray":
from vllm.executor.ray_distributed_executor import (
RayDistributedExecutor)
executor_class = RayDistributedExecutor
elif distributed_executor_backend == "mp":
from vllm.executor.mp_distributed_executor import (
MultiprocessingDistributedExecutor)
assert not envs.VLLM_USE_RAY_SPMD_WORKER, (
"multiprocessing distributed executor backend does not "
"support VLLM_USE_RAY_SPMD_WORKER=1")
executor_class = MultiprocessingDistributedExecutor
elif distributed_executor_backend == "uni":
# JAX-style, single-process, multi-device executor.
from vllm.executor.uniproc_executor import UniProcExecutor
executor_class = UniProcExecutor
elif distributed_executor_backend == "external_launcher":
# executor with external launcher
from vllm.executor.uniproc_executor import ( # noqa
ExecutorWithExternalLauncher)
executor_class = ExecutorWithExternalLauncher
else:
raise ValueError("unrecognized distributed_executor_backend: "
f"{distributed_executor_backend}")
return executor_class

@classmethod
Expand Down
17 changes: 8 additions & 9 deletions vllm/v1/executor/abstract.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,15 +25,14 @@ def get_class(vllm_config: VllmConfig) -> Type["Executor"]:
parallel_config = vllm_config.parallel_config
distributed_executor_backend = (
parallel_config.distributed_executor_backend)
if distributed_executor_backend is None:
# If the user does not specify the distributed executor backend,
# we will choose the backend based on the world size.
if parallel_config.world_size > 1:
distributed_executor_backend = "mp"
else:
distributed_executor_backend = "uni"

if distributed_executor_backend == "ray":
# distributed_executor_backend must be set in VllmConfig.__post_init__
if isinstance(distributed_executor_backend, type):
if not issubclass(distributed_executor_backend, ExecutorBase):
raise TypeError(
"distributed_executor_backend must be a subclass of "
f"ExecutorBase. Got {distributed_executor_backend}.")
executor_class = distributed_executor_backend
elif distributed_executor_backend == "ray":
executor_class = RayDistributedExecutor
elif distributed_executor_backend == "mp":
from vllm.v1.executor.multiproc_executor import MultiprocExecutor
Expand Down