Skip to content

Commit

Permalink
bugfix: fix prefill/append kernel behavior for empty kv-cache. (#353)
Browse files Browse the repository at this point in the history
The prefill kernels was buggy when some of the requests have empty
kv-cache, this PR fixes the issue.
  • Loading branch information
yzh119 authored Jul 3, 2024
1 parent d1d443a commit 7adc8cf
Show file tree
Hide file tree
Showing 3 changed files with 175 additions and 11 deletions.
6 changes: 4 additions & 2 deletions include/flashinfer/attention/handler.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -114,7 +114,8 @@ inline std::tuple<bool, uint32_t, uint32_t> PrefillBinarySearchKVChunkSize(

new_batch_size = 0;
for (uint32_t i = 0; i < batch_size; ++i) {
new_batch_size += ceil_div(packed_qo_len_arr[i], qo_chunk_size) * ceil_div(kv_len_arr[i], low);
new_batch_size += ceil_div(packed_qo_len_arr[i], qo_chunk_size) *
ceil_div(std::max(int(kv_len_arr[i]), 1), low);
}
return {low < max_kv_len, low, new_batch_size};
}
Expand Down Expand Up @@ -571,7 +572,8 @@ cudaError_t PrefillSplitQOKVIndptr(bool& split_kv, uint32_t& split_max_batch_siz
// step 3: split qo_indptr and kv_indptr
total_num_tiles_q = 0;
for (uint32_t request_idx = 0; request_idx < batch_size; ++request_idx) {
int64_t packed_qo_len = packed_qo_len_arr[request_idx], kv_len = kv_len_arr[request_idx];
int64_t packed_qo_len = packed_qo_len_arr[request_idx],
kv_len = std::max(int(kv_len_arr[request_idx]), 1);
int64_t num_tiles_q = ceil_div(packed_qo_len, qo_chunk_size),
num_tiles_kv = ceil_div(kv_len, kv_chunk_size);
total_num_tiles_q += num_tiles_q;
Expand Down
22 changes: 18 additions & 4 deletions include/flashinfer/attention/prefill.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -619,7 +619,7 @@ __device__ __forceinline__ void mask_s(const uint32_t qo_packed_idx_base,
reg_id % 2;
const bool out_of_boundary =
(mask_mode == MaskMode::kCausal
? (kv_idx > kv_len + q_idx - qo_len || (partition_kv && kv_idx >= chunk_end))
? (kv_idx + qo_len > kv_len + q_idx || (partition_kv && kv_idx >= chunk_end))
: kv_idx >= chunk_end);
s_frag[fx][fz][reg_id] =
(out_of_boundary ||
Expand Down Expand Up @@ -1503,9 +1503,11 @@ __global__ void BatchPrefillWithPagedKVCacheKernel(
kv_tile_idx = kv_tile_indices[bx];
constexpr uint32_t num_rows_per_cta = num_frags_x * num_warps_x * 16;
const uint32_t qo_len = q_indptr[request_idx + 1] - q_indptr[request_idx],
kv_len = (paged_kv.indptr[request_idx + 1] - paged_kv.indptr[request_idx] - 1) *
paged_kv.page_size +
paged_kv.last_page_len[request_idx];
kv_len = (paged_kv.indptr[request_idx + 1] != paged_kv.indptr[request_idx])
? (paged_kv.indptr[request_idx + 1] - paged_kv.indptr[request_idx] -
1) * paged_kv.page_size +
paged_kv.last_page_len[request_idx]
: 0;
const uint32_t chunk_size = partition_kv ? kv_chunk_size : kv_len;
const uint32_t chunk_start = partition_kv ? kv_tile_idx * chunk_size : 0;
const uint32_t chunk_end = partition_kv ? min((kv_tile_idx + 1) * chunk_size, kv_len) : kv_len;
Expand Down Expand Up @@ -1908,6 +1910,12 @@ cudaError_t BatchPrefillWithRaggedKVCacheDispatched(
const uint32_t group_size = num_qo_heads / num_kv_heads;
const uint_fastdiv group_size_fastdiv(group_size);

if (padded_batch_size == 0) {
// No request, skip
// this won't happen in CUDAGraph mode because we fixed the padded_batch_size
return cudaSuccess;
}

dim3 nblks(padded_batch_size, 1, num_kv_heads);
dim3 nthrs(32, num_warps_x, num_warps_z);
constexpr uint32_t num_frags_y = HEAD_DIM / 16;
Expand Down Expand Up @@ -2040,6 +2048,12 @@ cudaError_t BatchPrefillWithPagedKVCacheDispatched(
const uint32_t group_size = num_qo_heads / num_kv_heads;
const uint_fastdiv group_size_fastdiv(group_size);

if (padded_batch_size == 0) {
// No request, skip
// this won't happen in CUDAGraph mode because we fixed the padded_batch_size
return cudaSuccess;
}

dim3 nblks(padded_batch_size, 1, num_kv_heads);
dim3 nthrs(32, num_warps_x, num_warps_z);

Expand Down
158 changes: 153 additions & 5 deletions src/test_batch_prefill.cu
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
#include <cstdint>

#include "cpu_reference.h"
#include "flashinfer/pos_enc.cuh"
#include "flashinfer_ops.cuh"
#include "utils.h"

Expand Down Expand Up @@ -237,12 +238,13 @@ void _TestBatchPagedPrefillKernelShortContextCorrectness(size_t num_kv_heads, si
std::vector<int32_t> q_lens(batch_size);
utils::vec_randint_(q_lens, 1, 64);
std::vector<int32_t> kv_lens(q_lens);

std::vector<int32_t> q_indptr{0};
for (uint32_t i = 0; i < batch_size; ++i) {
q_indptr.push_back(q_indptr.back() + q_lens[i]);
for (uint32_t request_idx = 0; request_idx < batch_size; ++request_idx) {
q_indptr.push_back(q_indptr.back() + q_lens[request_idx]);
}
std::vector<int32_t> append_indptr{0};
for (size_t request_idx = 0; request_idx < batch_size; ++request_idx) {
for (uint32_t request_idx = 0; request_idx < batch_size; ++request_idx) {
append_indptr.push_back(append_indptr.back() + kv_lens[request_idx]);
}
std::vector<T> kv_data;
Expand Down Expand Up @@ -295,7 +297,6 @@ void _TestBatchPagedPrefillKernelShortContextCorrectness(size_t num_kv_heads, si
q.push_back(qi);
}
for (uint32_t request_idx = 0; request_idx < batch_size; ++request_idx) {
// create one-hot queries
int32_t q_len = q_lens[request_idx], kv_len = kv_lens[request_idx];
std::vector<T> o_ref_i = cpu_reference::single_mha<T, T, T>(
q[request_idx], key[request_idx], value[request_idx], q_len, kv_len, num_qo_heads,
Expand All @@ -318,7 +319,7 @@ void _TestBatchPagedPrefillKernelShortContextCorrectness(size_t num_kv_heads, si
thrust::device_vector<char> buffer(workspace_size_in_bytes);

handler.BeginForward<T, int32_t>((void*)thrust::raw_pointer_cast(buffer.data()),
workspace_size_in_bytes, append_indptr.data(), kv_indptr.data(),
workspace_size_in_bytes, q_indptr.data(), kv_indptr.data(),
batch_size, num_qo_heads, num_kv_heads, head_dim, page_size);

auto status =
Expand Down Expand Up @@ -350,6 +351,128 @@ void _TestBatchPagedPrefillKernelShortContextCorrectness(size_t num_kv_heads, si
EXPECT_EQ(nan_detected, false) << "NaN detected in output.";
}

template <typename T>
void _TestBatchPagedPrefillKernelQMinMaxKVMinMaxCorrectness(
size_t batch_size, size_t num_kv_heads, size_t num_qo_heads, size_t page_size, size_t head_dim,
bool allow_fp16_qk_reduction, uint32_t q_len_min, uint32_t q_len_max, uint32_t kv_len_min,
uint32_t kv_len_max) {
std::vector<int32_t> q_lens(batch_size);
utils::vec_randint_(q_lens, q_len_min, q_len_max);
std::vector<int32_t> kv_lens(batch_size);
utils::vec_randint_(kv_lens, kv_len_min, kv_len_max);

std::vector<int32_t> q_indptr{0};
for (uint32_t request_idx = 0; request_idx < batch_size; ++request_idx) {
q_indptr.push_back(q_indptr.back() + q_lens[request_idx]);
}
std::vector<int32_t> append_indptr{0};
for (uint32_t request_idx = 0; request_idx < batch_size; ++request_idx) {
append_indptr.push_back(append_indptr.back() + kv_lens[request_idx]);
}
std::vector<T> kv_data;
std::vector<int32_t> kv_indptr{0};
std::vector<int32_t> kv_indices;
std::vector<int32_t> kv_last_page_len;
size_t page_counter = 0;
std::vector<std::vector<T>> key, value;
for (uint32_t request_idx = 0; request_idx < batch_size; ++request_idx) {
size_t kv_len = kv_lens[request_idx];
size_t num_pages = (kv_len + page_size - 1) / page_size;
size_t last_page_len = num_pages == 0 ? 0 : (kv_len - 1) % page_size + 1;
std::vector<T> k(kv_len * num_kv_heads * head_dim), v(kv_len * num_kv_heads * head_dim);
utils::vec_normal_(k);
utils::vec_normal_(v);
key.push_back(k);
value.push_back(v);
kv_last_page_len.push_back(last_page_len);
kv_indptr.push_back(kv_indptr.back() + num_pages);
for (size_t j = 0; j < num_pages; ++j) {
kv_indices.push_back(page_counter++);
}
}

kv_data.resize(page_counter * 2 * num_kv_heads * page_size * head_dim);
flashinfer::paged_kv_t<PageStorage::kIndices, kv_layout, T, int32_t> paged_kv_cpu(
num_kv_heads, page_size, head_dim, batch_size, kv_data.data(), kv_indices.data(),
kv_indptr.data(), kv_last_page_len.data());
cpu_reference::append_paged_kv_cache<kv_layout, T, int32_t>(paged_kv_cpu, key, value,
append_indptr);

// copy data to device
thrust::device_vector<T> kv_data_device(kv_data);
thrust::device_vector<int32_t> kv_indptr_device(kv_indptr);
thrust::device_vector<int32_t> kv_indices_device(kv_indices);
thrust::device_vector<int32_t> kv_last_page_len_device(kv_last_page_len);

// create paged_kv object
flashinfer::paged_kv_t<PageStorage::kIndices, kv_layout, T, int32_t> paged_kv = paged_kv_cpu;
paged_kv.data = thrust::raw_pointer_cast(kv_data_device.data());
paged_kv.indices = thrust::raw_pointer_cast(kv_indices_device.data());
paged_kv.indptr = thrust::raw_pointer_cast(kv_indptr_device.data());
paged_kv.last_page_len = thrust::raw_pointer_cast(kv_last_page_len_device.data());

std::vector<std::vector<T>> q, o_ref;
for (uint32_t request_idx = 0; request_idx < batch_size; ++request_idx) {
int32_t q_len = q_lens[request_idx];
std::vector<T> qi(q_len * num_qo_heads * head_dim);
utils::vec_normal_(qi);
q.push_back(qi);
}
for (uint32_t request_idx = 0; request_idx < batch_size; ++request_idx) {
int32_t q_len = q_lens[request_idx], kv_len = kv_lens[request_idx];
std::vector<T> o_ref_i = cpu_reference::single_mha<T, T, T>(
q[request_idx], key[request_idx], value[request_idx], q_len, kv_len, num_qo_heads,
num_kv_heads, head_dim, /*causal=*/false, QKVLayout::kNHD,
/*pos_encoding_mode*/ PosEncodingMode::kNone);
o_ref.push_back(o_ref_i);
}

std::vector<T> q_concat, o_concat_ref;
for (uint32_t request_idx = 0; request_idx < batch_size; ++request_idx) {
q_concat.insert(q_concat.end(), q[request_idx].begin(), q[request_idx].end());
o_concat_ref.insert(o_concat_ref.end(), o_ref[request_idx].begin(), o_ref[request_idx].end());
}
thrust::device_vector<T> q_device(q_concat);

thrust::device_vector<int32_t> q_indptr_device(q_indptr);
thrust::device_vector<T> o_device(o_concat_ref.size());

BatchPrefillHandler handler;
size_t workspace_size_in_bytes = 32 * 1024 * 1024;
thrust::device_vector<char> buffer(workspace_size_in_bytes);

handler.BeginForward<T, int32_t>((void*)thrust::raw_pointer_cast(buffer.data()),
workspace_size_in_bytes, q_indptr.data(), kv_indptr.data(),
batch_size, num_qo_heads, num_kv_heads, head_dim, page_size);

auto status =
BatchPrefillWithPagedKVCacheWrapper<PageStorage::kIndices, kv_layout, T, T, int32_t>(
&handler, thrust::raw_pointer_cast(q_device.data()),
thrust::raw_pointer_cast(q_indptr_device.data()),
/*q_offset=*/nullptr, paged_kv, thrust::raw_pointer_cast(o_device.data()),
/*lse=*/nullptr, num_qo_heads, /*causal=*/false,
/*pos_encoding_mode*/ PosEncodingMode::kNone);
EXPECT_EQ(status, cudaSuccess) << "CUDA error: " + std::string(cudaGetErrorString(status));

thrust::host_vector<T> o_host(o_device);
size_t num_result_errors_atol_1e_3_rtol_1e_3 = 0;
bool nan_detected = false;
for (size_t i = 0; i < o_concat_ref.size(); ++i) {
if (std::isnan(float(o_host[i]))) {
nan_detected = true;
}
num_result_errors_atol_1e_3_rtol_1e_3 +=
(!utils::isclose(float(o_host[i]), float(o_concat_ref[i]), 1e-3, 1e-3));
}
float result_accuracy =
1. - float(num_result_errors_atol_1e_3_rtol_1e_3) / max(float(o_concat_ref.size()), 1.f);
std::cout << "batch_size=" << batch_size << ", page_size=" << page_size
<< ", num_qo_heads=" << num_qo_heads << ", num_kv_heads=" << num_kv_heads
<< ", head_dim=" << head_dim << ", result_accuracy=" << result_accuracy << std::endl;
EXPECT_GT(result_accuracy, 0.99) << "Result correctness test failed.";
EXPECT_EQ(nan_detected, false) << "NaN detected in output.";
}

template <typename T>
void _TestBatchPagedPrefillKernelLongContextCorrectness(size_t num_kv_heads, size_t num_qo_heads,
size_t page_size, size_t head_dim,
Expand Down Expand Up @@ -505,6 +628,27 @@ void TestBatchPagedPrefillKernelLongContextCorrectness(bool allow_fp16_qk_reduct
}
}

template <typename T>
void TestBatchPagedPrefillKernelZeroContextCorrectness(bool allow_fp16_qk_reduction) {
for (size_t batch_size : {1, 4, 7, 11, 19, 37, 99}) {
for (size_t num_kv_heads : {1, 4}) {
for (size_t group_size : {1, 8}) {
size_t num_qo_heads = num_kv_heads * group_size;
for (size_t page_size : {1, 16}) {
for (size_t head_dim : {64, 128, 256}) {
for (size_t kv_len_max : {0, 3}) {
_TestBatchPagedPrefillKernelQMinMaxKVMinMaxCorrectness<T>(
batch_size, num_kv_heads, num_qo_heads, page_size, head_dim,
allow_fp16_qk_reduction,
/*q_len_min=*/1, /*q_len_max=*/3, /*kv_len_min=*/0, kv_len_max);
}
}
}
}
}
}
}

template <typename T>
void TestBatchRaggedPrefillKernelCorrectness(bool allow_fp16_qk_reduction) {
for (size_t num_kv_heads : {4, 8, 32}) {
Expand Down Expand Up @@ -534,6 +678,10 @@ TEST(FlashInferCorrectnessTest, BatchPagedPrefillLongContextTestFP16) {
TestBatchPagedPrefillKernelLongContextCorrectness<half>(false);
}

TEST(FlashInferCorrectnessTest, BatchPagedPrefillZeroContextTestFP16) {
TestBatchPagedPrefillKernelZeroContextCorrectness<half>(false);
}

TEST(FlashInferCorrectnessTest, BatchPagedPrefillLongContextTestFP16QKHalfAccum) {
TestBatchPagedPrefillKernelLongContextCorrectness<half>(true);
}
Expand Down

0 comments on commit 7adc8cf

Please # to comment.