From f8ece6e17fbf4ff3a98d6d53cb3a03c50c02828c Mon Sep 17 00:00:00 2001 From: Shawn Du Date: Sun, 2 Feb 2025 16:40:58 +0800 Subject: [PATCH] [Core][v1] Unify allocating slots in prefill and decode in KV cache manager (#12608) As mentioned in RFC https://github.com/vllm-project/vllm/issues/12254, this PR achieves the task: combine allocate_slots and append_slots. There should be no functionality change, except that in decode, also raise exception when num_tokens is zero (like prefill), and change the unit test case accordingly. @comaniac @rickyyx @WoosukKwon @youkaichao @heheda12345 @simon-mo --------- Signed-off-by: Shawn Du --- tests/v1/core/test_prefix_caching.py | 24 ++-- vllm/v1/core/kv_cache_manager.py | 168 ++++++++++----------------- vllm/v1/core/scheduler.py | 2 +- 3 files changed, 78 insertions(+), 116 deletions(-) diff --git a/tests/v1/core/test_prefix_caching.py b/tests/v1/core/test_prefix_caching.py index f434fa8c61a80..5c1cda285fb1d 100644 --- a/tests/v1/core/test_prefix_caching.py +++ b/tests/v1/core/test_prefix_caching.py @@ -164,7 +164,7 @@ def test_decode(): req0.num_computed_tokens = 55 for _ in range(4): req0.append_output_token_ids(8) - new_blocks = manager.append_slots(req0, 4) + new_blocks = manager.allocate_slots(req0, 4) assert new_blocks is not None and len(new_blocks) == 0 assert manager.req_to_blocks[req0.request_id][-2].block_hash is None @@ -175,7 +175,7 @@ def test_decode(): # the preallocated block. for _ in range(5 + 10): req0.append_output_token_ids(7) - new_blocks = manager.append_slots(req0, 15) + new_blocks = manager.allocate_slots(req0, 15) assert new_blocks is not None and len(new_blocks) == 0 assert manager.req_to_blocks[req0.request_id][-2].block_hash is not None @@ -185,7 +185,7 @@ def test_decode(): # the preallocated block. for _ in range(6 + 11): req0.append_output_token_ids(12) - new_blocks = manager.append_slots(req0, 17) + new_blocks = manager.allocate_slots(req0, 17) # Plus one preallocated block. assert new_blocks is not None and len(new_blocks) == 2 @@ -395,12 +395,14 @@ def test_preallocate_blocks(num_preallocate_tokens: int, block_size: int): req.num_computed_tokens = block_size assert len(blocks) == 1 + num_preallocated_blocks - # Assume all computed. - manager.append_slots(req, block_size * (len(blocks) - 1)) - req.num_computed_tokens = block_size * len(blocks) + # Assume all computed, only when num_preallocate_tokens > 0, we need to + # consume the previously preallocated blocks. + if num_preallocated_blocks > 0: + manager.allocate_slots(req, block_size * (len(blocks) - 1)) + req.num_computed_tokens = block_size * len(blocks) # Append 1 block. - blocks = manager.append_slots(req, block_size) + blocks = manager.allocate_slots(req, block_size) assert len(blocks) == 1 + num_preallocated_blocks @@ -503,7 +505,7 @@ def test_mm_prefix_caching(): # Append slots without allocating a new block. for _ in range(5): req0.append_output_token_ids(8) - new_blocks = manager.append_slots(req0, 5) + new_blocks = manager.allocate_slots(req0, 5) assert new_blocks is not None and len(new_blocks) == 0 # The just completed block should have hashes with extra keys. @@ -603,7 +605,7 @@ def test_reset_prefix_cache(): unique_token_ids = [3] * 7 all_token_ids = full_block_token_ids + unique_token_ids req0 = make_request("0", all_token_ids) - blocks = manager.allocate_slots(req0, 55, []) + blocks = manager.allocate_slots(req0, 55) assert [b.block_id for b in blocks] == [0, 1, 2, 3] unique_token_ids = [4] * 7 @@ -639,7 +641,7 @@ def test_uncache_blocks(): ) req0 = make_request("0", list(range(30))) - blocks = manager.allocate_slots(req0, 30, []) + blocks = manager.allocate_slots(req0, 30) assert [b.block_id for b in blocks] == [0, 1] assert len(manager.cached_block_hash_to_block) == 1 @@ -648,7 +650,7 @@ def test_uncache_blocks(): # Simulate speculative tokens. for _ in range(5): req0.append_output_token_ids(8) - manager.append_slots(req0, 5) + manager.allocate_slots(req0, 5) assert len(manager.cached_block_hash_to_block) == 2 # After sampling, assuming only 1 token is accepted. diff --git a/vllm/v1/core/kv_cache_manager.py b/vllm/v1/core/kv_cache_manager.py index d6c612f155f01..7176ec9544f99 100644 --- a/vllm/v1/core/kv_cache_manager.py +++ b/vllm/v1/core/kv_cache_manager.py @@ -1,5 +1,5 @@ from collections import defaultdict -from typing import Dict, Iterable, List, Optional, Tuple +from typing import DefaultDict, Dict, Iterable, List, Optional, Tuple from vllm.logger import init_logger from vllm.utils import cdiv @@ -67,7 +67,8 @@ def __init__( # Mapping from request ID to blocks to track the blocks allocated # for each request, so that we can free the blocks when the request # is finished. - self.req_to_blocks: Dict[str, List[KVCacheBlock]] = {} + self.req_to_blocks: DefaultDict[str, + List[KVCacheBlock]] = defaultdict(list) @property def usage(self) -> float: @@ -115,33 +116,75 @@ def get_computed_blocks( num_computed_tokens = len(computed_blocks) * self.block_size return computed_blocks, num_computed_tokens - def append_slots( + def allocate_slots( self, request: Request, num_tokens: int, + new_computed_blocks: Optional[List[KVCacheBlock]] = None ) -> Optional[List[KVCacheBlock]]: - """Append slots to the block table of the request. - We first append slots to already allocated blocks. If the allocated - blocks are not enough, we allocate new blocks. + """Add slots for a request with new tokens to append. Args: - request: The request to append slots. - num_tokens: The number of tokens to append. + request: The request to allocate slots. + num_tokens: The number of tokens to allocate. Note that this does + not include the tokens that have already been computed. + new_computed_blocks: A list of new computed blocks just hitting the + prefix caching. + + Blocks layout: + ----------------------------------------------------------------------- + | < computed > | < new computed > | < new > | < pre-allocated > | + ----------------------------------------------------------------------- + | < required > | + -------------------------------------------------- + | < full > | + ------------------------------------------------ + | | + -------------- + The following *_blocks are illustrated in this layout. Returns: - A list of new blocks if new blocks are allocated, or None - if new blocks are required but cannot be allocated. + A list of new allocated blocks. """ - num_required_blocks = cdiv(request.num_computed_tokens + num_tokens, + if num_tokens == 0: + raise ValueError("num_tokens must be greater than 0") + + new_computed_blocks = new_computed_blocks or [] + + # The number of computed tokens is the number of computed tokens plus + # the new prefix caching hits + num_computed_tokens = (request.num_computed_tokens + + len(new_computed_blocks) * self.block_size) + num_required_blocks = cdiv(num_computed_tokens + num_tokens, self.block_size) req_blocks = self.req_to_blocks[request.request_id] + num_new_blocks = (num_required_blocks - len(req_blocks) - + len(new_computed_blocks)) - num_new_blocks = num_required_blocks - len(req_blocks) - if num_new_blocks > self.free_block_queue.num_free_blocks: - # Need to allocate new blocks due to insufficient pre-allocated - # slots, but we cannot allocate new blocks due to the limit. + # If a computed block of a request is an eviction candidate (in the + # free queue and ref_cnt == 0), it cannot be counted as a free block + # when allocating this request. + num_evictable_computed_blocks = sum(1 for blk in new_computed_blocks + if blk.ref_cnt == 0) + if (num_new_blocks > self.free_block_queue.num_free_blocks - + num_evictable_computed_blocks): + # Cannot allocate new blocks return None + # Touch the computed blocks to make sure they won't be evicted. + if self.enable_caching: + self._touch(new_computed_blocks) + else: + assert not new_computed_blocks, ( + "Computed blocks should be empty when " + "prefix caching is disabled") + + # Append the new computed blocks to the request blocks until now to + # avoid the case where the new blocks cannot be allocated. + req_blocks.extend(new_computed_blocks) + + # Start to handle new blocks + if num_new_blocks <= 0: # No new block is needed. new_blocks = [] @@ -160,112 +203,29 @@ def append_slots( ) assert num_new_blocks > 0 + # Concatenate the computed block IDs and the new block IDs. new_blocks = self._get_new_blocks(num_new_blocks) req_blocks.extend(new_blocks) if not self.enable_caching: return new_blocks - num_computed_full_blocks = (request.num_computed_tokens // - self.block_size) - # NOTE(rickyx): We are assuming the `num_tokens` are actual # tokens rather than lookahead slots (e.g. for speculative decoding). # TODO(rickyx): When supporting speculative decoding, we will need to # differentiate between them so that we can know how many blocks are # full after appending the actual tokens. - num_full_blocks_after_append = (request.num_computed_tokens + - num_tokens) // self.block_size - assert num_full_blocks_after_append <= len(req_blocks) - - new_full_blocks = req_blocks[ - num_computed_full_blocks:num_full_blocks_after_append] - if new_full_blocks: - self._cache_full_blocks( - request=request, - blk_start_idx=num_computed_full_blocks, - full_blocks=new_full_blocks, - prev_block=req_blocks[num_computed_full_blocks - 1] - if num_computed_full_blocks >= 1 else None, - ) - - return new_blocks - - def allocate_slots( - self, - request: Request, - num_tokens: int, - computed_blocks: List[KVCacheBlock], - ) -> Optional[List[KVCacheBlock]]: - """Allocate slots for a new request. - - Args: - request: The request to allocate slots. - num_tokens: The number of tokens to allocate. Note that this does - not include the tokens that have already been computed. - computed_blocks: A list of computed blocks. - - Returns: - A list of new allocated blocks. - """ - if num_tokens == 0: - raise ValueError( - f"num_tokens must be greater than 0, got {num_tokens}") - - # If a computed block of a request is an eviction candidate (in the - # free queue and ref_cnt == 0), it cannot be counted as a free block - # when allocating this request. - num_evictable_computed_blocks = sum(1 for blk in computed_blocks - if blk.ref_cnt == 0) - - num_required_blocks = cdiv(num_tokens, self.block_size) - if (num_required_blocks > self.free_block_queue.num_free_blocks - - num_evictable_computed_blocks): - # Cannot allocate new blocks. - return None - - # Touch the computed blocks to make sure they won't be evicted. - if self.enable_caching: - self._touch(computed_blocks) - else: - assert not computed_blocks, ( - "Computed blocks should be empty when " - "prefix caching is disabled") - - # Determine the number of new blocks to allocate considering - # preallocated blocks. - num_new_blocks = min( - num_required_blocks + self.num_preallocate_blocks, - self.free_block_queue.num_free_blocks, - # Should not exceed the maximum number of blocks per request. - # This is especially because the block table has the shape - # [..., max_num_blocks_per_req]. - # TODO(woosuk): Check and reject requests if - # num_prompt_tokens + max_tokens > max_model_len. - self.max_num_blocks_per_req - len(computed_blocks), - ) - assert num_new_blocks > 0 - - # Concatenate the computed block IDs and the new block IDs. - new_blocks = self._get_new_blocks(num_new_blocks) - self.req_to_blocks[request.request_id] = computed_blocks + new_blocks - - if not self.enable_caching: - return new_blocks - - num_computed_tokens = len(computed_blocks) * self.block_size num_full_blocks = (num_computed_tokens + num_tokens) // self.block_size - - new_full_blocks = self.req_to_blocks[ - request.request_id][len(computed_blocks):num_full_blocks] + num_computed_full_blocks = num_computed_tokens // self.block_size + new_full_blocks = req_blocks[num_computed_full_blocks:num_full_blocks] if new_full_blocks: self._cache_full_blocks( request=request, - blk_start_idx=len(computed_blocks), + blk_start_idx=num_computed_full_blocks, # The new full blocks are the full blocks that are not computed. full_blocks=new_full_blocks, - prev_block=computed_blocks[-1] if computed_blocks else None, - ) + prev_block=(req_blocks[num_computed_full_blocks - 1] + if num_computed_full_blocks > 0 else None)) return new_blocks diff --git a/vllm/v1/core/scheduler.py b/vllm/v1/core/scheduler.py index 910fc4ff4d2b6..27c9ac1ae353c 100644 --- a/vllm/v1/core/scheduler.py +++ b/vllm/v1/core/scheduler.py @@ -138,7 +138,7 @@ def schedule(self) -> "SchedulerOutput": assert num_new_tokens > 0 while True: - new_blocks = self.kv_cache_manager.append_slots( + new_blocks = self.kv_cache_manager.allocate_slots( request, num_new_tokens) if new_blocks is None: # The request cannot be scheduled.