From 7adc8cf01a029645307c321a7754d0b0a4f0f4de Mon Sep 17 00:00:00 2001 From: Zihao Ye Date: Wed, 3 Jul 2024 00:42:37 -0700 Subject: [PATCH] bugfix: fix prefill/append kernel behavior for empty kv-cache. (#353) The prefill kernels was buggy when some of the requests have empty kv-cache, this PR fixes the issue. --- include/flashinfer/attention/handler.cuh | 6 +- include/flashinfer/attention/prefill.cuh | 22 +++- src/test_batch_prefill.cu | 158 ++++++++++++++++++++++- 3 files changed, 175 insertions(+), 11 deletions(-) diff --git a/include/flashinfer/attention/handler.cuh b/include/flashinfer/attention/handler.cuh index 142fa668..af17de89 100644 --- a/include/flashinfer/attention/handler.cuh +++ b/include/flashinfer/attention/handler.cuh @@ -114,7 +114,8 @@ inline std::tuple PrefillBinarySearchKVChunkSize( new_batch_size = 0; for (uint32_t i = 0; i < batch_size; ++i) { - new_batch_size += ceil_div(packed_qo_len_arr[i], qo_chunk_size) * ceil_div(kv_len_arr[i], low); + new_batch_size += ceil_div(packed_qo_len_arr[i], qo_chunk_size) * + ceil_div(std::max(int(kv_len_arr[i]), 1), low); } return {low < max_kv_len, low, new_batch_size}; } @@ -571,7 +572,8 @@ cudaError_t PrefillSplitQOKVIndptr(bool& split_kv, uint32_t& split_max_batch_siz // step 3: split qo_indptr and kv_indptr total_num_tiles_q = 0; for (uint32_t request_idx = 0; request_idx < batch_size; ++request_idx) { - int64_t packed_qo_len = packed_qo_len_arr[request_idx], kv_len = kv_len_arr[request_idx]; + int64_t packed_qo_len = packed_qo_len_arr[request_idx], + kv_len = std::max(int(kv_len_arr[request_idx]), 1); int64_t num_tiles_q = ceil_div(packed_qo_len, qo_chunk_size), num_tiles_kv = ceil_div(kv_len, kv_chunk_size); total_num_tiles_q += num_tiles_q; diff --git a/include/flashinfer/attention/prefill.cuh b/include/flashinfer/attention/prefill.cuh index cb8f71b5..97fbc4ff 100644 --- a/include/flashinfer/attention/prefill.cuh +++ b/include/flashinfer/attention/prefill.cuh @@ -619,7 +619,7 @@ __device__ __forceinline__ void mask_s(const uint32_t qo_packed_idx_base, reg_id % 2; const bool out_of_boundary = (mask_mode == MaskMode::kCausal - ? (kv_idx > kv_len + q_idx - qo_len || (partition_kv && kv_idx >= chunk_end)) + ? (kv_idx + qo_len > kv_len + q_idx || (partition_kv && kv_idx >= chunk_end)) : kv_idx >= chunk_end); s_frag[fx][fz][reg_id] = (out_of_boundary || @@ -1503,9 +1503,11 @@ __global__ void BatchPrefillWithPagedKVCacheKernel( kv_tile_idx = kv_tile_indices[bx]; constexpr uint32_t num_rows_per_cta = num_frags_x * num_warps_x * 16; const uint32_t qo_len = q_indptr[request_idx + 1] - q_indptr[request_idx], - kv_len = (paged_kv.indptr[request_idx + 1] - paged_kv.indptr[request_idx] - 1) * - paged_kv.page_size + - paged_kv.last_page_len[request_idx]; + kv_len = (paged_kv.indptr[request_idx + 1] != paged_kv.indptr[request_idx]) + ? (paged_kv.indptr[request_idx + 1] - paged_kv.indptr[request_idx] - + 1) * paged_kv.page_size + + paged_kv.last_page_len[request_idx] + : 0; const uint32_t chunk_size = partition_kv ? kv_chunk_size : kv_len; const uint32_t chunk_start = partition_kv ? kv_tile_idx * chunk_size : 0; const uint32_t chunk_end = partition_kv ? min((kv_tile_idx + 1) * chunk_size, kv_len) : kv_len; @@ -1908,6 +1910,12 @@ cudaError_t BatchPrefillWithRaggedKVCacheDispatched( const uint32_t group_size = num_qo_heads / num_kv_heads; const uint_fastdiv group_size_fastdiv(group_size); + if (padded_batch_size == 0) { + // No request, skip + // this won't happen in CUDAGraph mode because we fixed the padded_batch_size + return cudaSuccess; + } + dim3 nblks(padded_batch_size, 1, num_kv_heads); dim3 nthrs(32, num_warps_x, num_warps_z); constexpr uint32_t num_frags_y = HEAD_DIM / 16; @@ -2040,6 +2048,12 @@ cudaError_t BatchPrefillWithPagedKVCacheDispatched( const uint32_t group_size = num_qo_heads / num_kv_heads; const uint_fastdiv group_size_fastdiv(group_size); + if (padded_batch_size == 0) { + // No request, skip + // this won't happen in CUDAGraph mode because we fixed the padded_batch_size + return cudaSuccess; + } + dim3 nblks(padded_batch_size, 1, num_kv_heads); dim3 nthrs(32, num_warps_x, num_warps_z); diff --git a/src/test_batch_prefill.cu b/src/test_batch_prefill.cu index a9e78468..976add0e 100644 --- a/src/test_batch_prefill.cu +++ b/src/test_batch_prefill.cu @@ -18,6 +18,7 @@ #include #include "cpu_reference.h" +#include "flashinfer/pos_enc.cuh" #include "flashinfer_ops.cuh" #include "utils.h" @@ -237,12 +238,13 @@ void _TestBatchPagedPrefillKernelShortContextCorrectness(size_t num_kv_heads, si std::vector q_lens(batch_size); utils::vec_randint_(q_lens, 1, 64); std::vector kv_lens(q_lens); + std::vector q_indptr{0}; - for (uint32_t i = 0; i < batch_size; ++i) { - q_indptr.push_back(q_indptr.back() + q_lens[i]); + for (uint32_t request_idx = 0; request_idx < batch_size; ++request_idx) { + q_indptr.push_back(q_indptr.back() + q_lens[request_idx]); } std::vector append_indptr{0}; - for (size_t request_idx = 0; request_idx < batch_size; ++request_idx) { + for (uint32_t request_idx = 0; request_idx < batch_size; ++request_idx) { append_indptr.push_back(append_indptr.back() + kv_lens[request_idx]); } std::vector kv_data; @@ -295,7 +297,6 @@ void _TestBatchPagedPrefillKernelShortContextCorrectness(size_t num_kv_heads, si q.push_back(qi); } for (uint32_t request_idx = 0; request_idx < batch_size; ++request_idx) { - // create one-hot queries int32_t q_len = q_lens[request_idx], kv_len = kv_lens[request_idx]; std::vector o_ref_i = cpu_reference::single_mha( q[request_idx], key[request_idx], value[request_idx], q_len, kv_len, num_qo_heads, @@ -318,7 +319,7 @@ void _TestBatchPagedPrefillKernelShortContextCorrectness(size_t num_kv_heads, si thrust::device_vector buffer(workspace_size_in_bytes); handler.BeginForward((void*)thrust::raw_pointer_cast(buffer.data()), - workspace_size_in_bytes, append_indptr.data(), kv_indptr.data(), + workspace_size_in_bytes, q_indptr.data(), kv_indptr.data(), batch_size, num_qo_heads, num_kv_heads, head_dim, page_size); auto status = @@ -350,6 +351,128 @@ void _TestBatchPagedPrefillKernelShortContextCorrectness(size_t num_kv_heads, si EXPECT_EQ(nan_detected, false) << "NaN detected in output."; } +template +void _TestBatchPagedPrefillKernelQMinMaxKVMinMaxCorrectness( + size_t batch_size, size_t num_kv_heads, size_t num_qo_heads, size_t page_size, size_t head_dim, + bool allow_fp16_qk_reduction, uint32_t q_len_min, uint32_t q_len_max, uint32_t kv_len_min, + uint32_t kv_len_max) { + std::vector q_lens(batch_size); + utils::vec_randint_(q_lens, q_len_min, q_len_max); + std::vector kv_lens(batch_size); + utils::vec_randint_(kv_lens, kv_len_min, kv_len_max); + + std::vector q_indptr{0}; + for (uint32_t request_idx = 0; request_idx < batch_size; ++request_idx) { + q_indptr.push_back(q_indptr.back() + q_lens[request_idx]); + } + std::vector append_indptr{0}; + for (uint32_t request_idx = 0; request_idx < batch_size; ++request_idx) { + append_indptr.push_back(append_indptr.back() + kv_lens[request_idx]); + } + std::vector kv_data; + std::vector kv_indptr{0}; + std::vector kv_indices; + std::vector kv_last_page_len; + size_t page_counter = 0; + std::vector> key, value; + for (uint32_t request_idx = 0; request_idx < batch_size; ++request_idx) { + size_t kv_len = kv_lens[request_idx]; + size_t num_pages = (kv_len + page_size - 1) / page_size; + size_t last_page_len = num_pages == 0 ? 0 : (kv_len - 1) % page_size + 1; + std::vector k(kv_len * num_kv_heads * head_dim), v(kv_len * num_kv_heads * head_dim); + utils::vec_normal_(k); + utils::vec_normal_(v); + key.push_back(k); + value.push_back(v); + kv_last_page_len.push_back(last_page_len); + kv_indptr.push_back(kv_indptr.back() + num_pages); + for (size_t j = 0; j < num_pages; ++j) { + kv_indices.push_back(page_counter++); + } + } + + kv_data.resize(page_counter * 2 * num_kv_heads * page_size * head_dim); + flashinfer::paged_kv_t paged_kv_cpu( + num_kv_heads, page_size, head_dim, batch_size, kv_data.data(), kv_indices.data(), + kv_indptr.data(), kv_last_page_len.data()); + cpu_reference::append_paged_kv_cache(paged_kv_cpu, key, value, + append_indptr); + + // copy data to device + thrust::device_vector kv_data_device(kv_data); + thrust::device_vector kv_indptr_device(kv_indptr); + thrust::device_vector kv_indices_device(kv_indices); + thrust::device_vector kv_last_page_len_device(kv_last_page_len); + + // create paged_kv object + flashinfer::paged_kv_t paged_kv = paged_kv_cpu; + paged_kv.data = thrust::raw_pointer_cast(kv_data_device.data()); + paged_kv.indices = thrust::raw_pointer_cast(kv_indices_device.data()); + paged_kv.indptr = thrust::raw_pointer_cast(kv_indptr_device.data()); + paged_kv.last_page_len = thrust::raw_pointer_cast(kv_last_page_len_device.data()); + + std::vector> q, o_ref; + for (uint32_t request_idx = 0; request_idx < batch_size; ++request_idx) { + int32_t q_len = q_lens[request_idx]; + std::vector qi(q_len * num_qo_heads * head_dim); + utils::vec_normal_(qi); + q.push_back(qi); + } + for (uint32_t request_idx = 0; request_idx < batch_size; ++request_idx) { + int32_t q_len = q_lens[request_idx], kv_len = kv_lens[request_idx]; + std::vector o_ref_i = cpu_reference::single_mha( + q[request_idx], key[request_idx], value[request_idx], q_len, kv_len, num_qo_heads, + num_kv_heads, head_dim, /*causal=*/false, QKVLayout::kNHD, + /*pos_encoding_mode*/ PosEncodingMode::kNone); + o_ref.push_back(o_ref_i); + } + + std::vector q_concat, o_concat_ref; + for (uint32_t request_idx = 0; request_idx < batch_size; ++request_idx) { + q_concat.insert(q_concat.end(), q[request_idx].begin(), q[request_idx].end()); + o_concat_ref.insert(o_concat_ref.end(), o_ref[request_idx].begin(), o_ref[request_idx].end()); + } + thrust::device_vector q_device(q_concat); + + thrust::device_vector q_indptr_device(q_indptr); + thrust::device_vector o_device(o_concat_ref.size()); + + BatchPrefillHandler handler; + size_t workspace_size_in_bytes = 32 * 1024 * 1024; + thrust::device_vector buffer(workspace_size_in_bytes); + + handler.BeginForward((void*)thrust::raw_pointer_cast(buffer.data()), + workspace_size_in_bytes, q_indptr.data(), kv_indptr.data(), + batch_size, num_qo_heads, num_kv_heads, head_dim, page_size); + + auto status = + BatchPrefillWithPagedKVCacheWrapper( + &handler, thrust::raw_pointer_cast(q_device.data()), + thrust::raw_pointer_cast(q_indptr_device.data()), + /*q_offset=*/nullptr, paged_kv, thrust::raw_pointer_cast(o_device.data()), + /*lse=*/nullptr, num_qo_heads, /*causal=*/false, + /*pos_encoding_mode*/ PosEncodingMode::kNone); + EXPECT_EQ(status, cudaSuccess) << "CUDA error: " + std::string(cudaGetErrorString(status)); + + thrust::host_vector o_host(o_device); + size_t num_result_errors_atol_1e_3_rtol_1e_3 = 0; + bool nan_detected = false; + for (size_t i = 0; i < o_concat_ref.size(); ++i) { + if (std::isnan(float(o_host[i]))) { + nan_detected = true; + } + num_result_errors_atol_1e_3_rtol_1e_3 += + (!utils::isclose(float(o_host[i]), float(o_concat_ref[i]), 1e-3, 1e-3)); + } + float result_accuracy = + 1. - float(num_result_errors_atol_1e_3_rtol_1e_3) / max(float(o_concat_ref.size()), 1.f); + std::cout << "batch_size=" << batch_size << ", page_size=" << page_size + << ", num_qo_heads=" << num_qo_heads << ", num_kv_heads=" << num_kv_heads + << ", head_dim=" << head_dim << ", result_accuracy=" << result_accuracy << std::endl; + EXPECT_GT(result_accuracy, 0.99) << "Result correctness test failed."; + EXPECT_EQ(nan_detected, false) << "NaN detected in output."; +} + template void _TestBatchPagedPrefillKernelLongContextCorrectness(size_t num_kv_heads, size_t num_qo_heads, size_t page_size, size_t head_dim, @@ -505,6 +628,27 @@ void TestBatchPagedPrefillKernelLongContextCorrectness(bool allow_fp16_qk_reduct } } +template +void TestBatchPagedPrefillKernelZeroContextCorrectness(bool allow_fp16_qk_reduction) { + for (size_t batch_size : {1, 4, 7, 11, 19, 37, 99}) { + for (size_t num_kv_heads : {1, 4}) { + for (size_t group_size : {1, 8}) { + size_t num_qo_heads = num_kv_heads * group_size; + for (size_t page_size : {1, 16}) { + for (size_t head_dim : {64, 128, 256}) { + for (size_t kv_len_max : {0, 3}) { + _TestBatchPagedPrefillKernelQMinMaxKVMinMaxCorrectness( + batch_size, num_kv_heads, num_qo_heads, page_size, head_dim, + allow_fp16_qk_reduction, + /*q_len_min=*/1, /*q_len_max=*/3, /*kv_len_min=*/0, kv_len_max); + } + } + } + } + } + } +} + template void TestBatchRaggedPrefillKernelCorrectness(bool allow_fp16_qk_reduction) { for (size_t num_kv_heads : {4, 8, 32}) { @@ -534,6 +678,10 @@ TEST(FlashInferCorrectnessTest, BatchPagedPrefillLongContextTestFP16) { TestBatchPagedPrefillKernelLongContextCorrectness(false); } +TEST(FlashInferCorrectnessTest, BatchPagedPrefillZeroContextTestFP16) { + TestBatchPagedPrefillKernelZeroContextCorrectness(false); +} + TEST(FlashInferCorrectnessTest, BatchPagedPrefillLongContextTestFP16QKHalfAccum) { TestBatchPagedPrefillKernelLongContextCorrectness(true); }