From fb2215f3f2619d74e614e7b47e090f8bdceb697a Mon Sep 17 00:00:00 2001 From: Nick Hill Date: Tue, 2 Jul 2024 14:14:26 -0700 Subject: [PATCH] [BugFix] Avoid unnecessary Ray import warnings Currently the logs are polluted with multiple Ray import warnings even in cases where Ray is not being used. There should be an error only in the case that Ray is configured and/or required but not available, and in that case it should be fatal. In our own builds/tests this was also interfering with the subprocess-based custom all-reduce p2p check recently introduced in https://github.com/vllm-project/vllm/pull/5669. --- vllm/config.py | 9 +++++++-- vllm/engine/async_llm_engine.py | 5 +++++ vllm/executor/ray_utils.py | 23 ++++++++++++++++------- 3 files changed, 28 insertions(+), 9 deletions(-) diff --git a/vllm/config.py b/vllm/config.py index 9a7e0ea7a3a10..9633b7dc63290 100644 --- a/vllm/config.py +++ b/vllm/config.py @@ -655,11 +655,13 @@ def __init__( from vllm.executor import ray_utils backend = "mp" - ray_found = ray_utils.ray is not None + ray_found = ray_utils.ray_is_available() if cuda_device_count_stateless() < self.world_size: if not ray_found: raise ValueError("Unable to load Ray which is " - "required for multi-node inference") + "required for multi-node inference, " + "please install Ray with `pip install " + "ray`.") from ray_utils.ray_import_err backend = "ray" elif ray_found: if self.placement_group: @@ -691,6 +693,9 @@ def _verify_args(self) -> None: raise ValueError( "Unrecognized distributed executor backend. Supported values " "are 'ray' or 'mp'.") + if self.distributed_executor_backend == "ray": + from vllm.executor import ray_utils + ray_utils.assert_ray_available() if not self.disable_custom_all_reduce and self.world_size > 1: if is_hip(): self.disable_custom_all_reduce = True diff --git a/vllm/engine/async_llm_engine.py b/vllm/engine/async_llm_engine.py index 0ce511ce42476..c972638cd11d3 100644 --- a/vllm/engine/async_llm_engine.py +++ b/vllm/engine/async_llm_engine.py @@ -378,6 +378,11 @@ def from_engine_args( """Creates an async LLM engine from the engine arguments.""" # Create the engine configs. engine_config = engine_args.create_engine_config() + + if engine_args.engine_use_ray: + from vllm.executor import ray_utils + ray_utils.assert_ray_available() + distributed_executor_backend = ( engine_config.parallel_config.distributed_executor_backend) diff --git a/vllm/executor/ray_utils.py b/vllm/executor/ray_utils.py index 495fddd175dd4..242d6c136655f 100644 --- a/vllm/executor/ray_utils.py +++ b/vllm/executor/ray_utils.py @@ -42,14 +42,26 @@ def execute_model_compiled_dag_remote(self, ignored): output = pickle.dumps(output) return output + ray_import_err = None + except ImportError as e: - logger.warning( - "Failed to import Ray with %r. For multi-node inference, " - "please install Ray with `pip install ray`.", e) ray = None # type: ignore + ray_import_err = e RayWorkerWrapper = None # type: ignore +def ray_is_available() -> bool: + """Returns True if Ray is available.""" + return ray is not None + + +def assert_ray_available(): + """Raise an exception if Ray is not available.""" + if ray is None: + raise ValueError("Failed to import Ray, please install Ray with " + "`pip install ray`.") from ray_import_err + + def initialize_ray_cluster( parallel_config: ParallelConfig, ray_address: Optional[str] = None, @@ -65,10 +77,7 @@ def initialize_ray_cluster( ray_address: The address of the Ray cluster. If None, uses the default Ray cluster address. """ - if ray is None: - raise ImportError( - "Ray is not installed. Please install Ray to use multi-node " - "serving.") + assert_ray_available() # Connect to a ray cluster. if is_hip() or is_xpu():