From ac72b1cc14a6474d601f371c8d69e2600ac28d2f Mon Sep 17 00:00:00 2001 From: Zihao Ye Date: Wed, 10 Jul 2024 00:01:39 -0700 Subject: [PATCH] bugfix: fix decode kernels output for empty kv cache (#363) When some request has empty kv cache, the output of decode kernels doesn't align with prefill kernels. This PR fixes the issue. Thanks @MasterJH5574 for reporting this bug. --- include/flashinfer/attention/handler.cuh | 2 +- python/tests/test_decode_prefill_lse.py | 79 ++++++++++++++++++++++++ python/tests/test_tensor_cores_decode.py | 8 +-- 3 files changed, 84 insertions(+), 5 deletions(-) create mode 100644 python/tests/test_decode_prefill_lse.py diff --git a/include/flashinfer/attention/handler.cuh b/include/flashinfer/attention/handler.cuh index 2a3d4495..632ccb85 100644 --- a/include/flashinfer/attention/handler.cuh +++ b/include/flashinfer/attention/handler.cuh @@ -238,7 +238,7 @@ cudaError_t PartitionPagedKVCacheComputeAuxiliaryInfo( for (uint32_t batch_idx = 0; batch_idx < old_batch_size; batch_idx++) { uint32_t num_chunks = ceil_div(old_indptr_h[batch_idx + 1] - old_indptr_h[batch_idx], max_num_pages_per_batch); - chunk_indptr_vec.push_back(chunk_indptr_vec.back() + num_chunks); + chunk_indptr_vec.push_back(chunk_indptr_vec.back() + std::max(num_chunks, 1U)); if (num_chunks == 0) { new_page_indptr_vec.push_back(old_indptr_h[batch_idx]); new_last_page_len_vec.push_back(0); diff --git a/python/tests/test_decode_prefill_lse.py b/python/tests/test_decode_prefill_lse.py new file mode 100644 index 00000000..f79bdac8 --- /dev/null +++ b/python/tests/test_decode_prefill_lse.py @@ -0,0 +1,79 @@ +""" +Copyright (c) 2024 by FlashInfer team. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +""" + +import flashinfer +import numpy as np +import torch +import pytest + + +def test_mlc_failed_case(): + kv_layout = "HND" + num_pages = 12 + kv_indptr_1 = torch.tensor([0, 0, 9]).int().to(0) + kv_indices_1 = torch.tensor([3, 4, 5, 6, 7, 8, 9, 10, 11]).int().to(0) + kv_last_page_len_1 = torch.tensor([0, 1]).int().to(0) + num_qo_heads = 32 + num_kv_heads = 32 + page_size = 16 + head_dim = 128 + q = torch.randn(2, num_qo_heads, head_dim).to(0).half() + kv_data = torch.randn(12, 2, num_kv_heads, page_size, head_dim).to(0).half() + + workspace_buffer = torch.empty(128 * 1024 * 1024, dtype=torch.int8).to(0) + wrapper = flashinfer.BatchDecodeWithPagedKVCacheWrapper(workspace_buffer, kv_layout) + wrapper.begin_forward( + kv_indptr_1, + kv_indices_1, + kv_last_page_len_1, + num_qo_heads, + num_kv_heads, + head_dim, + page_size, + pos_encoding_mode="NONE", + data_type=torch.float16, + q_data_type=torch.float16, + ) + o_1, lse_1 = wrapper.forward_return_lse(q, kv_data) + + wrapper_tensor_cores = flashinfer.BatchDecodeWithPagedKVCacheWrapper( + workspace_buffer, kv_layout, use_tensor_cores=True + ) + wrapper_tensor_cores.begin_forward( + kv_indptr_1, + kv_indices_1, + kv_last_page_len_1, + num_qo_heads, + num_kv_heads, + head_dim, + page_size, + pos_encoding_mode="NONE", + data_type=torch.float16, + q_data_type=torch.float16, + ) + o_1_tc, lse_1_tc = wrapper_tensor_cores.forward_return_lse( + q, kv_data + ) + + np.testing.assert_allclose( + lse_1.cpu().numpy(), lse_1_tc.cpu().numpy(), rtol=1e-3, atol=1e-3 + ) + np.testing.assert_allclose( + o_1.cpu().numpy(), o_1_tc.cpu().numpy(), rtol=1e-3, atol=1e-3 + ) + +if __name__ == "__main__": + test_mlc_failed_case() diff --git a/python/tests/test_tensor_cores_decode.py b/python/tests/test_tensor_cores_decode.py index 4cac6be7..b49c522d 100644 --- a/python/tests/test_tensor_cores_decode.py +++ b/python/tests/test_tensor_cores_decode.py @@ -104,7 +104,7 @@ def test_batch_decode_tensor_cores( num_kv_heads, head_dim, page_size, - "NONE", + pos_encoding_mode=pos_encoding_mode, data_type=torch.float16, q_data_type=torch.float16, ) @@ -121,7 +121,7 @@ def test_batch_decode_tensor_cores( num_kv_heads, head_dim, page_size, - "NONE", + pos_encoding_mode=pos_encoding_mode, data_type=torch.float16, q_data_type=torch.float16, ) @@ -187,7 +187,7 @@ def test_batch_decode_tensor_cores_cuda_graph( num_kv_heads, head_dim, page_size, - "NONE", + pos_encoding_mode=pos_encoding_mode, data_type=torch.float16, q_data_type=torch.float16, ) @@ -226,7 +226,7 @@ def test_batch_decode_tensor_cores_cuda_graph( num_kv_heads, head_dim, page_size, - "NONE", + pos_encoding_mode=pos_encoding_mode, data_type=torch.float16, q_data_type=torch.float16, )