Skip to content

Commit c4dd81a

Browse files
WoosukKwonNickLucche
authored andcommitted
[V1][BugFix] Free encoder cache for aborted requests (vllm-project#12545)
Signed-off-by: Woosuk Kwon <woosuk.kwon@berkeley.edu>
1 parent 72de617 commit c4dd81a

File tree

2 files changed

+16
-7
lines changed

2 files changed

+16
-7
lines changed

vllm/v1/core/encoder_cache_manager.py

+8-1
Original file line numberDiff line numberDiff line change
@@ -38,7 +38,8 @@ def allocate(self, request: Request, input_id: int) -> None:
3838
def get_cached_input_ids(self, request: Request) -> Set[int]:
3939
return self.cached.get(request.request_id, set())
4040

41-
def free(self, request: Request, input_id: int) -> None:
41+
def free_encoder_input(self, request: Request, input_id: int) -> None:
42+
"""Free a single encoder input id for the request."""
4243
req_id = request.request_id
4344
if req_id not in self.cached:
4445
return
@@ -49,6 +50,12 @@ def free(self, request: Request, input_id: int) -> None:
4950
self.num_free_slots += request.get_num_encoder_tokens(input_id)
5051
self.freed.append((req_id, input_id))
5152

53+
def free(self, request: Request) -> None:
54+
"""Free all cached input ids for the request."""
55+
input_ids = self.get_cached_input_ids(request)
56+
for input_id in input_ids:
57+
self.free_encoder_input(request, input_id)
58+
5259
def get_freed_ids(self) -> List[Tuple[str, int]]:
5360
freed = self.freed
5461
self.freed = []

vllm/v1/core/scheduler.py

+8-6
Original file line numberDiff line numberDiff line change
@@ -202,7 +202,7 @@ def schedule(self) -> "SchedulerOutput":
202202
# which have output tokens.
203203
num_new_tokens = request.num_tokens - num_computed_tokens
204204
if num_new_tokens == 0:
205-
# The happens when prompt length is divisible by the block
205+
# This happens when prompt length is divisible by the block
206206
# size and all blocks are cached. Now we force to recompute
207207
# the last block. Note that we have to re-compute an entire
208208
# block because allocate_slots() assumes num_computed_tokens
@@ -269,6 +269,7 @@ def schedule(self) -> "SchedulerOutput":
269269

270270
# Get the longest common prefix among all requests in the running queue.
271271
# This can be potentially used for cascade attention.
272+
num_common_prefix_blocks = 0
272273
if self.running:
273274
any_request = self.running[0]
274275
num_common_prefix_blocks = (
@@ -433,7 +434,8 @@ def update_from_output(
433434
if start_pos + num_tokens <= request.num_computed_tokens:
434435
# The encoder output is already processed and stored
435436
# in the decoder's KV cache.
436-
self.encoder_cache_manager.free(request, input_id)
437+
self.encoder_cache_manager.free_encoder_input(
438+
request, input_id)
437439

438440
if request.num_computed_tokens == request.num_tokens:
439441
req_index = model_runner_output.req_id_to_index[req_id]
@@ -445,8 +447,10 @@ def update_from_output(
445447
# TODO: Update the KV cache manager for prefix caching.
446448

447449
# Check for stop and update request state.
448-
# This must be called before me make the EngineCoreOutput.
450+
# This must be called before we make the EngineCoreOutput.
449451
stopped = self._check_stop(request)
452+
if stopped:
453+
self._free_request(request)
450454

451455
# Add EngineCoreOutput for this Request.
452456
output = EngineCoreOutput(
@@ -472,21 +476,18 @@ def _check_stop(self, request: Request) -> bool:
472476
if (request.num_tokens >= self.max_model_len
473477
or request.num_output_tokens >= request.max_tokens):
474478
request.status = RequestStatus.FINISHED_LENGTH_CAPPED
475-
self._free_request(request)
476479
return True
477480

478481
sampling_params = request.sampling_params
479482
last_token_id = request.output_token_ids[-1]
480483
if (not sampling_params.ignore_eos
481484
and last_token_id == request.eos_token_id):
482485
request.status = RequestStatus.FINISHED_STOPPED
483-
self._free_request(request)
484486
return True
485487

486488
if last_token_id in (sampling_params.stop_token_ids or ()):
487489
request.status = RequestStatus.FINISHED_STOPPED
488490
request.stop_reason = last_token_id
489-
self._free_request(request)
490491
return True
491492
return False
492493

@@ -525,6 +526,7 @@ def finish_requests(
525526
def _free_request(self, request: Request) -> None:
526527
assert request.is_finished()
527528
self.kv_cache_manager.free(request)
529+
self.encoder_cache_manager.free(request)
528530
self.running_reqs_data.pop(request.request_id, None)
529531
del self.requests[request.request_id]
530532
self.finished_req_ids.add(request.request_id)

0 commit comments

Comments
 (0)