1
1
import asyncio
2
2
import time
3
+ import weakref
3
4
from functools import partial
4
5
from typing import (Any , AsyncGenerator , Callable , Dict , Iterable , List ,
5
6
Mapping , Optional , Set , Tuple , Type , Union )
7
+ from weakref import ReferenceType
6
8
7
9
import vllm .envs as envs
8
10
from vllm .config import (DecodingConfig , EngineConfig , LoRAConfig , ModelConfig ,
26
28
from vllm .sequence import ExecuteModelRequest
27
29
from vllm .transformers_utils .tokenizer import AnyTokenizer
28
30
from vllm .usage .usage_lib import UsageContext
31
+ from vllm .utils import weak_bind
29
32
30
33
logger = init_logger (__name__ )
31
34
ENGINE_ITERATION_TIMEOUT_S = envs .VLLM_ENGINE_ITERATION_TIMEOUT_S
@@ -492,9 +495,6 @@ class AsyncLLMEngine:
492
495
method yields the outputs from the :class:`LLMEngine` to the caller.
493
496
494
497
Args:
495
- worker_use_ray: Whether to use Ray for model workers. Required for
496
- distributed execution. Should be the same as
497
- `parallel_config.worker_use_ray`.
498
498
log_requests: Whether to log the requests.
499
499
start_engine_loop: If True, the background task to run the engine
500
500
will be automatically started in the generate call.
@@ -505,23 +505,22 @@ class AsyncLLMEngine:
505
505
_engine_class : Type [_AsyncLLMEngine ] = _AsyncLLMEngine
506
506
507
507
def __init__ (self ,
508
- worker_use_ray : bool ,
509
508
* args ,
510
509
log_requests : bool = True ,
511
510
start_engine_loop : bool = True ,
512
511
** kwargs ) -> None :
513
- self .worker_use_ray = worker_use_ray
514
512
self .log_requests = log_requests
515
513
self .engine = self ._engine_class (* args , ** kwargs )
516
514
517
515
# This ensures quick processing of request outputs
518
516
# so the append to asyncio queues is not delayed,
519
517
# especially for multi-step.
520
- #
521
- self .use_process_request_outputs_callback = True
518
+ self .use_process_request_outputs_callback = (
519
+ self .engine .model_config .use_async_output_proc )
520
+
522
521
if self .use_process_request_outputs_callback :
523
522
self .engine .process_request_outputs_callback = \
524
- self .process_request_outputs
523
+ weak_bind ( self .process_request_outputs )
525
524
526
525
self .background_loop : Optional [asyncio .Future ] = None
527
526
# We need to keep a reference to unshielded
@@ -534,6 +533,11 @@ def __init__(self,
534
533
# Lazy initialized fields
535
534
self ._request_tracker : RequestTracker
536
535
536
+ def __del__ (self ):
537
+ if rt := getattr (self , "request_tracker" , None ):
538
+ # Wake up engine loop so that it will exit cleanly
539
+ rt .new_requests_event .set ()
540
+
537
541
@classmethod
538
542
def _get_executor_cls (
539
543
cls , engine_config : EngineConfig ) -> Type [ExecutorAsyncBase ]:
@@ -544,15 +548,12 @@ def _get_executor_cls(
544
548
raise TypeError (
545
549
"distributed_executor_backend must be a subclass of "
546
550
f"ExecutorAsyncBase. Got { distributed_executor_backend } ." )
547
- if distributed_executor_backend .uses_ray : # type: ignore
548
- initialize_ray_cluster (engine_config .parallel_config )
549
551
executor_class = distributed_executor_backend
550
552
elif engine_config .device_config .device_type == "neuron" :
551
553
from vllm .executor .neuron_executor import NeuronExecutorAsync
552
554
executor_class = NeuronExecutorAsync
553
555
elif engine_config .device_config .device_type == "tpu" :
554
556
if distributed_executor_backend == "ray" :
555
- initialize_ray_cluster (engine_config .parallel_config )
556
557
from vllm .executor .ray_tpu_executor import RayTPUExecutorAsync
557
558
executor_class = RayTPUExecutorAsync
558
559
else :
@@ -573,19 +574,16 @@ def _get_executor_cls(
573
574
from vllm .executor .xpu_executor import XPUExecutorAsync
574
575
executor_class = XPUExecutorAsync
575
576
elif distributed_executor_backend == "ray" :
576
- initialize_ray_cluster (engine_config .parallel_config )
577
577
from vllm .executor .ray_xpu_executor import RayXPUExecutorAsync
578
578
executor_class = RayXPUExecutorAsync
579
579
elif distributed_executor_backend == "mp" :
580
- initialize_ray_cluster (engine_config .parallel_config )
581
580
from vllm .executor .multiproc_xpu_executor import (
582
581
MultiprocessingXPUExecutorAsync )
583
582
executor_class = MultiprocessingXPUExecutorAsync
584
583
else :
585
584
raise RuntimeError (
586
585
"Not supported distributed execution model on XPU device." )
587
586
elif distributed_executor_backend == "ray" :
588
- initialize_ray_cluster (engine_config .parallel_config )
589
587
from vllm .executor .ray_gpu_executor import RayGPUExecutorAsync
590
588
executor_class = RayGPUExecutorAsync
591
589
elif distributed_executor_backend == "mp" :
@@ -601,19 +599,23 @@ def _get_executor_cls(
601
599
def from_engine_args (
602
600
cls ,
603
601
engine_args : AsyncEngineArgs ,
602
+ engine_config : Optional [EngineConfig ] = None ,
604
603
start_engine_loop : bool = True ,
605
604
usage_context : UsageContext = UsageContext .ENGINE_CONTEXT ,
606
605
stat_loggers : Optional [Dict [str , StatLoggerBase ]] = None ,
607
606
) -> "AsyncLLMEngine" :
608
607
"""Creates an async LLM engine from the engine arguments."""
609
608
# Create the engine configs.
610
- engine_config = engine_args .create_engine_config ()
609
+ if engine_config is None :
610
+ engine_config = engine_args .create_engine_config ()
611
611
612
612
executor_class = cls ._get_executor_cls (engine_config )
613
613
614
+ if executor_class .uses_ray :
615
+ initialize_ray_cluster (engine_config .parallel_config )
616
+
614
617
# Create the async LLM engine.
615
618
engine = cls (
616
- executor_class .uses_ray ,
617
619
** engine_config .to_dict (),
618
620
executor_class = executor_class ,
619
621
log_requests = not engine_args .disable_log_requests ,
@@ -670,7 +672,7 @@ def start_background_loop(self) -> None:
670
672
self ._request_tracker = RequestTracker ()
671
673
672
674
self ._background_loop_unshielded = asyncio .get_event_loop (
673
- ).create_task (self .run_engine_loop ())
675
+ ).create_task (self .run_engine_loop (weakref . ref ( self ) ))
674
676
self ._background_loop_unshielded .add_done_callback (
675
677
partial (_log_task_completion , error_callback = self ._error_callback ))
676
678
self .background_loop = asyncio .shield (self ._background_loop_unshielded )
@@ -740,9 +742,16 @@ def process_request_outputs(self, request_outputs) -> bool:
740
742
async def _engine_abort (self , request_ids : Iterable [str ]):
741
743
self .engine .abort_request (request_ids )
742
744
743
- async def run_engine_loop (self ):
745
+ @staticmethod
746
+ async def run_engine_loop (engine_ref : ReferenceType ):
747
+ """We use a weakref to the engine so that the running loop
748
+ doesn't prevent the engine being garbage collected."""
749
+ engine : Optional ["AsyncLLMEngine" ] = engine_ref ()
750
+ if not engine :
751
+ return
752
+
744
753
pipeline_parallel_size = \
745
- self .engine .parallel_config .pipeline_parallel_size
754
+ engine .engine .parallel_config .pipeline_parallel_size
746
755
has_requests_in_progress = [False ] * pipeline_parallel_size
747
756
while True :
748
757
if not any (has_requests_in_progress ):
@@ -753,11 +762,21 @@ async def run_engine_loop(self):
753
762
# timeout, and unblocks the RPC thread in the workers so that
754
763
# they can process any other queued control plane messages,
755
764
# such as add/remove lora adapters.
756
- await self .engine .stop_remote_worker_execution_loop_async ()
757
- await self ._request_tracker .wait_for_new_requests ()
765
+ await engine .engine .stop_remote_worker_execution_loop_async ()
766
+ request_tracker = engine ._request_tracker
767
+ # Allow engine to be garbage collected while
768
+ # waiting for new requests
769
+ del engine
770
+ await asyncio .sleep (0 )
771
+ if engine_ref () is None :
772
+ return
773
+ await request_tracker .wait_for_new_requests ()
774
+ engine = engine_ref ()
775
+ if not engine :
776
+ return
758
777
logger .debug ("Got new requests!" )
759
778
requests_in_progress = [
760
- asyncio .create_task (self .engine_step (ve ))
779
+ asyncio .create_task (engine .engine_step (ve ))
761
780
for ve in range (pipeline_parallel_size )
762
781
]
763
782
has_requests_in_progress = [True ] * pipeline_parallel_size
@@ -775,19 +794,20 @@ async def run_engine_loop(self):
775
794
result = task .result ()
776
795
virtual_engine = requests_in_progress .index (task )
777
796
has_unfinished_requests = (
778
- self .engine .has_unfinished_requests_for_virtual_engine (
797
+ engine .engine .
798
+ has_unfinished_requests_for_virtual_engine (
779
799
virtual_engine ))
780
800
if result or has_unfinished_requests :
781
801
requests_in_progress [virtual_engine ] = (
782
802
asyncio .create_task (
783
- self .engine_step (virtual_engine )))
803
+ engine .engine_step (virtual_engine )))
784
804
has_requests_in_progress [virtual_engine ] = True
785
805
else :
786
806
has_requests_in_progress [virtual_engine ] = False
787
807
except asyncio .TimeoutError as exc :
788
808
logger .error (
789
809
"Engine iteration timed out. This should never happen!" )
790
- self .set_errored (exc )
810
+ engine .set_errored (exc )
791
811
raise
792
812
await asyncio .sleep (0 )
793
813
0 commit comments