Skip to content

Commit cd16a23

Browse files
njhillAlvant
authored andcommitted
[BugFix] Fix clean shutdown issues (vllm-project#8492)
Signed-off-by: Alvant <alvasian@yandex.ru>
1 parent ac82e83 commit cd16a23

11 files changed

+215
-136
lines changed

tests/async_engine/test_async_llm_engine.py

+8-2
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,11 @@ class RequestOutput:
2626
finished: bool = False
2727

2828

29+
@dataclass
30+
class MockModelConfig:
31+
use_async_output_proc = True
32+
33+
2934
class MockEngine:
3035

3136
def __init__(self):
@@ -35,6 +40,7 @@ def __init__(self):
3540
self.request_id = None
3641
# Ugly, remove dependency when possible
3742
self.parallel_config = ParallelConfig(1, 1, False)
43+
self.model_config = MockModelConfig()
3844

3945
async def step_async(self, virtual_engine):
4046
# PP size is 1, ignore virtual engine
@@ -80,7 +86,7 @@ class MockAsyncLLMEngine(AsyncLLMEngine):
8086

8187
@pytest.mark.asyncio
8288
async def test_new_requests_event():
83-
engine = MockAsyncLLMEngine(worker_use_ray=False)
89+
engine = MockAsyncLLMEngine()
8490
engine.start_background_loop()
8591
await asyncio.sleep(0.01)
8692
assert engine.engine.step_calls == 0
@@ -113,7 +119,7 @@ async def test_new_requests_event():
113119
assert engine.engine.add_request_calls == 3
114120
assert engine.engine.step_calls == old_step_calls + 1
115121

116-
engine = MockAsyncLLMEngine(worker_use_ray=True)
122+
engine = MockAsyncLLMEngine()
117123
assert engine.get_model_config() is not None
118124
assert engine.get_tokenizer() is not None
119125
assert engine.get_decoding_config() is not None

vllm/engine/async_llm_engine.py

+45-25
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,10 @@
11
import asyncio
22
import time
3+
import weakref
34
from functools import partial
45
from typing import (Any, AsyncGenerator, Callable, Dict, Iterable, List,
56
Mapping, Optional, Set, Tuple, Type, Union)
7+
from weakref import ReferenceType
68

79
import vllm.envs as envs
810
from vllm.config import (DecodingConfig, EngineConfig, LoRAConfig, ModelConfig,
@@ -26,6 +28,7 @@
2628
from vllm.sequence import ExecuteModelRequest
2729
from vllm.transformers_utils.tokenizer import AnyTokenizer
2830
from vllm.usage.usage_lib import UsageContext
31+
from vllm.utils import weak_bind
2932

3033
logger = init_logger(__name__)
3134
ENGINE_ITERATION_TIMEOUT_S = envs.VLLM_ENGINE_ITERATION_TIMEOUT_S
@@ -492,9 +495,6 @@ class AsyncLLMEngine:
492495
method yields the outputs from the :class:`LLMEngine` to the caller.
493496
494497
Args:
495-
worker_use_ray: Whether to use Ray for model workers. Required for
496-
distributed execution. Should be the same as
497-
`parallel_config.worker_use_ray`.
498498
log_requests: Whether to log the requests.
499499
start_engine_loop: If True, the background task to run the engine
500500
will be automatically started in the generate call.
@@ -505,23 +505,22 @@ class AsyncLLMEngine:
505505
_engine_class: Type[_AsyncLLMEngine] = _AsyncLLMEngine
506506

507507
def __init__(self,
508-
worker_use_ray: bool,
509508
*args,
510509
log_requests: bool = True,
511510
start_engine_loop: bool = True,
512511
**kwargs) -> None:
513-
self.worker_use_ray = worker_use_ray
514512
self.log_requests = log_requests
515513
self.engine = self._engine_class(*args, **kwargs)
516514

517515
# This ensures quick processing of request outputs
518516
# so the append to asyncio queues is not delayed,
519517
# especially for multi-step.
520-
#
521-
self.use_process_request_outputs_callback = True
518+
self.use_process_request_outputs_callback = (
519+
self.engine.model_config.use_async_output_proc)
520+
522521
if self.use_process_request_outputs_callback:
523522
self.engine.process_request_outputs_callback = \
524-
self.process_request_outputs
523+
weak_bind(self.process_request_outputs)
525524

526525
self.background_loop: Optional[asyncio.Future] = None
527526
# We need to keep a reference to unshielded
@@ -534,6 +533,11 @@ def __init__(self,
534533
# Lazy initialized fields
535534
self._request_tracker: RequestTracker
536535

536+
def __del__(self):
537+
if rt := getattr(self, "request_tracker", None):
538+
# Wake up engine loop so that it will exit cleanly
539+
rt.new_requests_event.set()
540+
537541
@classmethod
538542
def _get_executor_cls(
539543
cls, engine_config: EngineConfig) -> Type[ExecutorAsyncBase]:
@@ -544,15 +548,12 @@ def _get_executor_cls(
544548
raise TypeError(
545549
"distributed_executor_backend must be a subclass of "
546550
f"ExecutorAsyncBase. Got {distributed_executor_backend}.")
547-
if distributed_executor_backend.uses_ray: # type: ignore
548-
initialize_ray_cluster(engine_config.parallel_config)
549551
executor_class = distributed_executor_backend
550552
elif engine_config.device_config.device_type == "neuron":
551553
from vllm.executor.neuron_executor import NeuronExecutorAsync
552554
executor_class = NeuronExecutorAsync
553555
elif engine_config.device_config.device_type == "tpu":
554556
if distributed_executor_backend == "ray":
555-
initialize_ray_cluster(engine_config.parallel_config)
556557
from vllm.executor.ray_tpu_executor import RayTPUExecutorAsync
557558
executor_class = RayTPUExecutorAsync
558559
else:
@@ -573,19 +574,16 @@ def _get_executor_cls(
573574
from vllm.executor.xpu_executor import XPUExecutorAsync
574575
executor_class = XPUExecutorAsync
575576
elif distributed_executor_backend == "ray":
576-
initialize_ray_cluster(engine_config.parallel_config)
577577
from vllm.executor.ray_xpu_executor import RayXPUExecutorAsync
578578
executor_class = RayXPUExecutorAsync
579579
elif distributed_executor_backend == "mp":
580-
initialize_ray_cluster(engine_config.parallel_config)
581580
from vllm.executor.multiproc_xpu_executor import (
582581
MultiprocessingXPUExecutorAsync)
583582
executor_class = MultiprocessingXPUExecutorAsync
584583
else:
585584
raise RuntimeError(
586585
"Not supported distributed execution model on XPU device.")
587586
elif distributed_executor_backend == "ray":
588-
initialize_ray_cluster(engine_config.parallel_config)
589587
from vllm.executor.ray_gpu_executor import RayGPUExecutorAsync
590588
executor_class = RayGPUExecutorAsync
591589
elif distributed_executor_backend == "mp":
@@ -601,19 +599,23 @@ def _get_executor_cls(
601599
def from_engine_args(
602600
cls,
603601
engine_args: AsyncEngineArgs,
602+
engine_config: Optional[EngineConfig] = None,
604603
start_engine_loop: bool = True,
605604
usage_context: UsageContext = UsageContext.ENGINE_CONTEXT,
606605
stat_loggers: Optional[Dict[str, StatLoggerBase]] = None,
607606
) -> "AsyncLLMEngine":
608607
"""Creates an async LLM engine from the engine arguments."""
609608
# Create the engine configs.
610-
engine_config = engine_args.create_engine_config()
609+
if engine_config is None:
610+
engine_config = engine_args.create_engine_config()
611611

612612
executor_class = cls._get_executor_cls(engine_config)
613613

614+
if executor_class.uses_ray:
615+
initialize_ray_cluster(engine_config.parallel_config)
616+
614617
# Create the async LLM engine.
615618
engine = cls(
616-
executor_class.uses_ray,
617619
**engine_config.to_dict(),
618620
executor_class=executor_class,
619621
log_requests=not engine_args.disable_log_requests,
@@ -670,7 +672,7 @@ def start_background_loop(self) -> None:
670672
self._request_tracker = RequestTracker()
671673

672674
self._background_loop_unshielded = asyncio.get_event_loop(
673-
).create_task(self.run_engine_loop())
675+
).create_task(self.run_engine_loop(weakref.ref(self)))
674676
self._background_loop_unshielded.add_done_callback(
675677
partial(_log_task_completion, error_callback=self._error_callback))
676678
self.background_loop = asyncio.shield(self._background_loop_unshielded)
@@ -740,9 +742,16 @@ def process_request_outputs(self, request_outputs) -> bool:
740742
async def _engine_abort(self, request_ids: Iterable[str]):
741743
self.engine.abort_request(request_ids)
742744

743-
async def run_engine_loop(self):
745+
@staticmethod
746+
async def run_engine_loop(engine_ref: ReferenceType):
747+
"""We use a weakref to the engine so that the running loop
748+
doesn't prevent the engine being garbage collected."""
749+
engine: Optional["AsyncLLMEngine"] = engine_ref()
750+
if not engine:
751+
return
752+
744753
pipeline_parallel_size = \
745-
self.engine.parallel_config.pipeline_parallel_size
754+
engine.engine.parallel_config.pipeline_parallel_size
746755
has_requests_in_progress = [False] * pipeline_parallel_size
747756
while True:
748757
if not any(has_requests_in_progress):
@@ -753,11 +762,21 @@ async def run_engine_loop(self):
753762
# timeout, and unblocks the RPC thread in the workers so that
754763
# they can process any other queued control plane messages,
755764
# such as add/remove lora adapters.
756-
await self.engine.stop_remote_worker_execution_loop_async()
757-
await self._request_tracker.wait_for_new_requests()
765+
await engine.engine.stop_remote_worker_execution_loop_async()
766+
request_tracker = engine._request_tracker
767+
# Allow engine to be garbage collected while
768+
# waiting for new requests
769+
del engine
770+
await asyncio.sleep(0)
771+
if engine_ref() is None:
772+
return
773+
await request_tracker.wait_for_new_requests()
774+
engine = engine_ref()
775+
if not engine:
776+
return
758777
logger.debug("Got new requests!")
759778
requests_in_progress = [
760-
asyncio.create_task(self.engine_step(ve))
779+
asyncio.create_task(engine.engine_step(ve))
761780
for ve in range(pipeline_parallel_size)
762781
]
763782
has_requests_in_progress = [True] * pipeline_parallel_size
@@ -775,19 +794,20 @@ async def run_engine_loop(self):
775794
result = task.result()
776795
virtual_engine = requests_in_progress.index(task)
777796
has_unfinished_requests = (
778-
self.engine.has_unfinished_requests_for_virtual_engine(
797+
engine.engine.
798+
has_unfinished_requests_for_virtual_engine(
779799
virtual_engine))
780800
if result or has_unfinished_requests:
781801
requests_in_progress[virtual_engine] = (
782802
asyncio.create_task(
783-
self.engine_step(virtual_engine)))
803+
engine.engine_step(virtual_engine)))
784804
has_requests_in_progress[virtual_engine] = True
785805
else:
786806
has_requests_in_progress[virtual_engine] = False
787807
except asyncio.TimeoutError as exc:
788808
logger.error(
789809
"Engine iteration timed out. This should never happen!")
790-
self.set_errored(exc)
810+
engine.set_errored(exc)
791811
raise
792812
await asyncio.sleep(0)
793813

vllm/engine/llm_engine.py

+13-8
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,8 @@
1-
import functools
21
import time
32
from collections import deque
43
from contextlib import contextmanager
54
from dataclasses import dataclass
5+
from functools import partial
66
from typing import (TYPE_CHECKING, Any, Callable, ClassVar, Deque, Dict,
77
Iterable, List, Mapping, NamedTuple, Optional)
88
from typing import Sequence as GenericSequence
@@ -51,7 +51,7 @@
5151
BaseTokenizerGroup, init_tokenizer_from_configs)
5252
from vllm.usage.usage_lib import (UsageContext, is_usage_stats_enabled,
5353
usage_message)
54-
from vllm.utils import Counter, Device
54+
from vllm.utils import Counter, Device, weak_bind
5555
from vllm.version import __version__ as VLLM_VERSION
5656

5757
logger = init_logger(__name__)
@@ -382,11 +382,16 @@ def get_tokenizer_for_seq(sequence: Sequence) -> AnyTokenizer:
382382
for _ in range(self.parallel_config.pipeline_parallel_size)
383383
]
384384

385-
self.async_callbacks = [
386-
functools.partial(self._process_model_outputs,
387-
ctx=self.scheduler_contexts[v_id])
388-
for v_id in range(self.parallel_config.pipeline_parallel_size)
389-
]
385+
if model_config.use_async_output_proc:
386+
process_model_outputs = weak_bind(self._process_model_outputs)
387+
388+
self.async_callbacks = [
389+
partial(process_model_outputs,
390+
ctx=self.scheduler_contexts[v_id])
391+
for v_id in range(self.parallel_config.pipeline_parallel_size)
392+
]
393+
else:
394+
self.async_callbacks = []
390395

391396
# Currently used by AsyncLLMEngine to ensure quick append
392397
# of request outputs to asyncio queues
@@ -916,8 +921,8 @@ def has_unfinished_requests_for_virtual_engine(
916921
"""
917922
return self.scheduler[virtual_engine].has_unfinished_seqs()
918923

924+
@staticmethod
919925
def _process_sequence_group_outputs(
920-
self,
921926
seq_group: SequenceGroup,
922927
outputs: List[EmbeddingSequenceGroupOutput],
923928
) -> None:

vllm/entrypoints/launcher.py

+10-11
Original file line numberDiff line numberDiff line change
@@ -1,21 +1,20 @@
11
import asyncio
22
import signal
33
from http import HTTPStatus
4-
from typing import Any
4+
from typing import Any, Optional
55

66
import uvicorn
7-
from fastapi import FastAPI, Response
7+
from fastapi import FastAPI, Request, Response
88

99
from vllm import envs
1010
from vllm.engine.async_llm_engine import AsyncEngineDeadError
11-
from vllm.engine.protocol import AsyncEngineClient
1211
from vllm.logger import init_logger
1312
from vllm.utils import find_process_using_port
1413

1514
logger = init_logger(__name__)
1615

1716

18-
async def serve_http(app: FastAPI, engine: AsyncEngineClient,
17+
async def serve_http(app: FastAPI, limit_concurrency: Optional[int],
1918
**uvicorn_kwargs: Any):
2019
logger.info("Available routes are:")
2120
for route in app.routes:
@@ -29,16 +28,16 @@ async def serve_http(app: FastAPI, engine: AsyncEngineClient,
2928

3029
# Set concurrency limits in uvicorn if running in multiprocessing mode
3130
# since zmq has maximum socket limit of zmq.constants.SOCKET_LIMIT (65536).
32-
if engine.limit_concurrency is not None:
31+
if limit_concurrency is not None:
3332
logger.info(
3433
"Launching Uvicorn with --limit_concurrency %s. To avoid this "
3534
"limit at the expense of performance run with "
36-
"--disable-frontend-multiprocessing", engine.limit_concurrency)
37-
uvicorn_kwargs["limit_concurrency"] = engine.limit_concurrency
35+
"--disable-frontend-multiprocessing", limit_concurrency)
36+
uvicorn_kwargs["limit_concurrency"] = limit_concurrency
3837

3938
config = uvicorn.Config(app, **uvicorn_kwargs)
4039
server = uvicorn.Server(config)
41-
_add_shutdown_handlers(app, server, engine)
40+
_add_shutdown_handlers(app, server)
4241

4342
loop = asyncio.get_running_loop()
4443

@@ -68,15 +67,15 @@ async def dummy_shutdown() -> None:
6867
return server.shutdown()
6968

7069

71-
def _add_shutdown_handlers(app: FastAPI, server: uvicorn.Server,
72-
engine: AsyncEngineClient) -> None:
70+
def _add_shutdown_handlers(app: FastAPI, server: uvicorn.Server) -> None:
7371
"""Adds handlers for fatal errors that should crash the server"""
7472

7573
@app.exception_handler(RuntimeError)
76-
async def runtime_error_handler(_, __):
74+
async def runtime_error_handler(request: Request, __):
7775
"""On generic runtime error, check to see if the engine has died.
7876
It probably has, in which case the server will no longer be able to
7977
handle requests. Trigger a graceful shutdown with a SIGTERM."""
78+
engine = request.app.state.engine_client
8079
if (not envs.VLLM_KEEP_ALIVE_ON_ENGINE_DEATH and engine.errored
8180
and not engine.is_running):
8281
logger.fatal("AsyncLLMEngine has failed, terminating server "

0 commit comments

Comments
 (0)