Skip to content

Commit

Permalink
test: add DtypeKV template param in bench_batch_decode (#607)
Browse files Browse the repository at this point in the history
Add `typename TKV` so that it's convenient to benchmark FP8 KVCache. A
example output:

![image](https://github.com/user-attachments/assets/ca5f5e75-8b9e-49f8-ae06-31a95b08a9b4)

Co-authored-by: Zhenhuan Chen <zhenhuanc@nvidia.com>
  • Loading branch information
dc3671 and Zhenhuan Chen authored Nov 13, 2024
1 parent be10bbd commit 45e9273
Showing 1 changed file with 27 additions and 30 deletions.
57 changes: 27 additions & 30 deletions src/bench_batch_decode.cu
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@ using namespace flashinfer;

constexpr QKVLayout kv_layout = QKVLayout::kNHD;

template <typename T>
template <typename T, typename TKV>
void bench_flashinfer_batch_decode(nvbench::state& state) {
constexpr size_t head_dim = 128;
constexpr auto pos_encoding_mode = PosEncodingMode::kNone;
Expand All @@ -51,12 +51,12 @@ void bench_flashinfer_batch_decode(nvbench::state& state) {
kv_indptr_host.push_back(kv_indptr_host.back() + pages_per_seq);
kv_last_page_len_host.push_back((seqlen - 1) % page_size + 1);
}
thrust::device_vector<T> k_data(num_pages * num_kv_heads * page_size * head_dim);
thrust::device_vector<T> v_data(num_pages * num_kv_heads * page_size * head_dim);
thrust::device_vector<TKV> k_data(num_pages * num_kv_heads * page_size * head_dim);
thrust::device_vector<TKV> v_data(num_pages * num_kv_heads * page_size * head_dim);
thrust::device_vector<int32_t> kv_indptr(kv_indptr_host);
thrust::device_vector<int32_t> kv_indices(kv_indicies_host);
thrust::device_vector<int32_t> kv_last_page_len(kv_last_page_len_host);
paged_kv_t<T, int32_t> paged_kv(
paged_kv_t<TKV, int32_t> paged_kv(
num_kv_heads, page_size, head_dim, batch_size, kv_layout,
thrust::raw_pointer_cast(k_data.data()), thrust::raw_pointer_cast(v_data.data()),
thrust::raw_pointer_cast(kv_indices.data()), thrust::raw_pointer_cast(kv_indptr.data()),
Expand All @@ -65,7 +65,7 @@ void bench_flashinfer_batch_decode(nvbench::state& state) {
thrust::device_vector<T> q(batch_size * num_qo_heads * head_dim);
thrust::device_vector<T> o(batch_size * num_qo_heads * head_dim);
state.add_global_memory_reads<uint8_t>(
vec_bytes(q) + (num_pages * 2 * num_kv_heads * page_size * head_dim) * sizeof(T) +
vec_bytes(q) + (num_pages * 2 * num_kv_heads * page_size * head_dim) * sizeof(TKV) +
vec_bytes(kv_indptr) + vec_bytes(kv_indices) + vec_bytes(kv_last_page_len),
"Read");
state.add_global_memory_writes<uint8_t>(vec_bytes(o), "Write");
Expand All @@ -76,13 +76,13 @@ void bench_flashinfer_batch_decode(nvbench::state& state) {
size_t int_workspace_size_in_bytes = 8 * 1024 * 1024;
thrust::device_vector<char> int_buffer(int_workspace_size_in_bytes);
// begin forward
BatchDecodeHandlerPlan<T, T, T, int32_t>(
BatchDecodeHandlerPlan<T, TKV, T, int32_t>(
&handler, (void*)thrust::raw_pointer_cast(float_buffer.data()), float_workspace_size_in_bytes,
(void*)thrust::raw_pointer_cast(int_buffer.data()), int_workspace_size_in_bytes,
kv_indptr_host.data(), kv_last_page_len_host.data(), batch_size, num_qo_heads, num_kv_heads,
head_dim, page_size, pos_encoding_mode);
state.exec([&](nvbench::launch&) {
cudaError_t status = BatchDecodeWithPagedKVCacheWrapper<T, T, T, int32_t>(
cudaError_t status = BatchDecodeWithPagedKVCacheWrapper<T, TKV, T, int32_t>(
&handler, thrust::raw_pointer_cast(q.data()), /*q_offset=*/nullptr, paged_kv,
thrust::raw_pointer_cast(o.data()), /*lse=*/nullptr, num_qo_heads, pos_encoding_mode);
if (status != cudaSuccess) {
Expand All @@ -91,7 +91,7 @@ void bench_flashinfer_batch_decode(nvbench::state& state) {
});
}

template <typename T>
template <typename T, typename TKV>
void bench_flashinfer_batch_decode_with_prefill(nvbench::state& state) {
constexpr size_t head_dim = 128;
constexpr auto pos_encoding_mode = PosEncodingMode::kNone;
Expand All @@ -114,12 +114,12 @@ void bench_flashinfer_batch_decode_with_prefill(nvbench::state& state) {
kv_indptr_host.push_back(kv_indptr_host.back() + pages_per_seq);
kv_last_page_len_host.push_back((seqlen - 1) % page_size + 1);
}
thrust::device_vector<T> k_data(num_pages * num_kv_heads * page_size * head_dim);
thrust::device_vector<T> v_data(num_pages * num_kv_heads * page_size * head_dim);
thrust::device_vector<TKV> k_data(num_pages * num_kv_heads * page_size * head_dim);
thrust::device_vector<TKV> v_data(num_pages * num_kv_heads * page_size * head_dim);
thrust::device_vector<int32_t> kv_indptr(kv_indptr_host);
thrust::device_vector<int32_t> kv_indices(kv_indicies_host);
thrust::device_vector<int32_t> kv_last_page_len(kv_last_page_len_host);
paged_kv_t<T, int32_t> paged_kv(
paged_kv_t<TKV, int32_t> paged_kv(
num_kv_heads, page_size, head_dim, batch_size, kv_layout,
thrust::raw_pointer_cast(k_data.data()), thrust::raw_pointer_cast(v_data.data()),
thrust::raw_pointer_cast(kv_indices.data()), thrust::raw_pointer_cast(kv_indptr.data()),
Expand All @@ -134,7 +134,7 @@ void bench_flashinfer_batch_decode_with_prefill(nvbench::state& state) {
}
thrust::device_vector<int32_t> qo_indptr_d(qo_indptr_h);
state.add_global_memory_reads<uint8_t>(
vec_bytes(q) + (num_pages * 2 * num_kv_heads * page_size * head_dim) * sizeof(T) +
vec_bytes(q) + (num_pages * 2 * num_kv_heads * page_size * head_dim) * sizeof(TKV) +
vec_bytes(kv_indptr) + vec_bytes(kv_indices) + vec_bytes(kv_last_page_len),
"Read");
state.add_global_memory_writes<uint8_t>(vec_bytes(o), "Write");
Expand All @@ -151,7 +151,7 @@ void bench_flashinfer_batch_decode_with_prefill(nvbench::state& state) {
batch_size, num_qo_heads, num_kv_heads, head_dim, page_size);

state.exec(nvbench::exec_tag::sync, [&](nvbench::launch&) {
cudaError_t status = BatchPrefillWithPagedKVCacheWrapper<T, T, T, int32_t>(
cudaError_t status = BatchPrefillWithPagedKVCacheWrapper<T, TKV, T, int32_t>(
&handler, thrust::raw_pointer_cast(q.data()), thrust::raw_pointer_cast(qo_indptr_d.data()),
/*q_offset=*/nullptr, paged_kv, thrust::raw_pointer_cast(o.data()),
/*lse=*/nullptr, num_qo_heads,
Expand All @@ -161,31 +161,28 @@ void bench_flashinfer_batch_decode_with_prefill(nvbench::state& state) {

#define STR_HELPER(x) #x
#define STR(x) STR_HELPER(x)
#define BENCH_FLASHINFER_BATCH_DECODE(dtype) \
auto bench_flashinfer_batch_decode_##dtype##_ = bench_flashinfer_batch_decode<dtype>; \
NVBENCH_BENCH(bench_flashinfer_batch_decode_##dtype##_) \
.set_name("bench_flashinfer_batch_decode_" STR(dtype)) \
.add_int64_axis("seqlen", \
{32, 64, 128, 256, 512, 1024, 2048, 4096, 8192, 16384, 32768, 65536}) \
.add_int64_axis("batch_size", \
{1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, \
15, 16, 20, 24, 28, 32, 40, 48, 56, 64, 80, 96, 112, 128, \
160, 192, 224, 256, 320, 384, 448, 512, 640, 768, 896, 1024}) \
.add_int64_axis("page_size", {4, 8, 16, 32, 64}) \
.add_int64_axis("num_qo_heads", {32}) \
#define BENCH_FLASHINFER_BATCH_DECODE(dtype, dtypeKV) \
auto bench_flashinfer_batch_decode_##dtype##_ = bench_flashinfer_batch_decode<dtype, dtypeKV>; \
NVBENCH_BENCH(bench_flashinfer_batch_decode_##dtype##_) \
.set_name("bench_flashinfer_batch_decode_" STR(dtype) STR(dtypeKV)) \
.add_int64_axis("seqlen", \
{32, 64, 128, 256, 512, 1024, 2048, 4096, 8192, 16384, 32768, 65536}) \
.add_int64_axis("batch_size", {1, 2, 4, 8, 16, 32, 64, 128, 256, 512, 1024}) \
.add_int64_axis("page_size", {16}) \
.add_int64_axis("num_qo_heads", {32}) \
.add_int64_axis("num_kv_heads", {32, 4})

#define BENCH_FLASHINFER_BATCH_DECODE_WITH_PREFILL(dtype) \
#define BENCH_FLASHINFER_BATCH_DECODE_WITH_PREFILL(dtype, dtypeKV) \
auto bench_flashinfer_batch_decode_with_prefill_##dtype##_ = \
bench_flashinfer_batch_decode_with_prefill<dtype>; \
bench_flashinfer_batch_decode_with_prefill<dtype, dtypeKV>; \
NVBENCH_BENCH(bench_flashinfer_batch_decode_with_prefill_##dtype##_) \
.set_name("bench_flashinfer_batch_decode_with_prefill_" STR(dtype)) \
.set_name("bench_flashinfer_batch_decode_with_prefill_" STR(dtype) STR(dtypeKV)) \
.add_int64_axis("seqlen", \
{32, 64, 128, 256, 512, 1024, 2048, 4096, 8192, 16384, 32768, 65536}) \
.add_int64_axis("batch_size", {1, 2, 4, 8, 16, 32, 64, 128, 256, 512, 1024}) \
.add_int64_axis("page_size", {16}) \
.add_int64_axis("num_qo_heads", {32}) \
.add_int64_axis("num_kv_heads", {32, 4})

BENCH_FLASHINFER_BATCH_DECODE(half);
BENCH_FLASHINFER_BATCH_DECODE_WITH_PREFILL(half);
kENCH_FLASHINFER_BATCH_DECODE(half, half);
BENCH_FLASHINFER_BATCH_DECODE_WITH_PREFILL(half, half);

0 comments on commit 45e9273

Please # to comment.