@@ -202,7 +202,7 @@ def schedule(self) -> "SchedulerOutput":
202
202
# which have output tokens.
203
203
num_new_tokens = request .num_tokens - num_computed_tokens
204
204
if num_new_tokens == 0 :
205
- # The happens when prompt length is divisible by the block
205
+ # This happens when prompt length is divisible by the block
206
206
# size and all blocks are cached. Now we force to recompute
207
207
# the last block. Note that we have to re-compute an entire
208
208
# block because allocate_slots() assumes num_computed_tokens
@@ -269,6 +269,7 @@ def schedule(self) -> "SchedulerOutput":
269
269
270
270
# Get the longest common prefix among all requests in the running queue.
271
271
# This can be potentially used for cascade attention.
272
+ num_common_prefix_blocks = 0
272
273
if self .running :
273
274
any_request = self .running [0 ]
274
275
num_common_prefix_blocks = (
@@ -433,7 +434,8 @@ def update_from_output(
433
434
if start_pos + num_tokens <= request .num_computed_tokens :
434
435
# The encoder output is already processed and stored
435
436
# in the decoder's KV cache.
436
- self .encoder_cache_manager .free (request , input_id )
437
+ self .encoder_cache_manager .free_encoder_input (
438
+ request , input_id )
437
439
438
440
if request .num_computed_tokens == request .num_tokens :
439
441
req_index = model_runner_output .req_id_to_index [req_id ]
@@ -445,8 +447,10 @@ def update_from_output(
445
447
# TODO: Update the KV cache manager for prefix caching.
446
448
447
449
# Check for stop and update request state.
448
- # This must be called before me make the EngineCoreOutput.
450
+ # This must be called before we make the EngineCoreOutput.
449
451
stopped = self ._check_stop (request )
452
+ if stopped :
453
+ self ._free_request (request )
450
454
451
455
# Add EngineCoreOutput for this Request.
452
456
output = EngineCoreOutput (
@@ -472,21 +476,18 @@ def _check_stop(self, request: Request) -> bool:
472
476
if (request .num_tokens >= self .max_model_len
473
477
or request .num_output_tokens >= request .max_tokens ):
474
478
request .status = RequestStatus .FINISHED_LENGTH_CAPPED
475
- self ._free_request (request )
476
479
return True
477
480
478
481
sampling_params = request .sampling_params
479
482
last_token_id = request .output_token_ids [- 1 ]
480
483
if (not sampling_params .ignore_eos
481
484
and last_token_id == request .eos_token_id ):
482
485
request .status = RequestStatus .FINISHED_STOPPED
483
- self ._free_request (request )
484
486
return True
485
487
486
488
if last_token_id in (sampling_params .stop_token_ids or ()):
487
489
request .status = RequestStatus .FINISHED_STOPPED
488
490
request .stop_reason = last_token_id
489
- self ._free_request (request )
490
491
return True
491
492
return False
492
493
@@ -525,6 +526,7 @@ def finish_requests(
525
526
def _free_request (self , request : Request ) -> None :
526
527
assert request .is_finished ()
527
528
self .kv_cache_manager .free (request )
529
+ self .encoder_cache_manager .free (request )
528
530
self .running_reqs_data .pop (request .request_id , None )
529
531
del self .requests [request .request_id ]
530
532
self .finished_req_ids .add (request .request_id )
0 commit comments