1
1
import asyncio
2
2
import time
3
+ import weakref
3
4
from functools import partial
4
5
from typing import (Any , AsyncGenerator , Callable , Dict , Iterable , List ,
5
6
Mapping , Optional , Set , Tuple , Type , Union )
7
+ from weakref import ReferenceType
6
8
7
9
import vllm .envs as envs
8
10
from vllm .config import (DecodingConfig , EngineConfig , LoRAConfig , ModelConfig ,
26
28
from vllm .sequence import ExecuteModelRequest
27
29
from vllm .transformers_utils .tokenizer import AnyTokenizer
28
30
from vllm .usage .usage_lib import UsageContext
31
+ from vllm .utils import weak_bind
29
32
30
33
logger = init_logger (__name__ )
31
34
ENGINE_ITERATION_TIMEOUT_S = envs .VLLM_ENGINE_ITERATION_TIMEOUT_S
@@ -450,9 +453,6 @@ class AsyncLLMEngine:
450
453
method yields the outputs from the :class:`LLMEngine` to the caller.
451
454
452
455
Args:
453
- worker_use_ray: Whether to use Ray for model workers. Required for
454
- distributed execution. Should be the same as
455
- `parallel_config.worker_use_ray`.
456
456
log_requests: Whether to log the requests.
457
457
start_engine_loop: If True, the background task to run the engine
458
458
will be automatically started in the generate call.
@@ -463,23 +463,22 @@ class AsyncLLMEngine:
463
463
_engine_class : Type [_AsyncLLMEngine ] = _AsyncLLMEngine
464
464
465
465
def __init__ (self ,
466
- worker_use_ray : bool ,
467
466
* args ,
468
467
log_requests : bool = True ,
469
468
start_engine_loop : bool = True ,
470
469
** kwargs ) -> None :
471
- self .worker_use_ray = worker_use_ray
472
470
self .log_requests = log_requests
473
471
self .engine = self ._engine_class (* args , ** kwargs )
474
472
475
473
# This ensures quick processing of request outputs
476
474
# so the append to asyncio queues is not delayed,
477
475
# especially for multi-step.
478
- #
479
- self .use_process_request_outputs_callback = True
476
+ self .use_process_request_outputs_callback = (
477
+ self .engine .model_config .use_async_output_proc )
478
+
480
479
if self .use_process_request_outputs_callback :
481
480
self .engine .process_request_outputs_callback = \
482
- self .process_request_outputs
481
+ weak_bind ( self .process_request_outputs )
483
482
484
483
self .background_loop : Optional [asyncio .Future ] = None
485
484
# We need to keep a reference to unshielded
@@ -492,6 +491,11 @@ def __init__(self,
492
491
# Lazy initialized fields
493
492
self ._request_tracker : RequestTracker
494
493
494
+ def __del__ (self ):
495
+ if rt := getattr (self , "request_tracker" , None ):
496
+ # Wake up engine loop so that it will exit cleanly
497
+ rt .new_requests_event .set ()
498
+
495
499
@classmethod
496
500
def _get_executor_cls (
497
501
cls , engine_config : EngineConfig ) -> Type [ExecutorAsyncBase ]:
@@ -502,15 +506,12 @@ def _get_executor_cls(
502
506
raise TypeError (
503
507
"distributed_executor_backend must be a subclass of "
504
508
f"ExecutorAsyncBase. Got { distributed_executor_backend } ." )
505
- if distributed_executor_backend .uses_ray : # type: ignore
506
- initialize_ray_cluster (engine_config .parallel_config )
507
509
executor_class = distributed_executor_backend
508
510
elif engine_config .device_config .device_type == "neuron" :
509
511
from vllm .executor .neuron_executor import NeuronExecutorAsync
510
512
executor_class = NeuronExecutorAsync
511
513
elif engine_config .device_config .device_type == "tpu" :
512
514
if distributed_executor_backend == "ray" :
513
- initialize_ray_cluster (engine_config .parallel_config )
514
515
from vllm .executor .ray_tpu_executor import RayTPUExecutorAsync
515
516
executor_class = RayTPUExecutorAsync
516
517
else :
@@ -531,19 +532,16 @@ def _get_executor_cls(
531
532
from vllm .executor .xpu_executor import XPUExecutorAsync
532
533
executor_class = XPUExecutorAsync
533
534
elif distributed_executor_backend == "ray" :
534
- initialize_ray_cluster (engine_config .parallel_config )
535
535
from vllm .executor .ray_xpu_executor import RayXPUExecutorAsync
536
536
executor_class = RayXPUExecutorAsync
537
537
elif distributed_executor_backend == "mp" :
538
- initialize_ray_cluster (engine_config .parallel_config )
539
538
from vllm .executor .multiproc_xpu_executor import (
540
539
MultiprocessingXPUExecutorAsync )
541
540
executor_class = MultiprocessingXPUExecutorAsync
542
541
else :
543
542
raise RuntimeError (
544
543
"Not supported distributed execution model on XPU device." )
545
544
elif distributed_executor_backend == "ray" :
546
- initialize_ray_cluster (engine_config .parallel_config )
547
545
from vllm .executor .ray_gpu_executor import RayGPUExecutorAsync
548
546
executor_class = RayGPUExecutorAsync
549
547
elif distributed_executor_backend == "mp" :
@@ -559,19 +557,23 @@ def _get_executor_cls(
559
557
def from_engine_args (
560
558
cls ,
561
559
engine_args : AsyncEngineArgs ,
560
+ engine_config : Optional [EngineConfig ] = None ,
562
561
start_engine_loop : bool = True ,
563
562
usage_context : UsageContext = UsageContext .ENGINE_CONTEXT ,
564
563
stat_loggers : Optional [Dict [str , StatLoggerBase ]] = None ,
565
564
) -> "AsyncLLMEngine" :
566
565
"""Creates an async LLM engine from the engine arguments."""
567
566
# Create the engine configs.
568
- engine_config = engine_args .create_engine_config ()
567
+ if engine_config is None :
568
+ engine_config = engine_args .create_engine_config ()
569
569
570
570
executor_class = cls ._get_executor_cls (engine_config )
571
571
572
+ if executor_class .uses_ray :
573
+ initialize_ray_cluster (engine_config .parallel_config )
574
+
572
575
# Create the async LLM engine.
573
576
engine = cls (
574
- executor_class .uses_ray ,
575
577
** engine_config .to_dict (),
576
578
executor_class = executor_class ,
577
579
log_requests = not engine_args .disable_log_requests ,
@@ -628,7 +630,7 @@ def start_background_loop(self) -> None:
628
630
self ._request_tracker = RequestTracker ()
629
631
630
632
self ._background_loop_unshielded = asyncio .get_event_loop (
631
- ).create_task (self .run_engine_loop ())
633
+ ).create_task (self .run_engine_loop (weakref . ref ( self ) ))
632
634
self ._background_loop_unshielded .add_done_callback (
633
635
partial (_log_task_completion , error_callback = self ._error_callback ))
634
636
self .background_loop = asyncio .shield (self ._background_loop_unshielded )
@@ -698,9 +700,16 @@ def process_request_outputs(self, request_outputs) -> bool:
698
700
async def _engine_abort (self , request_ids : Iterable [str ]):
699
701
self .engine .abort_request (request_ids )
700
702
701
- async def run_engine_loop (self ):
703
+ @staticmethod
704
+ async def run_engine_loop (engine_ref : ReferenceType ):
705
+ """We use a weakref to the engine so that the running loop
706
+ doesn't prevent the engine being garbage collected."""
707
+ engine : Optional ["AsyncLLMEngine" ] = engine_ref ()
708
+ if not engine :
709
+ return
710
+
702
711
pipeline_parallel_size = \
703
- self .engine .parallel_config .pipeline_parallel_size
712
+ engine .engine .parallel_config .pipeline_parallel_size
704
713
has_requests_in_progress = [False ] * pipeline_parallel_size
705
714
while True :
706
715
if not any (has_requests_in_progress ):
@@ -711,11 +720,21 @@ async def run_engine_loop(self):
711
720
# timeout, and unblocks the RPC thread in the workers so that
712
721
# they can process any other queued control plane messages,
713
722
# such as add/remove lora adapters.
714
- await self .engine .stop_remote_worker_execution_loop_async ()
715
- await self ._request_tracker .wait_for_new_requests ()
723
+ await engine .engine .stop_remote_worker_execution_loop_async ()
724
+ request_tracker = engine ._request_tracker
725
+ # Allow engine to be garbage collected while
726
+ # waiting for new requests
727
+ del engine
728
+ await asyncio .sleep (0 )
729
+ if engine_ref () is None :
730
+ return
731
+ await request_tracker .wait_for_new_requests ()
732
+ engine = engine_ref ()
733
+ if not engine :
734
+ return
716
735
logger .debug ("Got new requests!" )
717
736
requests_in_progress = [
718
- asyncio .create_task (self .engine_step (ve ))
737
+ asyncio .create_task (engine .engine_step (ve ))
719
738
for ve in range (pipeline_parallel_size )
720
739
]
721
740
has_requests_in_progress = [True ] * pipeline_parallel_size
@@ -733,19 +752,20 @@ async def run_engine_loop(self):
733
752
result = task .result ()
734
753
virtual_engine = requests_in_progress .index (task )
735
754
has_unfinished_requests = (
736
- self .engine .has_unfinished_requests_for_virtual_engine (
755
+ engine .engine .
756
+ has_unfinished_requests_for_virtual_engine (
737
757
virtual_engine ))
738
758
if result or has_unfinished_requests :
739
759
requests_in_progress [virtual_engine ] = (
740
760
asyncio .create_task (
741
- self .engine_step (virtual_engine )))
761
+ engine .engine_step (virtual_engine )))
742
762
has_requests_in_progress [virtual_engine ] = True
743
763
else :
744
764
has_requests_in_progress [virtual_engine ] = False
745
765
except asyncio .TimeoutError as exc :
746
766
logger .error (
747
767
"Engine iteration timed out. This should never happen!" )
748
- self .set_errored (exc )
768
+ engine .set_errored (exc )
749
769
raise
750
770
await asyncio .sleep (0 )
751
771
0 commit comments