Skip to content

Commit c0026e3

Browse files
njhilldtrifiro
authored andcommitted
[BugFix] Fix clean shutdown issues (vllm-project#8492)
1 parent efc78c9 commit c0026e3

11 files changed

+215
-136
lines changed

tests/async_engine/test_async_llm_engine.py

+8-2
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,11 @@ class RequestOutput:
2626
finished: bool = False
2727

2828

29+
@dataclass
30+
class MockModelConfig:
31+
use_async_output_proc = True
32+
33+
2934
class MockEngine:
3035

3136
def __init__(self):
@@ -35,6 +40,7 @@ def __init__(self):
3540
self.request_id = None
3641
# Ugly, remove dependency when possible
3742
self.parallel_config = ParallelConfig(1, 1, False)
43+
self.model_config = MockModelConfig()
3844

3945
async def step_async(self, virtual_engine):
4046
# PP size is 1, ignore virtual engine
@@ -80,7 +86,7 @@ class MockAsyncLLMEngine(AsyncLLMEngine):
8086

8187
@pytest.mark.asyncio
8288
async def test_new_requests_event():
83-
engine = MockAsyncLLMEngine(worker_use_ray=False)
89+
engine = MockAsyncLLMEngine()
8490
engine.start_background_loop()
8591
await asyncio.sleep(0.01)
8692
assert engine.engine.step_calls == 0
@@ -113,7 +119,7 @@ async def test_new_requests_event():
113119
assert engine.engine.add_request_calls == 3
114120
assert engine.engine.step_calls == old_step_calls + 1
115121

116-
engine = MockAsyncLLMEngine(worker_use_ray=True)
122+
engine = MockAsyncLLMEngine()
117123
assert engine.get_model_config() is not None
118124
assert engine.get_tokenizer() is not None
119125
assert engine.get_decoding_config() is not None

vllm/engine/async_llm_engine.py

+45-25
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,10 @@
11
import asyncio
22
import time
3+
import weakref
34
from functools import partial
45
from typing import (Any, AsyncGenerator, Callable, Dict, Iterable, List,
56
Mapping, Optional, Set, Tuple, Type, Union)
7+
from weakref import ReferenceType
68

79
import vllm.envs as envs
810
from vllm.config import (DecodingConfig, EngineConfig, LoRAConfig, ModelConfig,
@@ -26,6 +28,7 @@
2628
from vllm.sequence import ExecuteModelRequest
2729
from vllm.transformers_utils.tokenizer import AnyTokenizer
2830
from vllm.usage.usage_lib import UsageContext
31+
from vllm.utils import weak_bind
2932

3033
logger = init_logger(__name__)
3134
ENGINE_ITERATION_TIMEOUT_S = envs.VLLM_ENGINE_ITERATION_TIMEOUT_S
@@ -450,9 +453,6 @@ class AsyncLLMEngine:
450453
method yields the outputs from the :class:`LLMEngine` to the caller.
451454
452455
Args:
453-
worker_use_ray: Whether to use Ray for model workers. Required for
454-
distributed execution. Should be the same as
455-
`parallel_config.worker_use_ray`.
456456
log_requests: Whether to log the requests.
457457
start_engine_loop: If True, the background task to run the engine
458458
will be automatically started in the generate call.
@@ -463,23 +463,22 @@ class AsyncLLMEngine:
463463
_engine_class: Type[_AsyncLLMEngine] = _AsyncLLMEngine
464464

465465
def __init__(self,
466-
worker_use_ray: bool,
467466
*args,
468467
log_requests: bool = True,
469468
start_engine_loop: bool = True,
470469
**kwargs) -> None:
471-
self.worker_use_ray = worker_use_ray
472470
self.log_requests = log_requests
473471
self.engine = self._engine_class(*args, **kwargs)
474472

475473
# This ensures quick processing of request outputs
476474
# so the append to asyncio queues is not delayed,
477475
# especially for multi-step.
478-
#
479-
self.use_process_request_outputs_callback = True
476+
self.use_process_request_outputs_callback = (
477+
self.engine.model_config.use_async_output_proc)
478+
480479
if self.use_process_request_outputs_callback:
481480
self.engine.process_request_outputs_callback = \
482-
self.process_request_outputs
481+
weak_bind(self.process_request_outputs)
483482

484483
self.background_loop: Optional[asyncio.Future] = None
485484
# We need to keep a reference to unshielded
@@ -492,6 +491,11 @@ def __init__(self,
492491
# Lazy initialized fields
493492
self._request_tracker: RequestTracker
494493

494+
def __del__(self):
495+
if rt := getattr(self, "request_tracker", None):
496+
# Wake up engine loop so that it will exit cleanly
497+
rt.new_requests_event.set()
498+
495499
@classmethod
496500
def _get_executor_cls(
497501
cls, engine_config: EngineConfig) -> Type[ExecutorAsyncBase]:
@@ -502,15 +506,12 @@ def _get_executor_cls(
502506
raise TypeError(
503507
"distributed_executor_backend must be a subclass of "
504508
f"ExecutorAsyncBase. Got {distributed_executor_backend}.")
505-
if distributed_executor_backend.uses_ray: # type: ignore
506-
initialize_ray_cluster(engine_config.parallel_config)
507509
executor_class = distributed_executor_backend
508510
elif engine_config.device_config.device_type == "neuron":
509511
from vllm.executor.neuron_executor import NeuronExecutorAsync
510512
executor_class = NeuronExecutorAsync
511513
elif engine_config.device_config.device_type == "tpu":
512514
if distributed_executor_backend == "ray":
513-
initialize_ray_cluster(engine_config.parallel_config)
514515
from vllm.executor.ray_tpu_executor import RayTPUExecutorAsync
515516
executor_class = RayTPUExecutorAsync
516517
else:
@@ -531,19 +532,16 @@ def _get_executor_cls(
531532
from vllm.executor.xpu_executor import XPUExecutorAsync
532533
executor_class = XPUExecutorAsync
533534
elif distributed_executor_backend == "ray":
534-
initialize_ray_cluster(engine_config.parallel_config)
535535
from vllm.executor.ray_xpu_executor import RayXPUExecutorAsync
536536
executor_class = RayXPUExecutorAsync
537537
elif distributed_executor_backend == "mp":
538-
initialize_ray_cluster(engine_config.parallel_config)
539538
from vllm.executor.multiproc_xpu_executor import (
540539
MultiprocessingXPUExecutorAsync)
541540
executor_class = MultiprocessingXPUExecutorAsync
542541
else:
543542
raise RuntimeError(
544543
"Not supported distributed execution model on XPU device.")
545544
elif distributed_executor_backend == "ray":
546-
initialize_ray_cluster(engine_config.parallel_config)
547545
from vllm.executor.ray_gpu_executor import RayGPUExecutorAsync
548546
executor_class = RayGPUExecutorAsync
549547
elif distributed_executor_backend == "mp":
@@ -559,19 +557,23 @@ def _get_executor_cls(
559557
def from_engine_args(
560558
cls,
561559
engine_args: AsyncEngineArgs,
560+
engine_config: Optional[EngineConfig] = None,
562561
start_engine_loop: bool = True,
563562
usage_context: UsageContext = UsageContext.ENGINE_CONTEXT,
564563
stat_loggers: Optional[Dict[str, StatLoggerBase]] = None,
565564
) -> "AsyncLLMEngine":
566565
"""Creates an async LLM engine from the engine arguments."""
567566
# Create the engine configs.
568-
engine_config = engine_args.create_engine_config()
567+
if engine_config is None:
568+
engine_config = engine_args.create_engine_config()
569569

570570
executor_class = cls._get_executor_cls(engine_config)
571571

572+
if executor_class.uses_ray:
573+
initialize_ray_cluster(engine_config.parallel_config)
574+
572575
# Create the async LLM engine.
573576
engine = cls(
574-
executor_class.uses_ray,
575577
**engine_config.to_dict(),
576578
executor_class=executor_class,
577579
log_requests=not engine_args.disable_log_requests,
@@ -628,7 +630,7 @@ def start_background_loop(self) -> None:
628630
self._request_tracker = RequestTracker()
629631

630632
self._background_loop_unshielded = asyncio.get_event_loop(
631-
).create_task(self.run_engine_loop())
633+
).create_task(self.run_engine_loop(weakref.ref(self)))
632634
self._background_loop_unshielded.add_done_callback(
633635
partial(_log_task_completion, error_callback=self._error_callback))
634636
self.background_loop = asyncio.shield(self._background_loop_unshielded)
@@ -698,9 +700,16 @@ def process_request_outputs(self, request_outputs) -> bool:
698700
async def _engine_abort(self, request_ids: Iterable[str]):
699701
self.engine.abort_request(request_ids)
700702

701-
async def run_engine_loop(self):
703+
@staticmethod
704+
async def run_engine_loop(engine_ref: ReferenceType):
705+
"""We use a weakref to the engine so that the running loop
706+
doesn't prevent the engine being garbage collected."""
707+
engine: Optional["AsyncLLMEngine"] = engine_ref()
708+
if not engine:
709+
return
710+
702711
pipeline_parallel_size = \
703-
self.engine.parallel_config.pipeline_parallel_size
712+
engine.engine.parallel_config.pipeline_parallel_size
704713
has_requests_in_progress = [False] * pipeline_parallel_size
705714
while True:
706715
if not any(has_requests_in_progress):
@@ -711,11 +720,21 @@ async def run_engine_loop(self):
711720
# timeout, and unblocks the RPC thread in the workers so that
712721
# they can process any other queued control plane messages,
713722
# such as add/remove lora adapters.
714-
await self.engine.stop_remote_worker_execution_loop_async()
715-
await self._request_tracker.wait_for_new_requests()
723+
await engine.engine.stop_remote_worker_execution_loop_async()
724+
request_tracker = engine._request_tracker
725+
# Allow engine to be garbage collected while
726+
# waiting for new requests
727+
del engine
728+
await asyncio.sleep(0)
729+
if engine_ref() is None:
730+
return
731+
await request_tracker.wait_for_new_requests()
732+
engine = engine_ref()
733+
if not engine:
734+
return
716735
logger.debug("Got new requests!")
717736
requests_in_progress = [
718-
asyncio.create_task(self.engine_step(ve))
737+
asyncio.create_task(engine.engine_step(ve))
719738
for ve in range(pipeline_parallel_size)
720739
]
721740
has_requests_in_progress = [True] * pipeline_parallel_size
@@ -733,19 +752,20 @@ async def run_engine_loop(self):
733752
result = task.result()
734753
virtual_engine = requests_in_progress.index(task)
735754
has_unfinished_requests = (
736-
self.engine.has_unfinished_requests_for_virtual_engine(
755+
engine.engine.
756+
has_unfinished_requests_for_virtual_engine(
737757
virtual_engine))
738758
if result or has_unfinished_requests:
739759
requests_in_progress[virtual_engine] = (
740760
asyncio.create_task(
741-
self.engine_step(virtual_engine)))
761+
engine.engine_step(virtual_engine)))
742762
has_requests_in_progress[virtual_engine] = True
743763
else:
744764
has_requests_in_progress[virtual_engine] = False
745765
except asyncio.TimeoutError as exc:
746766
logger.error(
747767
"Engine iteration timed out. This should never happen!")
748-
self.set_errored(exc)
768+
engine.set_errored(exc)
749769
raise
750770
await asyncio.sleep(0)
751771

vllm/engine/llm_engine.py

+13-8
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,8 @@
1-
import functools
21
import time
32
from collections import deque
43
from contextlib import contextmanager
54
from dataclasses import dataclass
5+
from functools import partial
66
from typing import (TYPE_CHECKING, Any, Callable, ClassVar, Deque, Dict,
77
Iterable, List, Mapping, NamedTuple, Optional)
88
from typing import Sequence as GenericSequence
@@ -51,7 +51,7 @@
5151
BaseTokenizerGroup, init_tokenizer_from_configs)
5252
from vllm.usage.usage_lib import (UsageContext, is_usage_stats_enabled,
5353
usage_message)
54-
from vllm.utils import Counter, Device
54+
from vllm.utils import Counter, Device, weak_bind
5555
from vllm.version import __version__ as VLLM_VERSION
5656

5757
logger = init_logger(__name__)
@@ -382,11 +382,16 @@ def get_tokenizer_for_seq(sequence: Sequence) -> AnyTokenizer:
382382
for _ in range(self.parallel_config.pipeline_parallel_size)
383383
]
384384

385-
self.async_callbacks = [
386-
functools.partial(self._process_model_outputs,
387-
ctx=self.scheduler_contexts[v_id])
388-
for v_id in range(self.parallel_config.pipeline_parallel_size)
389-
]
385+
if model_config.use_async_output_proc:
386+
process_model_outputs = weak_bind(self._process_model_outputs)
387+
388+
self.async_callbacks = [
389+
partial(process_model_outputs,
390+
ctx=self.scheduler_contexts[v_id])
391+
for v_id in range(self.parallel_config.pipeline_parallel_size)
392+
]
393+
else:
394+
self.async_callbacks = []
390395

391396
# Currently used by AsyncLLMEngine to ensure quick append
392397
# of request outputs to asyncio queues
@@ -869,8 +874,8 @@ def has_unfinished_requests_for_virtual_engine(
869874
"""
870875
return self.scheduler[virtual_engine].has_unfinished_seqs()
871876

877+
@staticmethod
872878
def _process_sequence_group_outputs(
873-
self,
874879
seq_group: SequenceGroup,
875880
outputs: List[EmbeddingSequenceGroupOutput],
876881
) -> None:

vllm/entrypoints/launcher.py

+10-11
Original file line numberDiff line numberDiff line change
@@ -1,21 +1,20 @@
11
import asyncio
22
import signal
33
from http import HTTPStatus
4-
from typing import Any
4+
from typing import Any, Optional
55

66
import uvicorn
7-
from fastapi import FastAPI, Response
7+
from fastapi import FastAPI, Request, Response
88

99
from vllm import envs
1010
from vllm.engine.async_llm_engine import AsyncEngineDeadError
11-
from vllm.engine.protocol import AsyncEngineClient
1211
from vllm.logger import init_logger
1312
from vllm.utils import find_process_using_port
1413

1514
logger = init_logger(__name__)
1615

1716

18-
async def serve_http(app: FastAPI, engine: AsyncEngineClient,
17+
async def serve_http(app: FastAPI, limit_concurrency: Optional[int],
1918
**uvicorn_kwargs: Any):
2019
logger.info("Available routes are:")
2120
for route in app.routes:
@@ -29,16 +28,16 @@ async def serve_http(app: FastAPI, engine: AsyncEngineClient,
2928

3029
# Set concurrency limits in uvicorn if running in multiprocessing mode
3130
# since zmq has maximum socket limit of zmq.constants.SOCKET_LIMIT (65536).
32-
if engine.limit_concurrency is not None:
31+
if limit_concurrency is not None:
3332
logger.info(
3433
"Launching Uvicorn with --limit_concurrency %s. To avoid this "
3534
"limit at the expense of performance run with "
36-
"--disable-frontend-multiprocessing", engine.limit_concurrency)
37-
uvicorn_kwargs["limit_concurrency"] = engine.limit_concurrency
35+
"--disable-frontend-multiprocessing", limit_concurrency)
36+
uvicorn_kwargs["limit_concurrency"] = limit_concurrency
3837

3938
config = uvicorn.Config(app, **uvicorn_kwargs)
4039
server = uvicorn.Server(config)
41-
_add_shutdown_handlers(app, server, engine)
40+
_add_shutdown_handlers(app, server)
4241

4342
loop = asyncio.get_running_loop()
4443

@@ -68,15 +67,15 @@ async def dummy_shutdown() -> None:
6867
return server.shutdown()
6968

7069

71-
def _add_shutdown_handlers(app: FastAPI, server: uvicorn.Server,
72-
engine: AsyncEngineClient) -> None:
70+
def _add_shutdown_handlers(app: FastAPI, server: uvicorn.Server) -> None:
7371
"""Adds handlers for fatal errors that should crash the server"""
7472

7573
@app.exception_handler(RuntimeError)
76-
async def runtime_error_handler(_, __):
74+
async def runtime_error_handler(request: Request, __):
7775
"""On generic runtime error, check to see if the engine has died.
7876
It probably has, in which case the server will no longer be able to
7977
handle requests. Trigger a graceful shutdown with a SIGTERM."""
78+
engine = request.app.state.engine_client
8079
if (not envs.VLLM_KEEP_ALIVE_ON_ENGINE_DEATH and engine.errored
8180
and not engine.is_running):
8281
logger.fatal("AsyncLLMEngine has failed, terminating server "

0 commit comments

Comments
 (0)