diff --git a/src/bench_batch_decode.cu b/src/bench_batch_decode.cu index fc998ded..f9f931ea 100644 --- a/src/bench_batch_decode.cu +++ b/src/bench_batch_decode.cu @@ -28,7 +28,7 @@ using namespace flashinfer; constexpr QKVLayout kv_layout = QKVLayout::kNHD; -template +template void bench_flashinfer_batch_decode(nvbench::state& state) { constexpr size_t head_dim = 128; constexpr auto pos_encoding_mode = PosEncodingMode::kNone; @@ -51,12 +51,12 @@ void bench_flashinfer_batch_decode(nvbench::state& state) { kv_indptr_host.push_back(kv_indptr_host.back() + pages_per_seq); kv_last_page_len_host.push_back((seqlen - 1) % page_size + 1); } - thrust::device_vector k_data(num_pages * num_kv_heads * page_size * head_dim); - thrust::device_vector v_data(num_pages * num_kv_heads * page_size * head_dim); + thrust::device_vector k_data(num_pages * num_kv_heads * page_size * head_dim); + thrust::device_vector v_data(num_pages * num_kv_heads * page_size * head_dim); thrust::device_vector kv_indptr(kv_indptr_host); thrust::device_vector kv_indices(kv_indicies_host); thrust::device_vector kv_last_page_len(kv_last_page_len_host); - paged_kv_t paged_kv( + paged_kv_t paged_kv( num_kv_heads, page_size, head_dim, batch_size, kv_layout, thrust::raw_pointer_cast(k_data.data()), thrust::raw_pointer_cast(v_data.data()), thrust::raw_pointer_cast(kv_indices.data()), thrust::raw_pointer_cast(kv_indptr.data()), @@ -65,7 +65,7 @@ void bench_flashinfer_batch_decode(nvbench::state& state) { thrust::device_vector q(batch_size * num_qo_heads * head_dim); thrust::device_vector o(batch_size * num_qo_heads * head_dim); state.add_global_memory_reads( - vec_bytes(q) + (num_pages * 2 * num_kv_heads * page_size * head_dim) * sizeof(T) + + vec_bytes(q) + (num_pages * 2 * num_kv_heads * page_size * head_dim) * sizeof(TKV) + vec_bytes(kv_indptr) + vec_bytes(kv_indices) + vec_bytes(kv_last_page_len), "Read"); state.add_global_memory_writes(vec_bytes(o), "Write"); @@ -76,13 +76,13 @@ void bench_flashinfer_batch_decode(nvbench::state& state) { size_t int_workspace_size_in_bytes = 8 * 1024 * 1024; thrust::device_vector int_buffer(int_workspace_size_in_bytes); // begin forward - BatchDecodeHandlerPlan( + BatchDecodeHandlerPlan( &handler, (void*)thrust::raw_pointer_cast(float_buffer.data()), float_workspace_size_in_bytes, (void*)thrust::raw_pointer_cast(int_buffer.data()), int_workspace_size_in_bytes, kv_indptr_host.data(), kv_last_page_len_host.data(), batch_size, num_qo_heads, num_kv_heads, head_dim, page_size, pos_encoding_mode); state.exec([&](nvbench::launch&) { - cudaError_t status = BatchDecodeWithPagedKVCacheWrapper( + cudaError_t status = BatchDecodeWithPagedKVCacheWrapper( &handler, thrust::raw_pointer_cast(q.data()), /*q_offset=*/nullptr, paged_kv, thrust::raw_pointer_cast(o.data()), /*lse=*/nullptr, num_qo_heads, pos_encoding_mode); if (status != cudaSuccess) { @@ -91,7 +91,7 @@ void bench_flashinfer_batch_decode(nvbench::state& state) { }); } -template +template void bench_flashinfer_batch_decode_with_prefill(nvbench::state& state) { constexpr size_t head_dim = 128; constexpr auto pos_encoding_mode = PosEncodingMode::kNone; @@ -114,12 +114,12 @@ void bench_flashinfer_batch_decode_with_prefill(nvbench::state& state) { kv_indptr_host.push_back(kv_indptr_host.back() + pages_per_seq); kv_last_page_len_host.push_back((seqlen - 1) % page_size + 1); } - thrust::device_vector k_data(num_pages * num_kv_heads * page_size * head_dim); - thrust::device_vector v_data(num_pages * num_kv_heads * page_size * head_dim); + thrust::device_vector k_data(num_pages * num_kv_heads * page_size * head_dim); + thrust::device_vector v_data(num_pages * num_kv_heads * page_size * head_dim); thrust::device_vector kv_indptr(kv_indptr_host); thrust::device_vector kv_indices(kv_indicies_host); thrust::device_vector kv_last_page_len(kv_last_page_len_host); - paged_kv_t paged_kv( + paged_kv_t paged_kv( num_kv_heads, page_size, head_dim, batch_size, kv_layout, thrust::raw_pointer_cast(k_data.data()), thrust::raw_pointer_cast(v_data.data()), thrust::raw_pointer_cast(kv_indices.data()), thrust::raw_pointer_cast(kv_indptr.data()), @@ -134,7 +134,7 @@ void bench_flashinfer_batch_decode_with_prefill(nvbench::state& state) { } thrust::device_vector qo_indptr_d(qo_indptr_h); state.add_global_memory_reads( - vec_bytes(q) + (num_pages * 2 * num_kv_heads * page_size * head_dim) * sizeof(T) + + vec_bytes(q) + (num_pages * 2 * num_kv_heads * page_size * head_dim) * sizeof(TKV) + vec_bytes(kv_indptr) + vec_bytes(kv_indices) + vec_bytes(kv_last_page_len), "Read"); state.add_global_memory_writes(vec_bytes(o), "Write"); @@ -151,7 +151,7 @@ void bench_flashinfer_batch_decode_with_prefill(nvbench::state& state) { batch_size, num_qo_heads, num_kv_heads, head_dim, page_size); state.exec(nvbench::exec_tag::sync, [&](nvbench::launch&) { - cudaError_t status = BatchPrefillWithPagedKVCacheWrapper( + cudaError_t status = BatchPrefillWithPagedKVCacheWrapper( &handler, thrust::raw_pointer_cast(q.data()), thrust::raw_pointer_cast(qo_indptr_d.data()), /*q_offset=*/nullptr, paged_kv, thrust::raw_pointer_cast(o.data()), /*lse=*/nullptr, num_qo_heads, @@ -161,25 +161,22 @@ void bench_flashinfer_batch_decode_with_prefill(nvbench::state& state) { #define STR_HELPER(x) #x #define STR(x) STR_HELPER(x) -#define BENCH_FLASHINFER_BATCH_DECODE(dtype) \ - auto bench_flashinfer_batch_decode_##dtype##_ = bench_flashinfer_batch_decode; \ - NVBENCH_BENCH(bench_flashinfer_batch_decode_##dtype##_) \ - .set_name("bench_flashinfer_batch_decode_" STR(dtype)) \ - .add_int64_axis("seqlen", \ - {32, 64, 128, 256, 512, 1024, 2048, 4096, 8192, 16384, 32768, 65536}) \ - .add_int64_axis("batch_size", \ - {1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, \ - 15, 16, 20, 24, 28, 32, 40, 48, 56, 64, 80, 96, 112, 128, \ - 160, 192, 224, 256, 320, 384, 448, 512, 640, 768, 896, 1024}) \ - .add_int64_axis("page_size", {4, 8, 16, 32, 64}) \ - .add_int64_axis("num_qo_heads", {32}) \ +#define BENCH_FLASHINFER_BATCH_DECODE(dtype, dtypeKV) \ + auto bench_flashinfer_batch_decode_##dtype##_ = bench_flashinfer_batch_decode; \ + NVBENCH_BENCH(bench_flashinfer_batch_decode_##dtype##_) \ + .set_name("bench_flashinfer_batch_decode_" STR(dtype) STR(dtypeKV)) \ + .add_int64_axis("seqlen", \ + {32, 64, 128, 256, 512, 1024, 2048, 4096, 8192, 16384, 32768, 65536}) \ + .add_int64_axis("batch_size", {1, 2, 4, 8, 16, 32, 64, 128, 256, 512, 1024}) \ + .add_int64_axis("page_size", {16}) \ + .add_int64_axis("num_qo_heads", {32}) \ .add_int64_axis("num_kv_heads", {32, 4}) -#define BENCH_FLASHINFER_BATCH_DECODE_WITH_PREFILL(dtype) \ +#define BENCH_FLASHINFER_BATCH_DECODE_WITH_PREFILL(dtype, dtypeKV) \ auto bench_flashinfer_batch_decode_with_prefill_##dtype##_ = \ - bench_flashinfer_batch_decode_with_prefill; \ + bench_flashinfer_batch_decode_with_prefill; \ NVBENCH_BENCH(bench_flashinfer_batch_decode_with_prefill_##dtype##_) \ - .set_name("bench_flashinfer_batch_decode_with_prefill_" STR(dtype)) \ + .set_name("bench_flashinfer_batch_decode_with_prefill_" STR(dtype) STR(dtypeKV)) \ .add_int64_axis("seqlen", \ {32, 64, 128, 256, 512, 1024, 2048, 4096, 8192, 16384, 32768, 65536}) \ .add_int64_axis("batch_size", {1, 2, 4, 8, 16, 32, 64, 128, 256, 512, 1024}) \ @@ -187,5 +184,5 @@ void bench_flashinfer_batch_decode_with_prefill(nvbench::state& state) { .add_int64_axis("num_qo_heads", {32}) \ .add_int64_axis("num_kv_heads", {32, 4}) -BENCH_FLASHINFER_BATCH_DECODE(half); -BENCH_FLASHINFER_BATCH_DECODE_WITH_PREFILL(half); +kENCH_FLASHINFER_BATCH_DECODE(half, half); +BENCH_FLASHINFER_BATCH_DECODE_WITH_PREFILL(half, half);