From cf0cac2ba755ea641afbc48a79e21a31417167af Mon Sep 17 00:00:00 2001 From: Shawn Du Date: Wed, 29 Jan 2025 11:50:42 +0800 Subject: [PATCH 1/8] Add _untouch() to reverse touch() if not enough blocks Signed-off-by: Shawn Du --- vllm/v1/core/kv_cache_manager.py | 33 ++++++++++++++++++++------------ 1 file changed, 21 insertions(+), 12 deletions(-) diff --git a/vllm/v1/core/kv_cache_manager.py b/vllm/v1/core/kv_cache_manager.py index 18fdfdfe4a010..a7d5008d6adbe 100644 --- a/vllm/v1/core/kv_cache_manager.py +++ b/vllm/v1/core/kv_cache_manager.py @@ -207,18 +207,6 @@ def allocate_slots( raise ValueError( f"num_tokens must be greater than 0, got {num_tokens}") - # If a computed block of a request is an eviction candidate (in the - # free queue and ref_cnt == 0), it cannot be counted as a free block - # when allocating this request. - num_evictable_computed_blocks = sum(1 for blk in computed_blocks - if blk.ref_cnt == 0) - - num_required_blocks = cdiv(num_tokens, self.block_size) - if (num_required_blocks > self.free_block_queue.num_free_blocks - - num_evictable_computed_blocks): - # Cannot allocate new blocks. - return None - # Touch the computed blocks to make sure they won't be evicted. if self.enable_caching: self._touch(computed_blocks) @@ -227,6 +215,12 @@ def allocate_slots( "Computed blocks should be empty when " "prefix caching is disabled") + num_required_blocks = cdiv(num_tokens, self.block_size) + if num_required_blocks > self.free_block_queue.num_free_blocks: + # Cannot allocate new blocks + self._untouch(computed_blocks) + return None + # Determine the number of new blocks to allocate considering # preallocated blocks. num_new_blocks = min( @@ -471,6 +465,21 @@ def _touch(self, blocks: List[KVCacheBlock]) -> None: self.free_block_queue.remove(block) block.incr_ref() + def _untouch(self, blocks: List[KVCacheBlock]) -> None: + """Untouch a block decreases its reference count by 1, and may add + the block to the free queue. This is used to reverse touched blocks + that are hit by another request with the same prefix. + + Args: + blocks: A list of blocks to touch. + """ + for block in blocks: + block.decr_ref() + # ref_cnt=0 means this block can be added in the free list + # (i.e. eviction candidate). + if block.ref_cnt == 0: + self.free_block_queue.append(block) + def _cache_full_blocks( self, request: Request, From 5cedce5b77a1b45beefccfc1cea4b4352eab1912 Mon Sep 17 00:00:00 2001 From: Shawn Du Date: Wed, 29 Jan 2025 13:34:36 +0800 Subject: [PATCH 2/8] Combine allocate_slots and append_slots Signed-off-by: Shawn Du --- tests/v1/core/test_prefix_caching.py | 14 ++-- vllm/v1/core/kv_cache_manager.py | 97 +++++++++++++++++++--------- vllm/v1/core/scheduler.py | 4 +- 3 files changed, 75 insertions(+), 40 deletions(-) diff --git a/tests/v1/core/test_prefix_caching.py b/tests/v1/core/test_prefix_caching.py index f434fa8c61a80..eeb1d05de0255 100644 --- a/tests/v1/core/test_prefix_caching.py +++ b/tests/v1/core/test_prefix_caching.py @@ -164,7 +164,7 @@ def test_decode(): req0.num_computed_tokens = 55 for _ in range(4): req0.append_output_token_ids(8) - new_blocks = manager.append_slots(req0, 4) + new_blocks = manager.allocate_slots(req0, 4, []) assert new_blocks is not None and len(new_blocks) == 0 assert manager.req_to_blocks[req0.request_id][-2].block_hash is None @@ -175,7 +175,7 @@ def test_decode(): # the preallocated block. for _ in range(5 + 10): req0.append_output_token_ids(7) - new_blocks = manager.append_slots(req0, 15) + new_blocks = manager.allocate_slots(req0, 15, []) assert new_blocks is not None and len(new_blocks) == 0 assert manager.req_to_blocks[req0.request_id][-2].block_hash is not None @@ -185,7 +185,7 @@ def test_decode(): # the preallocated block. for _ in range(6 + 11): req0.append_output_token_ids(12) - new_blocks = manager.append_slots(req0, 17) + new_blocks = manager.allocate_slots(req0, 17, []) # Plus one preallocated block. assert new_blocks is not None and len(new_blocks) == 2 @@ -396,11 +396,11 @@ def test_preallocate_blocks(num_preallocate_tokens: int, block_size: int): assert len(blocks) == 1 + num_preallocated_blocks # Assume all computed. - manager.append_slots(req, block_size * (len(blocks) - 1)) + manager.allocate_slots(req, block_size * (len(blocks) - 1), []) req.num_computed_tokens = block_size * len(blocks) # Append 1 block. - blocks = manager.append_slots(req, block_size) + blocks = manager.allocate_slots(req, block_size, []) assert len(blocks) == 1 + num_preallocated_blocks @@ -503,7 +503,7 @@ def test_mm_prefix_caching(): # Append slots without allocating a new block. for _ in range(5): req0.append_output_token_ids(8) - new_blocks = manager.append_slots(req0, 5) + new_blocks = manager.allocate_slots(req0, 5, []) assert new_blocks is not None and len(new_blocks) == 0 # The just completed block should have hashes with extra keys. @@ -648,7 +648,7 @@ def test_uncache_blocks(): # Simulate speculative tokens. for _ in range(5): req0.append_output_token_ids(8) - manager.append_slots(req0, 5) + manager.allocate_slots(req0, 5, []) assert len(manager.cached_block_hash_to_block) == 2 # After sampling, assuming only 1 token is accepted. diff --git a/vllm/v1/core/kv_cache_manager.py b/vllm/v1/core/kv_cache_manager.py index a7d5008d6adbe..cc02e05ad1102 100644 --- a/vllm/v1/core/kv_cache_manager.py +++ b/vllm/v1/core/kv_cache_manager.py @@ -1,5 +1,5 @@ from collections import defaultdict -from typing import Dict, Iterable, List, Optional, Tuple +from typing import DefaultDict, Dict, Iterable, List, Optional, Tuple from vllm.logger import init_logger from vllm.utils import cdiv @@ -67,7 +67,7 @@ def __init__( # Mapping from request ID to blocks to track the blocks allocated # for each request, so that we can free the blocks when the request # is finished. - self.req_to_blocks: Dict[str, List[KVCacheBlock]] = {} + self.req_to_blocks: DefaultDict[str, List[KVCacheBlock]] = defaultdict(list) def get_computed_blocks( self, request: Request) -> Tuple[List[KVCacheBlock], int]: @@ -190,15 +190,28 @@ def allocate_slots( self, request: Request, num_tokens: int, - computed_blocks: List[KVCacheBlock], + new_computed_blocks: List[KVCacheBlock] ) -> Optional[List[KVCacheBlock]]: - """Allocate slots for a new request. + """Add slots for a new prefill or a new decode request. Args: request: The request to allocate slots. num_tokens: The number of tokens to allocate. Note that this does not include the tokens that have already been computed. - computed_blocks: A list of computed blocks. + new_computed_blocks: A list of new computed blocks just hitting the + prefix caching. + + Blocks layout: + ----------------------------------------------------------------------- + | < computed > | < new computed > | < new > | < pre-allocated > | + ----------------------------------------------------------------------- + | < required > | + -------------------------------------------------- + | < full > | + ------------------------------------------------ + | | + -------------- + The following *_blocks are illustrated in this layout. Returns: A list of new allocated blocks. @@ -209,51 +222,73 @@ def allocate_slots( # Touch the computed blocks to make sure they won't be evicted. if self.enable_caching: - self._touch(computed_blocks) + self._touch(new_computed_blocks) else: - assert not computed_blocks, ( + assert not new_computed_blocks, ( "Computed blocks should be empty when " "prefix caching is disabled") - num_required_blocks = cdiv(num_tokens, self.block_size) - if num_required_blocks > self.free_block_queue.num_free_blocks: + # The number of computed tokens is the number of computed tokens plus + # the new prefix cacheing hits + num_computed_tokens = (request.num_computed_tokens + + len(new_computed_blocks) * self.block_size) + num_required_blocks = cdiv(num_computed_tokens + num_tokens, + self.block_size) + req_blocks = self.req_to_blocks[request.request_id] + num_new_blocks = (num_required_blocks - len(req_blocks) - + len(new_computed_blocks)) + if num_new_blocks > self.free_block_queue.num_free_blocks: # Cannot allocate new blocks - self._untouch(computed_blocks) + self._untouch(new_computed_blocks) return None + # Append the new computed blocks to the request blocks until now to + # avoid the case where the new blocks cannot be allocated. + req_blocks.extend(new_computed_blocks) + + # Start to handle new blocks + # Determine the number of new blocks to allocate considering # preallocated blocks. - num_new_blocks = min( - num_required_blocks + self.num_preallocate_blocks, - self.free_block_queue.num_free_blocks, - # Should not exceed the maximum number of blocks per request. - # This is especially because the block table has the shape - # [..., max_num_blocks_per_req]. - # TODO(woosuk): Check and reject requests if - # num_prompt_tokens + max_tokens > max_model_len. - self.max_num_blocks_per_req - len(computed_blocks), - ) - assert num_new_blocks > 0 - - # Concatenate the computed block IDs and the new block IDs. - new_blocks = self._get_new_blocks(num_new_blocks) - self.req_to_blocks[request.request_id] = computed_blocks + new_blocks + if num_new_blocks <= 0: + # No new block is needed. + new_blocks = [] + else: + num_new_blocks = min( + num_new_blocks + self.num_preallocate_blocks, + self.free_block_queue.num_free_blocks, + # Should not exceed the maximum number of blocks per request. + # This is especially because the block table has the shape + # [..., max_num_blocks_per_req]. + # TODO(woosuk): Check and reject requests if + # num_prompt_tokens + max_tokens > max_model_len. + self.max_num_blocks_per_req - len(req_blocks), + ) + assert num_new_blocks > 0 + + # Concatenate the computed block IDs and the new block IDs. + new_blocks = self._get_new_blocks(num_new_blocks) + req_blocks.extend(new_blocks) if not self.enable_caching: return new_blocks - num_computed_tokens = len(computed_blocks) * self.block_size + # NOTE(rickyx): We are assuming the `num_tokens` are actual + # tokens rather than lookahead slots (e.g. for speculative decoding). + # TODO(rickyx): When supporting speculative decoding, we will need to + # differentiate between them so that we can know how many blocks are + # full after appending the actual tokens. num_full_blocks = (num_computed_tokens + num_tokens) // self.block_size - - new_full_blocks = self.req_to_blocks[ - request.request_id][len(computed_blocks):num_full_blocks] + num_computed_full_blocks = num_computed_tokens // self.block_size + new_full_blocks = req_blocks[num_computed_full_blocks:num_full_blocks] if new_full_blocks: self._cache_full_blocks( request=request, - blk_start_idx=len(computed_blocks), + blk_start_idx=num_computed_full_blocks, # The new full blocks are the full blocks that are not computed. full_blocks=new_full_blocks, - prev_block=computed_blocks[-1] if computed_blocks else None, + prev_block=(req_blocks[num_computed_full_blocks - 1] + if num_computed_full_blocks > 0 else None) ) return new_blocks diff --git a/vllm/v1/core/scheduler.py b/vllm/v1/core/scheduler.py index 7a88cc9433b32..09d5665642022 100644 --- a/vllm/v1/core/scheduler.py +++ b/vllm/v1/core/scheduler.py @@ -138,8 +138,8 @@ def schedule(self) -> "SchedulerOutput": assert num_new_tokens > 0 while True: - new_blocks = self.kv_cache_manager.append_slots( - request, num_new_tokens) + new_blocks = self.kv_cache_manager.allocate_slots( + request, num_new_tokens, []) if new_blocks is None: # The request cannot be scheduled. # Preempt the lowest-priority request. From a482f5db56b79083bffddfddc9def7e2d048d303 Mon Sep 17 00:00:00 2001 From: Shawn Du Date: Wed, 29 Jan 2025 19:00:07 +0800 Subject: [PATCH 3/8] Delete append_slots Signed-off-by: Shawn Du --- vllm/v1/core/kv_cache_manager.py | 76 -------------------------------- 1 file changed, 76 deletions(-) diff --git a/vllm/v1/core/kv_cache_manager.py b/vllm/v1/core/kv_cache_manager.py index cc02e05ad1102..0918f1404d1b9 100644 --- a/vllm/v1/core/kv_cache_manager.py +++ b/vllm/v1/core/kv_cache_manager.py @@ -110,82 +110,6 @@ def get_computed_blocks( num_computed_tokens = len(computed_blocks) * self.block_size return computed_blocks, num_computed_tokens - def append_slots( - self, - request: Request, - num_tokens: int, - ) -> Optional[List[KVCacheBlock]]: - """Append slots to the block table of the request. - We first append slots to already allocated blocks. If the allocated - blocks are not enough, we allocate new blocks. - - Args: - request: The request to append slots. - num_tokens: The number of tokens to append. - - Returns: - A list of new blocks if new blocks are allocated, or None - if new blocks are required but cannot be allocated. - """ - num_required_blocks = cdiv(request.num_computed_tokens + num_tokens, - self.block_size) - req_blocks = self.req_to_blocks[request.request_id] - - num_new_blocks = num_required_blocks - len(req_blocks) - if num_new_blocks > self.free_block_queue.num_free_blocks: - # Need to allocate new blocks due to insufficient pre-allocated - # slots, but we cannot allocate new blocks due to the limit. - return None - - if num_new_blocks <= 0: - # No new block is needed. - new_blocks = [] - else: - # Get new blocks from the free block pool considering - # preallocated blocks. - num_new_blocks = min( - num_new_blocks + self.num_preallocate_blocks, - self.free_block_queue.num_free_blocks, - # Should not exceed the maximum number of blocks per request. - # This is especially because the block table has the shape - # [..., max_num_blocks_per_req]. - # TODO(woosuk): Check and reject requests if - # num_prompt_tokens + max_tokens > max_model_len. - self.max_num_blocks_per_req - len(req_blocks), - ) - assert num_new_blocks > 0 - - new_blocks = self._get_new_blocks(num_new_blocks) - req_blocks.extend(new_blocks) - - if not self.enable_caching: - return new_blocks - - num_computed_full_blocks = (request.num_computed_tokens // - self.block_size) - - # NOTE(rickyx): We are assuming the `num_tokens` are actual - # tokens rather than lookahead slots (e.g. for speculative decoding). - # TODO(rickyx): When supporting speculative decoding, we will need to - # differentiate between them so that we can know how many blocks are - # full after appending the actual tokens. - num_full_blocks_after_append = (request.num_computed_tokens + - num_tokens) // self.block_size - assert num_full_blocks_after_append <= len(req_blocks) - - new_full_blocks = req_blocks[ - num_computed_full_blocks:num_full_blocks_after_append] - if new_full_blocks: - self._cache_full_blocks( - request=request, - blk_start_idx=num_computed_full_blocks, - full_blocks=new_full_blocks, - prev_block=req_blocks[num_computed_full_blocks - 1] - if num_computed_full_blocks >= 1 else None, - ) - - return new_blocks - def allocate_slots( self, request: Request, From 23772e92df7053fdba54f152a8c35d0cb40875c1 Mon Sep 17 00:00:00 2001 From: Shawn Du Date: Wed, 29 Jan 2025 22:16:33 +0800 Subject: [PATCH 4/8] Modify test case in prefix caching Assume in both prefill and decode, num_tokens should not be zero, previously only prefill assumed this. Signed-off-by: Shawn Du --- tests/v1/core/test_prefix_caching.py | 8 +++++--- 1 file changed, 5 insertions(+), 3 deletions(-) diff --git a/tests/v1/core/test_prefix_caching.py b/tests/v1/core/test_prefix_caching.py index eeb1d05de0255..bbc4e6e297d97 100644 --- a/tests/v1/core/test_prefix_caching.py +++ b/tests/v1/core/test_prefix_caching.py @@ -395,9 +395,11 @@ def test_preallocate_blocks(num_preallocate_tokens: int, block_size: int): req.num_computed_tokens = block_size assert len(blocks) == 1 + num_preallocated_blocks - # Assume all computed. - manager.allocate_slots(req, block_size * (len(blocks) - 1), []) - req.num_computed_tokens = block_size * len(blocks) + # Assume all computed, only when num_preallocate_tokens > 0, we need to + # consume the previously preallocated blocks. + if num_preallocated_blocks > 0: + manager.allocate_slots(req, block_size * (len(blocks) - 1), []) + req.num_computed_tokens = block_size * len(blocks) # Append 1 block. blocks = manager.allocate_slots(req, block_size, []) From 075f1b5b8015e343f25b45d36ccff571cb39739b Mon Sep 17 00:00:00 2001 From: Shawn Du Date: Fri, 31 Jan 2025 20:30:21 +0800 Subject: [PATCH 5/8] Address static checkers Signed-off-by: Shawn Du --- vllm/v1/core/kv_cache_manager.py | 10 ++++------ 1 file changed, 4 insertions(+), 6 deletions(-) diff --git a/vllm/v1/core/kv_cache_manager.py b/vllm/v1/core/kv_cache_manager.py index 0918f1404d1b9..643e3e6d7e082 100644 --- a/vllm/v1/core/kv_cache_manager.py +++ b/vllm/v1/core/kv_cache_manager.py @@ -67,7 +67,8 @@ def __init__( # Mapping from request ID to blocks to track the blocks allocated # for each request, so that we can free the blocks when the request # is finished. - self.req_to_blocks: DefaultDict[str, List[KVCacheBlock]] = defaultdict(list) + self.req_to_blocks: DefaultDict[str, + List[KVCacheBlock]] = defaultdict(list) def get_computed_blocks( self, request: Request) -> Tuple[List[KVCacheBlock], int]: @@ -111,9 +112,7 @@ def get_computed_blocks( return computed_blocks, num_computed_tokens def allocate_slots( - self, - request: Request, - num_tokens: int, + self, request: Request, num_tokens: int, new_computed_blocks: List[KVCacheBlock] ) -> Optional[List[KVCacheBlock]]: """Add slots for a new prefill or a new decode request. @@ -212,8 +211,7 @@ def allocate_slots( # The new full blocks are the full blocks that are not computed. full_blocks=new_full_blocks, prev_block=(req_blocks[num_computed_full_blocks - 1] - if num_computed_full_blocks > 0 else None) - ) + if num_computed_full_blocks > 0 else None)) return new_blocks From c7ef00378433aa502175c604c54188578a25d992 Mon Sep 17 00:00:00 2001 From: Shawn Du Date: Sat, 1 Feb 2025 11:37:48 +0800 Subject: [PATCH 6/8] Address reviewer comments Signed-off-by: Shawn Du --- tests/v1/core/test_prefix_caching.py | 18 +++++++++--------- vllm/v1/core/kv_cache_manager.py | 8 ++++---- vllm/v1/core/scheduler.py | 2 +- 3 files changed, 14 insertions(+), 14 deletions(-) diff --git a/tests/v1/core/test_prefix_caching.py b/tests/v1/core/test_prefix_caching.py index bbc4e6e297d97..5c1cda285fb1d 100644 --- a/tests/v1/core/test_prefix_caching.py +++ b/tests/v1/core/test_prefix_caching.py @@ -164,7 +164,7 @@ def test_decode(): req0.num_computed_tokens = 55 for _ in range(4): req0.append_output_token_ids(8) - new_blocks = manager.allocate_slots(req0, 4, []) + new_blocks = manager.allocate_slots(req0, 4) assert new_blocks is not None and len(new_blocks) == 0 assert manager.req_to_blocks[req0.request_id][-2].block_hash is None @@ -175,7 +175,7 @@ def test_decode(): # the preallocated block. for _ in range(5 + 10): req0.append_output_token_ids(7) - new_blocks = manager.allocate_slots(req0, 15, []) + new_blocks = manager.allocate_slots(req0, 15) assert new_blocks is not None and len(new_blocks) == 0 assert manager.req_to_blocks[req0.request_id][-2].block_hash is not None @@ -185,7 +185,7 @@ def test_decode(): # the preallocated block. for _ in range(6 + 11): req0.append_output_token_ids(12) - new_blocks = manager.allocate_slots(req0, 17, []) + new_blocks = manager.allocate_slots(req0, 17) # Plus one preallocated block. assert new_blocks is not None and len(new_blocks) == 2 @@ -398,11 +398,11 @@ def test_preallocate_blocks(num_preallocate_tokens: int, block_size: int): # Assume all computed, only when num_preallocate_tokens > 0, we need to # consume the previously preallocated blocks. if num_preallocated_blocks > 0: - manager.allocate_slots(req, block_size * (len(blocks) - 1), []) + manager.allocate_slots(req, block_size * (len(blocks) - 1)) req.num_computed_tokens = block_size * len(blocks) # Append 1 block. - blocks = manager.allocate_slots(req, block_size, []) + blocks = manager.allocate_slots(req, block_size) assert len(blocks) == 1 + num_preallocated_blocks @@ -505,7 +505,7 @@ def test_mm_prefix_caching(): # Append slots without allocating a new block. for _ in range(5): req0.append_output_token_ids(8) - new_blocks = manager.allocate_slots(req0, 5, []) + new_blocks = manager.allocate_slots(req0, 5) assert new_blocks is not None and len(new_blocks) == 0 # The just completed block should have hashes with extra keys. @@ -605,7 +605,7 @@ def test_reset_prefix_cache(): unique_token_ids = [3] * 7 all_token_ids = full_block_token_ids + unique_token_ids req0 = make_request("0", all_token_ids) - blocks = manager.allocate_slots(req0, 55, []) + blocks = manager.allocate_slots(req0, 55) assert [b.block_id for b in blocks] == [0, 1, 2, 3] unique_token_ids = [4] * 7 @@ -641,7 +641,7 @@ def test_uncache_blocks(): ) req0 = make_request("0", list(range(30))) - blocks = manager.allocate_slots(req0, 30, []) + blocks = manager.allocate_slots(req0, 30) assert [b.block_id for b in blocks] == [0, 1] assert len(manager.cached_block_hash_to_block) == 1 @@ -650,7 +650,7 @@ def test_uncache_blocks(): # Simulate speculative tokens. for _ in range(5): req0.append_output_token_ids(8) - manager.allocate_slots(req0, 5, []) + manager.allocate_slots(req0, 5) assert len(manager.cached_block_hash_to_block) == 2 # After sampling, assuming only 1 token is accepted. diff --git a/vllm/v1/core/kv_cache_manager.py b/vllm/v1/core/kv_cache_manager.py index 643e3e6d7e082..bdbc57b4b7cd7 100644 --- a/vllm/v1/core/kv_cache_manager.py +++ b/vllm/v1/core/kv_cache_manager.py @@ -113,7 +113,7 @@ def get_computed_blocks( def allocate_slots( self, request: Request, num_tokens: int, - new_computed_blocks: List[KVCacheBlock] + new_computed_blocks: Optional[List[KVCacheBlock]] = None ) -> Optional[List[KVCacheBlock]]: """Add slots for a new prefill or a new decode request. @@ -140,9 +140,9 @@ def allocate_slots( A list of new allocated blocks. """ if num_tokens == 0: - raise ValueError( - f"num_tokens must be greater than 0, got {num_tokens}") + raise ValueError(f"num_tokens must be greater than 0") + new_computed_blocks = new_computed_blocks or [] # Touch the computed blocks to make sure they won't be evicted. if self.enable_caching: self._touch(new_computed_blocks) @@ -152,7 +152,7 @@ def allocate_slots( "prefix caching is disabled") # The number of computed tokens is the number of computed tokens plus - # the new prefix cacheing hits + # the new prefix caching hits num_computed_tokens = (request.num_computed_tokens + len(new_computed_blocks) * self.block_size) num_required_blocks = cdiv(num_computed_tokens + num_tokens, diff --git a/vllm/v1/core/scheduler.py b/vllm/v1/core/scheduler.py index 09d5665642022..066751cda1f05 100644 --- a/vllm/v1/core/scheduler.py +++ b/vllm/v1/core/scheduler.py @@ -139,7 +139,7 @@ def schedule(self) -> "SchedulerOutput": while True: new_blocks = self.kv_cache_manager.allocate_slots( - request, num_new_tokens, []) + request, num_new_tokens) if new_blocks is None: # The request cannot be scheduled. # Preempt the lowest-priority request. From 8b2172a03e66e91af66c7b73a636220c0b3b9b70 Mon Sep 17 00:00:00 2001 From: Shawn Du Date: Sat, 1 Feb 2025 11:58:27 +0800 Subject: [PATCH 7/8] Remove _untouch Signed-off-by: Shawn Du --- vllm/v1/core/kv_cache_manager.py | 40 +++++++++++++------------------- 1 file changed, 16 insertions(+), 24 deletions(-) diff --git a/vllm/v1/core/kv_cache_manager.py b/vllm/v1/core/kv_cache_manager.py index bdbc57b4b7cd7..19ec8b885a7a4 100644 --- a/vllm/v1/core/kv_cache_manager.py +++ b/vllm/v1/core/kv_cache_manager.py @@ -143,13 +143,6 @@ def allocate_slots( raise ValueError(f"num_tokens must be greater than 0") new_computed_blocks = new_computed_blocks or [] - # Touch the computed blocks to make sure they won't be evicted. - if self.enable_caching: - self._touch(new_computed_blocks) - else: - assert not new_computed_blocks, ( - "Computed blocks should be empty when " - "prefix caching is disabled") # The number of computed tokens is the number of computed tokens plus # the new prefix caching hits @@ -160,11 +153,25 @@ def allocate_slots( req_blocks = self.req_to_blocks[request.request_id] num_new_blocks = (num_required_blocks - len(req_blocks) - len(new_computed_blocks)) - if num_new_blocks > self.free_block_queue.num_free_blocks: + + # If a computed block of a request is an eviction candidate (in the + # free queue and ref_cnt == 0), it cannot be counted as a free block + # when allocating this request. + num_evictable_computed_blocks = sum(1 for blk in new_computed_blocks + if blk.ref_cnt == 0) + if (num_new_blocks > self.free_block_queue.num_free_blocks - + num_evictable_computed_blocks): # Cannot allocate new blocks - self._untouch(new_computed_blocks) return None + # Touch the computed blocks to make sure they won't be evicted. + if self.enable_caching: + self._touch(new_computed_blocks) + else: + assert not new_computed_blocks, ( + "Computed blocks should be empty when " + "prefix caching is disabled") + # Append the new computed blocks to the request blocks until now to # avoid the case where the new blocks cannot be allocated. req_blocks.extend(new_computed_blocks) @@ -422,21 +429,6 @@ def _touch(self, blocks: List[KVCacheBlock]) -> None: self.free_block_queue.remove(block) block.incr_ref() - def _untouch(self, blocks: List[KVCacheBlock]) -> None: - """Untouch a block decreases its reference count by 1, and may add - the block to the free queue. This is used to reverse touched blocks - that are hit by another request with the same prefix. - - Args: - blocks: A list of blocks to touch. - """ - for block in blocks: - block.decr_ref() - # ref_cnt=0 means this block can be added in the free list - # (i.e. eviction candidate). - if block.ref_cnt == 0: - self.free_block_queue.append(block) - def _cache_full_blocks( self, request: Request, From b4256766f6a9f920ac9b253cc3fad7d0ed3bdf1b Mon Sep 17 00:00:00 2001 From: Shawn Du Date: Sat, 1 Feb 2025 17:17:28 +0800 Subject: [PATCH 8/8] Address comments Signed-off-by: Shawn Du --- vllm/v1/core/kv_cache_manager.py | 12 +++++++----- 1 file changed, 7 insertions(+), 5 deletions(-) diff --git a/vllm/v1/core/kv_cache_manager.py b/vllm/v1/core/kv_cache_manager.py index 19ec8b885a7a4..7a67af9448a09 100644 --- a/vllm/v1/core/kv_cache_manager.py +++ b/vllm/v1/core/kv_cache_manager.py @@ -112,10 +112,12 @@ def get_computed_blocks( return computed_blocks, num_computed_tokens def allocate_slots( - self, request: Request, num_tokens: int, + self, + request: Request, + num_tokens: int, new_computed_blocks: Optional[List[KVCacheBlock]] = None ) -> Optional[List[KVCacheBlock]]: - """Add slots for a new prefill or a new decode request. + """Add slots for a request with new tokens to append. Args: request: The request to allocate slots. @@ -140,7 +142,7 @@ def allocate_slots( A list of new allocated blocks. """ if num_tokens == 0: - raise ValueError(f"num_tokens must be greater than 0") + raise ValueError("num_tokens must be greater than 0") new_computed_blocks = new_computed_blocks or [] @@ -178,12 +180,12 @@ def allocate_slots( # Start to handle new blocks - # Determine the number of new blocks to allocate considering - # preallocated blocks. if num_new_blocks <= 0: # No new block is needed. new_blocks = [] else: + # Get new blocks from the free block pool considering + # preallocated blocks. num_new_blocks = min( num_new_blocks + self.num_preallocate_blocks, self.free_block_queue.num_free_blocks,