diff --git a/benchmarks/bench_append_paged_kv_cache.py b/benchmarks/bench_append_paged_kv_cache.py index 0cc7d0e1..b55d5aec 100644 --- a/benchmarks/bench_append_paged_kv_cache.py +++ b/benchmarks/bench_append_paged_kv_cache.py @@ -99,12 +99,19 @@ def main(): dtype=torch.int32, ) + batch_indices, positions = flashinfer.get_batch_indices_positions( + x_indptr, + flashinfer.get_seq_lens(kv_indptr, kv_last_page_len, page_len), + k.shape[0], + ) + @torch.cuda.nvtx.range(f"model={model_name}, seqlens={seqlens}") def fn(): flashinfer.append_paged_kv_cache( k, v, - x_indptr, + batch_indices, + positions, layer_buf, kv_indices, kv_indptr, diff --git a/include/flashinfer/page.cuh b/include/flashinfer/page.cuh index f3663f27..fa256be3 100644 --- a/include/flashinfer/page.cuh +++ b/include/flashinfer/page.cuh @@ -249,38 +249,34 @@ __global__ void AppendPagedKVCacheDecodeKernel(paged_kv_t paged_k * \param paged_kv The paged key-value cache * \param key The key to be appended * \param value The value to be appended - * \param append_indptr The indptr array of the appended ragged tensor + * \param batch_indices The batch indices of elements to be appended + * \param positions The positions of elements to be appended */ template -__global__ void AppendPagedKVCachePrefillKernel(paged_kv_t paged_kv, - DType* __restrict__ key, DType* __restrict__ value, - IdType* __restrict__ append_indptr) { +__global__ void AppendPagedKVCacheKernel(paged_kv_t paged_kv, + DType* __restrict__ append_key, + DType* __restrict__ append_value, + IdType* __restrict__ batch_indices, + IdType* __restrict__ positions, uint32_t nnz, + size_t append_k_stride_n, size_t append_k_stride_h, + size_t append_v_stride_n, size_t append_v_stride_h) { uint32_t tx = threadIdx.x, ty = threadIdx.y; uint32_t num_heads = paged_kv.num_heads; - uint32_t batch_idx = blockIdx.x; uint32_t head_idx = ty; - - uint32_t seq_len = - (paged_kv.indptr[batch_idx + 1] - paged_kv.indptr[batch_idx] - 1) * paged_kv.page_size + - paged_kv.last_page_len[batch_idx]; - uint32_t append_seq_len = append_indptr[batch_idx + 1] - append_indptr[batch_idx]; - uint32_t append_start = seq_len - append_seq_len; - -#pragma unroll 2 - for (uint32_t j = 0; j < append_seq_len; ++j) { - uint32_t page_seq_idx = j + append_start; - uint32_t page_iter = paged_kv.indptr[batch_idx] + page_seq_idx / paged_kv.page_size; - uint32_t entry_idx = page_seq_idx % paged_kv.page_size; - + uint32_t cta_id = blockIdx.x; + uint32_t num_ctas = gridDim.x; + +#pragma unroll 4 + for (uint32_t i = cta_id; i < nnz; i += num_ctas) { + uint32_t page_iter, entry_idx; + paged_kv.page_size.divmod(paged_kv.indptr[batch_indices[i]] * paged_kv.page_size + positions[i], + page_iter, entry_idx); DType* k_ptr = paged_kv.get_k_ptr(page_iter, head_idx, entry_idx, tx * vec_size); DType* v_ptr = paged_kv.get_v_ptr(page_iter, head_idx, entry_idx, tx * vec_size); vec_t::memcpy( - k_ptr, - key + ((append_indptr[batch_idx] + j) * num_heads + head_idx) * head_dim + tx * vec_size); - + k_ptr, append_key + i * append_k_stride_n + head_idx * append_k_stride_h + tx * vec_size); vec_t::memcpy( - v_ptr, - value + ((append_indptr[batch_idx] + j) * num_heads + head_idx) * head_dim + tx * vec_size); + v_ptr, append_value + i * append_v_stride_n + head_idx * append_v_stride_h + tx * vec_size); } } @@ -327,20 +323,36 @@ cudaError_t AppendPagedKVCacheDecode(paged_kv_t paged_kv, DType* * \return status Indicates whether CUDA calls are successful */ template -cudaError_t AppendPagedKVCache(paged_kv_t paged_kv, DType* key, DType* value, - IdType* append_indptr, cudaStream_t stream = nullptr) { +cudaError_t AppendPagedKVCache(paged_kv_t paged_kv, DType* append_key, + DType* append_value, IdType* batch_indices, IdType* positions, + uint32_t nnz, size_t append_k_stride_n, size_t append_k_stride_h, + size_t append_v_stride_n, size_t append_v_stride_h, + cudaStream_t stream = nullptr) { uint32_t head_dim = paged_kv.head_dim; - uint32_t batch_size = paged_kv.batch_size; uint32_t num_heads = paged_kv.num_heads; + int dev_id = 0; + int num_sms = 0; + int num_blocks_per_sm = 0; + FLASHINFER_CUDA_CALL(cudaGetDevice(&dev_id)); + FLASHINFER_CUDA_CALL(cudaDeviceGetAttribute(&num_sms, cudaDevAttrMultiProcessorCount, dev_id)); + DISPATCH_HEAD_DIM(head_dim, HEAD_DIM, { constexpr uint32_t vec_size = std::max(16 / sizeof(DType), HEAD_DIM / 32); uint32_t bdx = HEAD_DIM / vec_size; uint32_t bdy = num_heads; - // NOTE(Zihao): could be slow for small batch size, will optimize later - dim3 nblks(batch_size); + uint32_t num_threads = bdx * bdy; + uint32_t smem_size = 0; + auto kernel = AppendPagedKVCacheKernel; + FLASHINFER_CUDA_CALL(cudaOccupancyMaxActiveBlocksPerMultiprocessor(&num_blocks_per_sm, kernel, + num_threads, smem_size)); + num_blocks_per_sm = min(num_blocks_per_sm, ceil_div(int(nnz), num_sms)); + dim3 nblks(num_blocks_per_sm * num_sms); dim3 nthrs(bdx, bdy); - auto kernel = AppendPagedKVCachePrefillKernel; - void* args[] = {(void*)&paged_kv, (void*)&key, (void*)&value, (void*)&append_indptr}; + + void* args[] = {(void*)&paged_kv, (void*)&append_key, (void*)&append_value, + (void*)&batch_indices, (void*)&positions, (void*)&nnz, + (void*)&append_k_stride_n, (void*)&append_k_stride_h, (void*)&append_v_stride_n, + (void*)&append_v_stride_h}; FLASHINFER_CUDA_CALL(cudaLaunchKernel((void*)kernel, nblks, nthrs, args, 0, stream)); }); return cudaSuccess; diff --git a/python/csrc/flashinfer_page_ops.cu b/python/csrc/flashinfer_page_ops.cu index 604c4156..aacaa485 100644 --- a/python/csrc/flashinfer_page_ops.cu +++ b/python/csrc/flashinfer_page_ops.cu @@ -16,10 +16,10 @@ #include void append_paged_kv_cache(torch::Tensor append_key, torch::Tensor append_value, - torch::Tensor append_indptr, torch::Tensor paged_k_cache, - torch::Tensor paged_v_cache, torch::Tensor kv_indices, - torch::Tensor kv_indptr, torch::Tensor kv_last_page_len, - unsigned int layout); + torch::Tensor batch_indices, torch::Tensor positions, + torch::Tensor paged_k_cache, torch::Tensor paged_v_cache, + torch::Tensor kv_indices, torch::Tensor kv_indptr, + torch::Tensor kv_last_page_len, unsigned int layout); PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { m.def("append_paged_kv_cache", &append_paged_kv_cache, "Append paged KV-Cache operator"); diff --git a/python/csrc/page.cu b/python/csrc/page.cu index 79aab9ee..f580c8c5 100644 --- a/python/csrc/page.cu +++ b/python/csrc/page.cu @@ -20,13 +20,14 @@ using namespace flashinfer; void append_paged_kv_cache(torch::Tensor append_key, torch::Tensor append_value, - torch::Tensor append_indptr, torch::Tensor paged_k_cache, - torch::Tensor paged_v_cache, torch::Tensor kv_indices, - torch::Tensor kv_indptr, torch::Tensor kv_last_page_len, - unsigned int layout) { - CHECK_INPUT(append_key); - CHECK_INPUT(append_value); - CHECK_INPUT(append_indptr); + torch::Tensor batch_indices, torch::Tensor positions, + torch::Tensor paged_k_cache, torch::Tensor paged_v_cache, + torch::Tensor kv_indices, torch::Tensor kv_indptr, + torch::Tensor kv_last_page_len, unsigned int layout) { + CHECK_LAST_DIM_CONTIGUOUS(append_key); + CHECK_LAST_DIM_CONTIGUOUS(append_value); + CHECK_INPUT(batch_indices); + CHECK_INPUT(positions); // NOTE(Zihao): doesn't have to be contiguous CHECK_LAST_DIM_CONTIGUOUS_INPUT(paged_k_cache); CHECK_LAST_DIM_CONTIGUOUS_INPUT(paged_v_cache); @@ -35,20 +36,24 @@ void append_paged_kv_cache(torch::Tensor append_key, torch::Tensor append_value, CHECK_INPUT(kv_last_page_len); CHECK_DIM(3, append_key); CHECK_DIM(3, append_value); - CHECK_DIM(1, append_indptr); + CHECK_DIM(1, batch_indices); + CHECK_DIM(1, positions); CHECK_DIM(4, paged_k_cache); CHECK_DIM(4, paged_v_cache); CHECK_DIM(1, kv_indices); CHECK_DIM(1, kv_indptr); CHECK_DIM(1, kv_last_page_len); + unsigned int nnz = append_key.size(0); unsigned int batch_size = kv_last_page_len.size(0); - CHECK_EQ(append_indptr.size(0), batch_size + 1); CHECK_EQ(kv_indptr.size(0), batch_size + 1); - CHECK_EQ(append_indptr.scalar_type(), torch::kInt32); + CHECK_EQ(batch_indices.size(0), nnz); + CHECK_EQ(positions.size(0), nnz); + CHECK_EQ(batch_indices.scalar_type(), torch::kInt32); + CHECK_EQ(positions.scalar_type(), torch::kInt32); CHECK_EQ(kv_indptr.scalar_type(), torch::kInt32); CHECK_EQ(kv_indices.scalar_type(), torch::kInt32); CHECK_EQ(kv_last_page_len.scalar_type(), torch::kInt32); - auto device = append_indptr.device(); + auto device = append_key.device(); CHECK_EQ(append_key.device(), device); CHECK_EQ(append_value.device(), device); CHECK_EQ(paged_k_cache.device(), device); @@ -76,10 +81,17 @@ void append_paged_kv_cache(torch::Tensor append_key, torch::Tensor append_value, TORCH_CHECK(k_strides == v_strides, "k/v strides must be identical"); kv_cache_strides = k_strides.data(); + auto append_k_strides = append_key.strides(); + auto append_k_stride_n = append_k_strides[0]; + auto append_k_stride_h = append_k_strides[1]; + auto append_v_strides = append_value.strides(); + auto append_v_stride_n = append_v_strides[0]; + auto append_v_stride_h = append_v_strides[1]; + CHECK_EQ(append_key.size(1), num_heads); CHECK_EQ(append_key.size(2), head_dim); CHECK_EQ(append_value.size(1), num_heads); - CHECK_EQ(append_key.size(2), head_dim); + CHECK_EQ(append_value.size(2), head_dim); cudaStream_t torch_current_stream = c10::cuda::getCurrentCUDAStream(device.index()); @@ -92,10 +104,12 @@ void append_paged_kv_cache(torch::Tensor append_key, torch::Tensor append_value, static_cast(paged_v_cache.data_ptr()), kv_cache_strides, static_cast(kv_indices.data_ptr()), static_cast(kv_indptr.data_ptr()), static_cast(kv_last_page_len.data_ptr())); - cudaError_t status = - AppendPagedKVCache(paged_kv, static_cast(append_key.data_ptr()), - static_cast(append_value.data_ptr()), - static_cast(append_indptr.data_ptr()), torch_current_stream); + cudaError_t status = AppendPagedKVCache(paged_kv, static_cast(append_key.data_ptr()), + static_cast(append_value.data_ptr()), + static_cast(batch_indices.data_ptr()), + static_cast(positions.data_ptr()), nnz, + append_k_stride_n, append_k_stride_h, append_v_stride_n, + append_v_stride_h, torch_current_stream); TORCH_CHECK(status == cudaSuccess, "AppendPagedKVCache failed with error: ", cudaGetErrorString(status)); return true; diff --git a/python/csrc_aot/flashinfer_ops.cu b/python/csrc_aot/flashinfer_ops.cu index e32ac3e9..80039260 100644 --- a/python/csrc_aot/flashinfer_ops.cu +++ b/python/csrc_aot/flashinfer_ops.cu @@ -80,10 +80,10 @@ void gemma_fused_add_rmsnorm(torch::Tensor& input, torch::Tensor& residual, torc //========== page ========== void append_paged_kv_cache(torch::Tensor append_key, torch::Tensor append_value, - torch::Tensor append_indptr, torch::Tensor paged_k_cache, - torch::Tensor paged_v_cache, torch::Tensor kv_indices, - torch::Tensor kv_indptr, torch::Tensor kv_last_page_len, - unsigned int layout); + torch::Tensor batch_indices, torch::Tensor positions, + torch::Tensor paged_k_cache, torch::Tensor paged_v_cache, + torch::Tensor kv_indices, torch::Tensor kv_indptr, + torch::Tensor kv_last_page_len, unsigned int layout); //========== prefill ========== diff --git a/python/flashinfer/__init__.py b/python/flashinfer/__init__.py index 5284f268..54cd0d31 100644 --- a/python/flashinfer/__init__.py +++ b/python/flashinfer/__init__.py @@ -46,6 +46,8 @@ from .norm import gemma_rmsnorm as gemma_rmsnorm from .norm import rmsnorm as rmsnorm from .page import append_paged_kv_cache as append_paged_kv_cache +from .page import get_batch_indices_positions as get_batch_indices_positions +from .page import get_seq_lens as get_seq_lens from .prefill import ( BatchPrefillWithPagedKVCacheWrapper as BatchPrefillWithPagedKVCacheWrapper, ) diff --git a/python/flashinfer/page.py b/python/flashinfer/page.py index cc3cbb93..be206b2d 100644 --- a/python/flashinfer/page.py +++ b/python/flashinfer/page.py @@ -14,9 +14,11 @@ limitations under the License. """ -from typing import Optional +from typing import Optional, Tuple import torch +import triton +import triton.language as tl from .jit import FLASHINFER_CSRC_DIR, has_prebuilt_ops, load_cuda_ops from .utils import ( @@ -55,7 +57,8 @@ def get_page_module(): def _append_paged_kv_cache_kernel( append_key: torch.Tensor, append_value: torch.Tensor, - append_indptr: torch.Tensor, + batch_indices: torch.Tensor, + positions: torch.Tensor, paged_k_cache: Optional[torch.Tensor], paged_v_cache: Optional[torch.Tensor], kv_indices: torch.Tensor, @@ -66,7 +69,8 @@ def _append_paged_kv_cache_kernel( get_page_module().append_paged_kv_cache( append_key, append_value, - append_indptr, + batch_indices, + positions, paged_k_cache, paged_v_cache, kv_indices, @@ -80,7 +84,8 @@ def _append_paged_kv_cache_kernel( def _fake_append_paged_kv_cache_kernel( append_key: torch.Tensor, append_value: torch.Tensor, - append_indptr: torch.Tensor, + batch_indices: torch.Tensor, + positions: torch.Tensor, paged_k_cache: Optional[torch.Tensor], paged_v_cache: Optional[torch.Tensor], kv_indices: torch.Tensor, @@ -91,10 +96,107 @@ def _fake_append_paged_kv_cache_kernel( pass +@triton.jit +def get_batch_indices_positions_kernel( + append_indptr, + seq_lens_ptr, + batch_indices_ptr, + positions_ptr, + num_stages: tl.constexpr, +): + batch_idx = tl.program_id(0) + + batch_start = tl.load(append_indptr + batch_idx) + batch_end = tl.load(append_indptr + batch_idx + 1) + seq_len = tl.load(seq_lens_ptr + batch_idx) + + for i in tl.range(batch_start, batch_end, 128, num_stages=num_stages): + offsets = tl.arange(0, 128) + i + mask = offsets < batch_end + tl.store(batch_indices_ptr + offsets, batch_idx, mask) + tl.store(positions_ptr + offsets, offsets + seq_len - batch_end, mask) + + +def get_batch_indices_positions( + append_indptr: torch.Tensor, seq_lens: torch.Tensor, nnz: int +) -> Tuple[torch.Tensor, torch.Tensor]: + r"""Convert append indptr and sequence lengths to batch indices and positions. + + Parameters + ---------- + append_indptr : torch.Tensor + The indptr of the ragged tensor, shape: ``[batch_size + 1]``. + seq_lens: torch.Tensor + The sequence lengths of each request in the KV-Cache, shape: ``[batch_size]``. + nnz : int + The number of entries in the ragged tensor. + + Returns + ------- + batch_indices: torch.Tensor + The batch indices of the each entry in the ragged tensor, shape: ``[nnz]``. + positions: torch.Tensor + The positions of the each entry in the ragged tensor, shape: ``[nnz]``. + + Example + ------- + >>> import torch + >>> import flashinfer + >>> nnz_kv = 10 + >>> append_indptr = torch.tensor([0, 1, 3, 6, 10], dtype=torch.int32, device="cuda:0") + >>> seq_lens = torch.tensor([5, 5, 5, 5]) + >>> batch_indices, positions = flashinfer.get_batch_indices_positions(append_indptr, seq_lens, nnz_kv) + >>> batch_indices + tensor([0, 1, 1, 2, 2, 2, 3, 3, 3, 3], device='cuda:0', dtype=torch.int32) + >>> positions # the rightmost column index of each row + tensor([4, 3, 4, 2, 3, 4, 1, 2, 3, 4], device='cuda:0', dtype=torch.int32) + + Notes + ----- + This function is similar to `CSR2COO `_ + conversion in cuSPARSE library, with the difference that we are converting from a ragged + tensor (which don't require a column indices array) to a COO format. + """ + batch_size = append_indptr.size(0) - 1 + batch_indices = torch.empty((nnz,), device=append_indptr.device, dtype=torch.int32) + positions = torch.empty((nnz,), device=append_indptr.device, dtype=torch.int32) + get_batch_indices_positions_kernel[(batch_size,)]( + append_indptr, seq_lens, batch_indices, positions, num_stages=2 + ) + return batch_indices, positions + + +def get_seq_lens( + kv_indptr: torch.Tensor, kv_last_page_len: torch.Tensor, page_size: int +) -> torch.Tensor: + r"""Convert KV indptr and last page length to sequence lengths. + + Parameters + ---------- + kv_indptr : torch.Tensor + The indptr of the paged kv-cache, shape: ``[batch_size + 1]``. + kv_last_page_len : torch.Tensor + The number of entries in the last page of each request in the paged kv cache, + shape: ``[batch_size]``. + page_size : int + The size of a page in the paged kv-cache. + + Returns + ------- + seq_lens: torch.Tensor + The sequence lengths of each request in the paged kv-cache, shape: ``[batch_size]``. + """ + return ( + torch.clamp(kv_indptr[1:] - kv_indptr[:-1] - 1, min=0) * page_size + + kv_last_page_len + ) + + def append_paged_kv_cache( append_key: torch.Tensor, append_value: torch.Tensor, - append_indptr: torch.Tensor, + batch_indices: torch.Tensor, + positions: torch.Tensor, paged_kv_cache: torch.Tensor, kv_indices: torch.Tensor, kv_indptr: torch.Tensor, @@ -111,8 +213,10 @@ def append_paged_kv_cache( append_value : torch.Tensor The value tensor to append in ragged tensor format, shape: ``[append_indptr[-1], num_kv_heads, head_dim]``. - append_indptr : torch.Tensor - The indptr tensor of the key-value pairs to append, shape: ``[batch_size + 1]``. + batch_indices : torch.Tensor + The batch indices of the each entry in the appended key-value pairs, shape: ``[append_indptr[-1]]``. + positions : torch.Tensor + The positions of the each entry in the appended key-value pairs, shape: ``[append_indptr[-1]]``. paged_kv_cache : Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]] The paged KV-Cache stored as a tuple of tensors or a single tensor: @@ -165,11 +269,28 @@ def append_paged_kv_cache( >>> # 25 = (2 - 1) * 16 + 9 >>> # 22 = (2 - 1) * 16 + 6 >>> kv_last_page_len = torch.tensor([13, 8, 9, 6], dtype=torch.int32, device="cuda:0") - >>> + >>> batch_indices, positions = flashinfer.get_batch_indices_positions( + ... kv_append_indptr, flashinfer.get_seq_lens(kv_page_indptr, kv_last_page_len, page_size), nnz_kv + ... ) + >>> batch_indices + tensor([0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, + 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 1, + 1, 1, 1, 1, 1, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, + 2, 2, 2, 2, 2, 2, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, + 3, 3, 3, 3], device='cuda:0', dtype=torch.int32) + >>> positions + tensor([ 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, + 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31, 32, 33, 34, 35, + 36, 37, 38, 39, 40, 41, 42, 43, 44, 0, 1, 2, 3, 4, 5, 6, 7, 0, + 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, + 19, 20, 21, 22, 23, 24, 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, + 12, 13, 14, 15, 16, 17, 18, 19, 20, 21], device='cuda:0', + dtype=torch.int32) >>> flashinfer.append_paged_kv_cache( ... k_append, ... v_append, - ... kv_append_indptr, + ... batch_indices, + ... positions, ... paged_kv_cache, ... kv_page_indices, ... kv_page_indptr, @@ -189,7 +310,8 @@ def append_paged_kv_cache( _append_paged_kv_cache_kernel( append_key, append_value, - append_indptr, + batch_indices, + positions, *_unpack_paged_kv_cache(paged_kv_cache, kv_layout), kv_indices, kv_indptr, diff --git a/src/test_page.cu b/src/test_page.cu index b36b143d..39cdad22 100644 --- a/src/test_page.cu +++ b/src/test_page.cu @@ -32,9 +32,11 @@ void _TestAppendPagedKVKernelCorrectness(size_t page_size, size_t batch_size, si size_t max_prefill_len = 128; size_t max_num_pages = num_conv_rounds * batch_size * ((max_decode_len + max_prefill_len) / page_size + 1); - std::vector kv_data_cpu(2 * max_num_pages * page_size * num_heads * head_dim); - utils::vec_zero_(kv_data_cpu); - thrust::device_vector kv_data_gpu(kv_data_cpu); + std::vector k_data_cpu(max_num_pages * page_size * num_heads * head_dim); + std::vector v_data_cpu(max_num_pages * page_size * num_heads * head_dim); + utils::vec_zero_(k_data_cpu); + utils::vec_zero_(v_data_cpu); + thrust::device_vector k_data_gpu(k_data_cpu), v_data_gpu(v_data_cpu); std::vector seq_len(batch_size); utils::vec_fill_(seq_len, 0); std::vector> page_indices(batch_size); @@ -45,6 +47,8 @@ void _TestAppendPagedKVKernelCorrectness(size_t page_size, size_t batch_size, si for (size_t round = 0; round < 2 * num_conv_rounds; ++round) { std::vector append_len(batch_size); std::vector append_indptr{0}; + std::vector batch_indices; + std::vector positions; std::vector> keys; std::vector> values; if (round % 2 == 0) { @@ -62,6 +66,8 @@ void _TestAppendPagedKVKernelCorrectness(size_t page_size, size_t batch_size, si } else { last_page_len[i] += 1; } + batch_indices.push_back(i); + positions.push_back(seq_len[i] - append_len[i] + j); } std::vector ki(append_len[i] * num_heads * head_dim), vi(append_len[i] * num_heads * head_dim); @@ -79,24 +85,24 @@ void _TestAppendPagedKVKernelCorrectness(size_t page_size, size_t batch_size, si } indptr_cpu.push_back(indptr_cpu.back() + page_indices[i].size()); } - paged_kv_t paged_kv_cpu( - num_heads, page_size, head_dim, batch_size, kv_layout, - /*k_data=*/kv_data_cpu.data(), - /*v_data=*/kv_data_cpu.data() + page_size * num_heads * head_dim, indices_cpu.data(), - indptr_cpu.data(), last_page_len.data()); + paged_kv_t paged_kv_cpu(num_heads, page_size, head_dim, batch_size, kv_layout, + /*k_data=*/k_data_cpu.data(), + /*v_data=*/v_data_cpu.data(), indices_cpu.data(), + indptr_cpu.data(), last_page_len.data()); cpu_reference::append_paged_kv_cache(paged_kv_cpu, keys, values, append_indptr); thrust::device_vector indptr_gpu(indptr_cpu); thrust::device_vector indices_gpu(indices_cpu); thrust::device_vector last_page_len_gpu(last_page_len); - paged_kv_t paged_kv_gpu( - num_heads, page_size, head_dim, batch_size, kv_layout, - /*k_data=*/thrust::raw_pointer_cast(kv_data_gpu.data()), - /*v_data=*/thrust::raw_pointer_cast(kv_data_gpu.data()) + page_size * num_heads * head_dim, - thrust::raw_pointer_cast(indices_gpu.data()), thrust::raw_pointer_cast(indptr_gpu.data()), - thrust::raw_pointer_cast(last_page_len_gpu.data())); - - thrust::device_vector append_indptr_gpu(append_indptr); + paged_kv_t paged_kv_gpu(num_heads, page_size, head_dim, batch_size, kv_layout, + /*k_data=*/thrust::raw_pointer_cast(k_data_gpu.data()), + /*v_data=*/thrust::raw_pointer_cast(v_data_gpu.data()), + thrust::raw_pointer_cast(indices_gpu.data()), + thrust::raw_pointer_cast(indptr_gpu.data()), + thrust::raw_pointer_cast(last_page_len_gpu.data())); + + thrust::device_vector batch_indices_gpu(batch_indices); + thrust::device_vector positions_gpu(positions); thrust::device_vector keys_gpu(append_indptr.back() * num_heads * head_dim); thrust::device_vector values_gpu(append_indptr.back() * num_heads * head_dim); for (size_t i = 0; i < batch_size; ++i) { @@ -107,12 +113,19 @@ void _TestAppendPagedKVKernelCorrectness(size_t page_size, size_t batch_size, si thrust::copy(vi.begin(), vi.end(), values_gpu.begin() + append_indptr[i] * num_heads * head_dim); } + if (round % 2 == 0) { // call prefill kernel cudaError_t status = AppendPagedKVCache(paged_kv_gpu, thrust::raw_pointer_cast(keys_gpu.data()), thrust::raw_pointer_cast(values_gpu.data()), - thrust::raw_pointer_cast(append_indptr_gpu.data())); + thrust::raw_pointer_cast(batch_indices_gpu.data()), + thrust::raw_pointer_cast(positions_gpu.data()), + /*nnz=*/append_indptr.back(), + /*append_k_stride_n=*/num_heads * head_dim, + /*append_k_stride_h=*/head_dim, + /*append_v_stride_n=*/num_heads * head_dim, + /*append_v_stride_h=*/head_dim); EXPECT_EQ(status, cudaSuccess) << "AppendPagedKVCache kernel launch failed, error message: " << cudaGetErrorString(status); } else { @@ -126,18 +139,25 @@ void _TestAppendPagedKVKernelCorrectness(size_t page_size, size_t batch_size, si } } - thrust::host_vector kv_data_gpu_h(kv_data_gpu); + thrust::host_vector k_data_gpu_h(k_data_gpu), v_data_gpu_h(v_data_gpu); size_t num_result_errors_atol_1e_3_rtol_1e_3 = 0; bool nan_detected = false; - for (size_t i = 0; i < kv_data_cpu.size(); ++i) { - if (std::isnan(float(kv_data_gpu_h[i]))) { + for (size_t i = 0; i < k_data_cpu.size(); ++i) { + if (std::isnan(float(k_data_gpu_h[i]))) { + nan_detected = true; + } + num_result_errors_atol_1e_3_rtol_1e_3 += + (!utils::isclose(float(k_data_cpu[i]), float(k_data_gpu_h[i]), 1e-3, 1e-3)); + } + for (size_t i = 0; i < v_data_cpu.size(); ++i) { + if (std::isnan(float(v_data_gpu_h[i]))) { nan_detected = true; } num_result_errors_atol_1e_3_rtol_1e_3 += - (!utils::isclose(float(kv_data_cpu[i]), float(kv_data_gpu_h[i]), 1e-3, 1e-3)); + (!utils::isclose(float(v_data_cpu[i]), float(v_data_gpu_h[i]), 1e-3, 1e-3)); } - float result_accuracy = - 1. - float(num_result_errors_atol_1e_3_rtol_1e_3) / float(kv_data_cpu.size()); + float result_accuracy = 1. - float(num_result_errors_atol_1e_3_rtol_1e_3) / + float(k_data_cpu.size() + v_data_cpu.size()); std::cout << "kv_layout=" << QKVLayoutToString(kv_layout) << ", page_size=" << page_size << ", batch_size=" << batch_size << ", num_heads=" << num_heads << ", head_dim=" << head_dim << ", result_accuracy=" << result_accuracy << std::endl; diff --git a/tests/test_page.py b/tests/test_page.py index 97daa8e6..c075725b 100644 --- a/tests/test_page.py +++ b/tests/test_page.py @@ -1,19 +1,28 @@ +import pytest import torch import flashinfer -def test_append_paged_kv_cache(): +@pytest.mark.parametrize("contiguous", [True, False]) +def test_append_paged_kv_cache(contiguous): nnz_kv = 100 num_kv_heads = 32 head_dim = 128 - k_append = torch.randn(nnz_kv, num_kv_heads, head_dim).half().to(0) - v_append = torch.randn(nnz_kv, num_kv_heads, head_dim).half().to(0) + + if contiguous: + k_append = torch.randn(nnz_kv, num_kv_heads, head_dim).half().to(0) + v_append = torch.randn(nnz_kv, num_kv_heads, head_dim).half().to(0) + else: + kv_append = torch.randn(nnz_kv, 2, num_kv_heads, head_dim).half().to(0) + k_append = kv_append[:, 0] + v_append = kv_append[:, 1] # 45 + 8 + 25 + 22 = nnz_kv kv_append_length = torch.tensor([45, 8, 25, 22], dtype=torch.int32, device="cuda:0") kv_append_indptr = torch.cat( [torch.zeros(1).int().to(0), torch.cumsum(kv_append_length, dim=0)] ).int() + max_num_pages = 1000 page_size = 16 paged_kv_cache = ( @@ -30,11 +39,17 @@ def test_append_paged_kv_cache(): # 25 = (2 - 1) * 16 + 9 # 22 = (2 - 1) * 16 + 6 kv_last_page_len = torch.tensor([13, 8, 9, 6], dtype=torch.int32, device="cuda:0") + batch_indices, positions = flashinfer.get_batch_indices_positions( + kv_append_indptr, + flashinfer.get_seq_lens(kv_append_indptr, kv_last_page_len, page_size), + nnz_kv, + ) - flashinfer.page.append_paged_kv_cache( + flashinfer.append_paged_kv_cache( k_append, v_append, - kv_append_indptr, + batch_indices, + positions, paged_kv_cache, kv_page_indices, kv_page_indptr,