From a7ee5662bf967ab1ee16910c73761d326fbeb9a0 Mon Sep 17 00:00:00 2001 From: Zihao Ye Date: Tue, 13 Aug 2024 03:02:38 -0700 Subject: [PATCH] feat: decouple float and int workspace buffer (#442) Before this PR, flashinfer coupled float and int buffers in a single workspace buffer, and different wrappers cannot share the same buffers. This PR decouples float and int workspace buffer. The float workspace buffer (large) can be shared in multiple wrappers, and the int buffer (small) is unique for each wrapper. This PR can save GPU memory when multiple wrappers are created (decode, prefill paged, prefill ragged) or cascade inference. --- include/flashinfer/attention/handler.cuh | 149 ++++++++++++----------- python/csrc/activation.cu | 3 +- python/csrc/batch_decode.cu | 38 +++--- python/csrc/batch_prefill.cu | 54 +++++--- python/csrc/flashinfer_ops_decode.h | 10 +- python/csrc/flashinfer_ops_prefill.h | 15 +-- python/flashinfer/cascade.py | 52 +++++--- python/flashinfer/decode.py | 39 ++++-- python/flashinfer/prefill.py | 72 +++++++---- python/flashinfer/sparse.py | 39 ++++-- python/setup.py | 4 +- src/bench_batch_decode.cu | 26 ++-- src/bench_batch_prefill.cu | 15 ++- src/bench_cascade.cu | 52 +++++--- src/flashinfer_ops.cuh | 9 +- src/test_batch_decode.cu | 9 +- src/test_batch_prefill.cu | 59 ++++++--- src/test_cascade.cu | 44 ++++--- src/tvm_wrapper.cu | 53 +++++--- 19 files changed, 467 insertions(+), 275 deletions(-) diff --git a/include/flashinfer/attention/handler.cuh b/include/flashinfer/attention/handler.cuh index e89c762e..0d721314 100644 --- a/include/flashinfer/attention/handler.cuh +++ b/include/flashinfer/attention/handler.cuh @@ -309,8 +309,9 @@ class BatchDecodeHandler { template - cudaError_t BeginForwardDispatched(void* buffer, size_t workspace_size_in_bytes, IdType* indptr_h, - IdType* last_page_len_h, uint32_t batch_size, + cudaError_t BeginForwardDispatched(void* float_buffer, size_t float_workspace_size_in_bytes, + void* int_buffer, size_t int_workspace_size_in_bytes, + IdType* indptr_h, IdType* last_page_len_h, uint32_t batch_size, uint32_t num_qo_heads, uint32_t num_kv_heads, uint32_t page_size) { batch_size_before_partition_ = batch_size; @@ -337,44 +338,45 @@ class BatchDecodeHandler { size_t padded_batch_size = max_grid_size / num_kv_heads; if (split_kv) { padded_batch_size_ = padded_batch_size; - AlignedAllocator allocator(buffer, workspace_size_in_bytes); - tmp_v_ = allocator.aligned_alloc( + AlignedAllocator float_allocator(float_buffer, float_workspace_size_in_bytes); + tmp_v_ = float_allocator.aligned_alloc( num_qo_heads * padded_batch_size * HEAD_DIM * sizeof(DTypeOut), 16, "batch_decode_tmp_v"); - tmp_s_ = allocator.aligned_alloc(num_qo_heads * padded_batch_size * sizeof(float), - 16, "batch_decode_tmp_s"); - new_indptr_ = allocator.aligned_alloc((padded_batch_size + 1) * sizeof(IdType), 16, - "batch_decode_new_indptr"); + tmp_s_ = float_allocator.aligned_alloc( + num_qo_heads * padded_batch_size * sizeof(float), 16, "batch_decode_tmp_s"); + AlignedAllocator int_allocator(int_buffer, int_workspace_size_in_bytes); + new_indptr_ = int_allocator.aligned_alloc((padded_batch_size + 1) * sizeof(IdType), + 16, "batch_decode_new_indptr"); void* new_indptr_h_ = page_locked_buffer_; - new_last_page_len_ = allocator.aligned_alloc(padded_batch_size * sizeof(IdType), 16, - "batch_decode_new_last_page_len"); + new_last_page_len_ = int_allocator.aligned_alloc( + padded_batch_size * sizeof(IdType), 16, "batch_decode_new_last_page_len"); void* new_last_page_len_h_ = (char*)page_locked_buffer_ + ((char*)new_last_page_len_ - (char*)new_indptr_); - chunk_indptr_ = allocator.aligned_alloc((padded_batch_size + 1) * sizeof(IdType), - 16, "batch_decode_chunk_indptr"); + chunk_indptr_ = int_allocator.aligned_alloc( + (padded_batch_size + 1) * sizeof(IdType), 16, "batch_decode_chunk_indptr"); void* chunk_indptr_h_ = (char*)page_locked_buffer_ + ((char*)chunk_indptr_ - (char*)new_indptr_); - batch_idx_map_ = allocator.aligned_alloc(padded_batch_size * sizeof(IdType), 16, - "batch_decode_batch_idx_map"); + batch_idx_map_ = int_allocator.aligned_alloc(padded_batch_size * sizeof(IdType), 16, + "batch_decode_batch_idx_map"); void* batch_idx_map_h_ = (char*)page_locked_buffer_ + ((char*)batch_idx_map_ - (char*)new_indptr_); - chunk_start_pos_ = allocator.aligned_alloc(padded_batch_size * sizeof(IdType), 16, - "batch_decode_chunk_start_pos"); + chunk_start_pos_ = int_allocator.aligned_alloc(padded_batch_size * sizeof(IdType), + 16, "batch_decode_chunk_start_pos"); void* chunk_start_pos_h_ = (char*)page_locked_buffer_ + ((char*)chunk_start_pos_ - (char*)new_indptr_); - seq_lengths_before_partition_ = allocator.aligned_alloc( + seq_lengths_before_partition_ = int_allocator.aligned_alloc( padded_batch_size * sizeof(IdType), 16, "batch_decode_seq_lengths_before_partition"); void* seq_lengths_before_partition_h_ = (char*)page_locked_buffer_ + ((char*)seq_lengths_before_partition_ - (char*)new_indptr_); - block_valid_mask_ = allocator.aligned_alloc(padded_batch_size * sizeof(bool), 16, - "batch_decode_block_valid_mask"); + block_valid_mask_ = int_allocator.aligned_alloc( + padded_batch_size * sizeof(bool), 16, "batch_decode_block_valid_mask"); bool* block_valid_mask_h_ = (bool*)page_locked_buffer_ + ((bool*)block_valid_mask_ - (bool*)new_indptr_); std::fill(block_valid_mask_h_, block_valid_mask_h_ + padded_batch_size, 0); - size_t num_bytes_to_copy = (char*)allocator.ptr - (char*)new_indptr_; + size_t num_bytes_to_copy = (char*)int_allocator.ptr - (char*)new_indptr_; FLASHINFER_CUDA_CALL(PartitionPagedKVCacheComputeAuxiliaryInfo( max_num_pages_per_batch, batch_size, padded_batch_size, page_size, indptr_h, last_page_len_h, (IdType*)new_indptr_h_, (IdType*)new_last_page_len_h_, @@ -392,38 +394,39 @@ class BatchDecodeHandler { // do not pad the batch size when not using CUDAGraph padded_batch_size_ = batch_size_after_partition_; if (split_kv) { - AlignedAllocator allocator(buffer, workspace_size_in_bytes); - tmp_v_ = allocator.aligned_alloc( + AlignedAllocator float_allocator(float_buffer, float_workspace_size_in_bytes); + tmp_v_ = float_allocator.aligned_alloc( num_qo_heads * new_batch_size * HEAD_DIM * sizeof(DTypeOut), 16, "batch_decode_tmp_v"); - tmp_s_ = allocator.aligned_alloc(num_qo_heads * new_batch_size * sizeof(float), 16, - "batch_decode_tmp_s"); - new_indptr_ = allocator.aligned_alloc( + tmp_s_ = float_allocator.aligned_alloc( + num_qo_heads * new_batch_size * sizeof(float), 16, "batch_decode_tmp_s"); + AlignedAllocator int_allocator(int_buffer, int_workspace_size_in_bytes); + new_indptr_ = int_allocator.aligned_alloc( (batch_size_after_partition_ + 1) * sizeof(IdType), 16, "batch_decode_new_indptr"); void* new_indptr_h_ = page_locked_buffer_; - new_last_page_len_ = allocator.aligned_alloc( + new_last_page_len_ = int_allocator.aligned_alloc( batch_size_after_partition_ * sizeof(IdType), 16, "batch_decode_new_last_page_len"); void* new_last_page_len_h_ = (char*)page_locked_buffer_ + ((char*)new_last_page_len_ - (char*)new_indptr_); - chunk_indptr_ = allocator.aligned_alloc( + chunk_indptr_ = int_allocator.aligned_alloc( (batch_size_before_partition_ + 1) * sizeof(IdType), 16, "batch_decode_chunk_indptr"); void* chunk_indptr_h_ = (char*)page_locked_buffer_ + ((char*)chunk_indptr_ - (char*)new_indptr_); - batch_idx_map_ = allocator.aligned_alloc( + batch_idx_map_ = int_allocator.aligned_alloc( batch_size_after_partition_ * sizeof(IdType), 16, "batch_decode_batch_idx_map"); void* batch_idx_map_h_ = (char*)page_locked_buffer_ + ((char*)batch_idx_map_ - (char*)new_indptr_); - chunk_start_pos_ = allocator.aligned_alloc( + chunk_start_pos_ = int_allocator.aligned_alloc( batch_size_after_partition_ * sizeof(IdType), 16, "batch_decode_chunk_start_pos"); void* chunk_start_pos_h_ = (char*)page_locked_buffer_ + ((char*)chunk_start_pos_ - (char*)new_indptr_); seq_lengths_before_partition_ = - allocator.aligned_alloc(batch_size_after_partition_ * sizeof(IdType), 16, - "batch_decode_seq_lengths_before_partition"); + int_allocator.aligned_alloc(batch_size_after_partition_ * sizeof(IdType), 16, + "batch_decode_seq_lengths_before_partition"); void* seq_lengths_before_partition_h_ = (char*)page_locked_buffer_ + ((char*)seq_lengths_before_partition_ - (char*)new_indptr_); - size_t num_bytes_to_copy = (char*)allocator.ptr - (char*)new_indptr_; + size_t num_bytes_to_copy = (char*)int_allocator.ptr - (char*)new_indptr_; FLASHINFER_CUDA_CALL(PartitionPagedKVCacheComputeAuxiliaryInfo( max_num_pages_per_batch, batch_size, batch_size_after_partition_, page_size, indptr_h, last_page_len_h, (IdType*)new_indptr_h_, (IdType*)new_last_page_len_h_, @@ -458,9 +461,9 @@ class BatchDecodeHandler { bool IsForwardStarted() const { return forward_started_; } - void UpdatePageLockedBufferSize(size_t max_workspace_size_in_bytes) { + void UpdatePageLockedBufferSize(size_t int_workspace_size_in_bytes) { cudaFreeHost(page_locked_buffer_); - cudaMallocHost(&page_locked_buffer_, max_workspace_size_in_bytes); + cudaMallocHost(&page_locked_buffer_, int_workspace_size_in_bytes); } uint32_t GetBatchSizeBeforePartition() const { return batch_size_before_partition_; } @@ -655,15 +658,17 @@ class BatchPrefillHandler { bool IsForwardStarted() const { return request_indices_ != nullptr; } - void UpdatePageLockedBufferSize(size_t max_workspace_size_in_bytes) { + void UpdatePageLockedBufferSize(size_t int_workspace_size_in_bytes) { cudaFreeHost(page_locked_buffer_); - cudaMallocHost(&page_locked_buffer_, max_workspace_size_in_bytes); + cudaMallocHost(&page_locked_buffer_, int_workspace_size_in_bytes); } template - cudaError_t BeginForward(void* buffer, size_t workspace_size_in_bytes, IdType* qo_indptr_h, - IdType* kv_indptr_h, uint32_t batch_size, uint32_t num_qo_heads, - uint32_t num_kv_heads, uint32_t head_dim, uint32_t page_size) { + cudaError_t BeginForward(void* float_buffer, size_t float_workspace_size_in_bytes, + void* int_buffer, size_t int_workspace_size_in_bytes, + IdType* qo_indptr_h, IdType* kv_indptr_h, uint32_t batch_size, + uint32_t num_qo_heads, uint32_t num_kv_heads, uint32_t head_dim, + uint32_t page_size) { if (num_qo_heads % num_kv_heads != 0) { std::ostringstream err_msg; err_msg << "num_qo_heads " << num_qo_heads << " should be divisible by num_kv_heads " @@ -683,35 +688,35 @@ class BatchPrefillHandler { if (IsCUDAGraphEnabled()) { padded_batch_size_ = std::max(split_max_batch_size, total_num_tiles_q); - AlignedAllocator allocator(buffer, workspace_size_in_bytes); - request_indices_ = allocator.aligned_alloc(sizeof(IdType) * padded_batch_size_, 16, - "batch_prefill_request_indices"); + AlignedAllocator int_allocator(int_buffer, int_workspace_size_in_bytes); + request_indices_ = int_allocator.aligned_alloc(sizeof(IdType) * padded_batch_size_, 16, + "batch_prefill_request_indices"); void* request_indices_h_ = page_locked_buffer_; - qo_tile_indices_ = allocator.aligned_alloc(sizeof(IdType) * padded_batch_size_, 16, - "batch_prefill_qo_tile_indices"); + qo_tile_indices_ = int_allocator.aligned_alloc(sizeof(IdType) * padded_batch_size_, 16, + "batch_prefill_qo_tile_indices"); void* qo_tile_indices_h_ = (char*)page_locked_buffer_ + ((char*)qo_tile_indices_ - (char*)request_indices_); - kv_tile_indices_ = allocator.aligned_alloc(sizeof(IdType) * padded_batch_size_, 16, - "batch_prefill_kv_tile_indices"); + kv_tile_indices_ = int_allocator.aligned_alloc(sizeof(IdType) * padded_batch_size_, 16, + "batch_prefill_kv_tile_indices"); void* kv_tile_indices_h_ = (char*)page_locked_buffer_ + ((char*)kv_tile_indices_ - (char*)request_indices_); - o_indptr_ = allocator.aligned_alloc(sizeof(IdType) * (batch_size + 1), 16, - "batch_prefill_o_indptr"); + o_indptr_ = int_allocator.aligned_alloc(sizeof(IdType) * (batch_size + 1), 16, + "batch_prefill_o_indptr"); void* o_indptr_h_ = (char*)page_locked_buffer_ + ((char*)o_indptr_ - (char*)request_indices_); kv_chunk_size_ptr_ = - allocator.aligned_alloc(sizeof(IdType), 1, "batch_prefill_kv_chunk_size_ptr"); + int_allocator.aligned_alloc(sizeof(IdType), 1, "batch_prefill_kv_chunk_size_ptr"); void* kv_chunk_size_ptr_h_ = (char*)page_locked_buffer_ + ((char*)kv_chunk_size_ptr_ - (char*)request_indices_); *(IdType*)kv_chunk_size_ptr_h_ = kv_chunk_size; if (total_num_tiles_q < split_max_batch_size) { // need merge_indptr - merge_indptr_ = allocator.aligned_alloc(sizeof(IdType) * (total_num_rows_ + 1), 16, - "batch_prefill_merge_indptr"); + merge_indptr_ = int_allocator.aligned_alloc(sizeof(IdType) * (total_num_rows_ + 1), + 16, "batch_prefill_merge_indptr"); void* merge_indptr_h_ = (char*)page_locked_buffer_ + ((char*)merge_indptr_ - (char*)request_indices_); std::copy(merge_indptr_vec.begin(), merge_indptr_vec.end(), (IdType*)merge_indptr_h_); - block_valid_mask_ = allocator.aligned_alloc(sizeof(bool) * padded_batch_size_, 16, - "batch_prefill_block_valid_mask"); + block_valid_mask_ = int_allocator.aligned_alloc(sizeof(bool) * padded_batch_size_, 16, + "batch_prefill_block_valid_mask"); bool* block_valid_mask_h_ = (bool*)page_locked_buffer_ + ((bool*)block_valid_mask_ - (bool*)request_indices_); for (uint32_t i = 0; i < padded_batch_size_; ++i) { @@ -731,15 +736,16 @@ class BatchPrefillHandler { (IdType*)kv_tile_indices_h_); std::copy(o_indptr_vec.begin(), o_indptr_vec.end(), (IdType*)o_indptr_h_); - size_t num_bytes_to_copy = (char*)allocator.ptr - (char*)request_indices_; + size_t num_bytes_to_copy = (char*)int_allocator.ptr - (char*)request_indices_; FLASHINFER_CUDA_CALL(cudaMemcpyAsync(request_indices_, page_locked_buffer_, num_bytes_to_copy, cudaMemcpyHostToDevice, stream_)) if (total_num_tiles_q < split_max_batch_size) { - tmp_v_ = allocator.aligned_alloc( + AlignedAllocator float_allocator(float_buffer, float_workspace_size_in_bytes); + tmp_v_ = float_allocator.aligned_alloc( num_qo_heads * split_max_batch_size * qo_tile_size * head_dim * sizeof(DTypeOut), 16, "batch_prefill_tmp_v"); - tmp_s_ = allocator.aligned_alloc( + tmp_s_ = float_allocator.aligned_alloc( num_qo_heads * split_max_batch_size * qo_tile_size * sizeof(float), 16, "batch_prefill_tmp_s"); } else { @@ -748,31 +754,31 @@ class BatchPrefillHandler { } } else { padded_batch_size_ = new_batch_size; - AlignedAllocator allocator(buffer, workspace_size_in_bytes); - request_indices_ = allocator.aligned_alloc(sizeof(IdType) * request_indices_vec.size(), - 16, "batch_prefill_request_indices"); + AlignedAllocator int_allocator(int_buffer, int_workspace_size_in_bytes); + request_indices_ = int_allocator.aligned_alloc( + sizeof(IdType) * request_indices_vec.size(), 16, "batch_prefill_request_indices"); void* request_indices_h_ = page_locked_buffer_; - qo_tile_indices_ = allocator.aligned_alloc(sizeof(IdType) * qo_tile_indices_vec.size(), - 16, "batch_prefill_qo_tile_indices"); + qo_tile_indices_ = int_allocator.aligned_alloc( + sizeof(IdType) * qo_tile_indices_vec.size(), 16, "batch_prefill_qo_tile_indices"); void* qo_tile_indices_h_ = (char*)page_locked_buffer_ + ((char*)qo_tile_indices_ - (char*)request_indices_); - kv_tile_indices_ = allocator.aligned_alloc(sizeof(IdType) * kv_tile_indices_vec.size(), - 16, "batch_prefill_kv_tile_indices"); + kv_tile_indices_ = int_allocator.aligned_alloc( + sizeof(IdType) * kv_tile_indices_vec.size(), 16, "batch_prefill_kv_tile_indices"); void* kv_tile_indices_h_ = (char*)page_locked_buffer_ + ((char*)kv_tile_indices_ - (char*)request_indices_); if (split_kv) { // need merge_indptr when split_kv is true - merge_indptr_ = allocator.aligned_alloc(sizeof(IdType) * merge_indptr_vec.size(), 16, - "batch_prefill_merge_indptr"); + merge_indptr_ = int_allocator.aligned_alloc(sizeof(IdType) * merge_indptr_vec.size(), + 16, "batch_prefill_merge_indptr"); void* merge_indptr_h_ = (char*)page_locked_buffer_ + ((char*)merge_indptr_ - (char*)request_indices_); std::copy(merge_indptr_vec.begin(), merge_indptr_vec.end(), (IdType*)merge_indptr_h_); } - o_indptr_ = allocator.aligned_alloc(sizeof(IdType) * o_indptr_vec.size(), 16, - "batch_prefill_o_indptr"); + o_indptr_ = int_allocator.aligned_alloc(sizeof(IdType) * o_indptr_vec.size(), 16, + "batch_prefill_o_indptr"); void* o_indptr_h_ = (char*)page_locked_buffer_ + ((char*)o_indptr_ - (char*)request_indices_); kv_chunk_size_ptr_ = - allocator.aligned_alloc(sizeof(IdType), 1, "batch_prefill_kv_chunk_size_ptr"); + int_allocator.aligned_alloc(sizeof(IdType), 1, "batch_prefill_kv_chunk_size_ptr"); void* kv_chunk_size_ptr_h_ = (char*)page_locked_buffer_ + ((char*)kv_chunk_size_ptr_ - (char*)request_indices_); *(IdType*)kv_chunk_size_ptr_h_ = kv_chunk_size; @@ -783,16 +789,17 @@ class BatchPrefillHandler { std::copy(kv_tile_indices_vec.begin(), kv_tile_indices_vec.end(), (IdType*)kv_tile_indices_h_); std::copy(o_indptr_vec.begin(), o_indptr_vec.end(), (IdType*)o_indptr_h_); - size_t num_bytes_to_copy = (char*)allocator.ptr - (char*)request_indices_; + size_t num_bytes_to_copy = (char*)int_allocator.ptr - (char*)request_indices_; FLASHINFER_CUDA_CALL(cudaMemcpyAsync(request_indices_, page_locked_buffer_, num_bytes_to_copy, cudaMemcpyHostToDevice, stream_)) if (split_kv) { - tmp_v_ = allocator.aligned_alloc( + AlignedAllocator float_allocator(float_buffer, float_workspace_size_in_bytes); + tmp_v_ = float_allocator.aligned_alloc( num_qo_heads * new_batch_size * qo_tile_size * head_dim * sizeof(DTypeOut), 16, "batch_prefill_tmp_v"); - tmp_s_ = allocator.aligned_alloc( + tmp_s_ = float_allocator.aligned_alloc( num_qo_heads * new_batch_size * qo_tile_size * sizeof(float), 16, "batch_prefill_tmp_s"); } else { diff --git a/python/csrc/activation.cu b/python/csrc/activation.cu index e4dcf7c0..4830334b 100644 --- a/python/csrc/activation.cu +++ b/python/csrc/activation.cu @@ -51,8 +51,7 @@ void gelu_tanh_and_mul(torch::Tensor& out, torch::Tensor& input) { DISPATCH_PYTORCH_DTYPE_TO_CTYPE_FP16(input.scalar_type(), c_type, [&] { uint32_t vec_size = 16 / sizeof(c_type); dim3 block(std::min(d / vec_size, 1024U)); - flashinfer::activation::act_and_mul_kernel + flashinfer::activation::act_and_mul_kernel <<>>(static_cast(out.data_ptr()), static_cast(input.data_ptr()), d); diff --git a/python/csrc/batch_decode.cu b/python/csrc/batch_decode.cu index 94365376..92ce222f 100644 --- a/python/csrc/batch_decode.cu +++ b/python/csrc/batch_decode.cu @@ -21,22 +21,28 @@ using namespace flashinfer; void BatchDecodeWithPagedKVCachePyTorchWrapper::BeginForward( - torch::Tensor workspace_buffer, torch::Tensor indptr, torch::Tensor last_page_len, - unsigned int batch_size, unsigned int num_qo_heads, unsigned int num_kv_heads, - unsigned int head_dim, unsigned int page_size, unsigned int pos_encoding_mode, - float logits_soft_cap, torch::Tensor empty_q_data, torch::Tensor empty_kv_data) { - CHECK_INPUT(workspace_buffer); + torch::Tensor float_workspace_buffer, torch::Tensor int_workspace_buffer, torch::Tensor indptr, + torch::Tensor last_page_len, unsigned int batch_size, unsigned int num_qo_heads, + unsigned int num_kv_heads, unsigned int head_dim, unsigned int page_size, + unsigned int pos_encoding_mode, float logits_soft_cap, torch::Tensor empty_q_data, + torch::Tensor empty_kv_data) { + CHECK_INPUT(float_workspace_buffer); + CHECK_INPUT(int_workspace_buffer); // NOTE(zihao): not necessary to be CUDA tensor CHECK_CONTIGUOUS(indptr); CHECK_CONTIGUOUS(last_page_len); CHECK_DIM(1, indptr); CHECK_DIM(1, last_page_len); - CHECK_DIM(1, workspace_buffer); + CHECK_DIM(1, float_workspace_buffer); + CHECK_DIM(1, int_workspace_buffer); CHECK_EQ(indptr.scalar_type(), torch::kInt32); CHECK_EQ(indptr.scalar_type(), torch::kInt32); CHECK_GQA_HEAD_DIVISIBLE(num_qo_heads, num_kv_heads); - size_t workspace_size_in_bytes = workspace_buffer.size(0) * workspace_buffer.element_size(); - auto device = workspace_buffer.device(); + size_t float_workspace_size_in_bytes = + float_workspace_buffer.size(0) * float_workspace_buffer.element_size(); + size_t int_workspace_size_in_bytes = + int_workspace_buffer.size(0) * int_workspace_buffer.element_size(); + auto device = float_workspace_buffer.device(); cudaStream_t torch_current_stream = c10::cuda::getCurrentCUDAStream(device.index()); handler_->SetCUDAStream(torch_current_stream); indptr = indptr.to(torch::kCPU); @@ -59,8 +65,10 @@ void BatchDecodeWithPagedKVCachePyTorchWrapper::BeginForward( handler_->BeginForwardDispatched( - static_cast(workspace_buffer.data_ptr()), workspace_size_in_bytes, - static_cast(indptr.data_ptr()), + static_cast(float_workspace_buffer.data_ptr()), + float_workspace_size_in_bytes, + static_cast(int_workspace_buffer.data_ptr()), + int_workspace_size_in_bytes, static_cast(indptr.data_ptr()), static_cast(last_page_len.data_ptr()), batch_size, num_qo_heads, num_kv_heads, page_size); TORCH_CHECK(status == cudaSuccess, "BatchDecodeWithPagedKVCache failed with error ", @@ -81,8 +89,10 @@ void BatchDecodeWithPagedKVCachePyTorchWrapper::BeginForward( handler_->BeginForwardDispatched( - static_cast(workspace_buffer.data_ptr()), workspace_size_in_bytes, - static_cast(indptr.data_ptr()), + static_cast(float_workspace_buffer.data_ptr()), + float_workspace_size_in_bytes, + static_cast(int_workspace_buffer.data_ptr()), + int_workspace_size_in_bytes, static_cast(indptr.data_ptr()), static_cast(last_page_len.data_ptr()), batch_size, num_qo_heads, num_kv_heads, page_size); TORCH_CHECK(status == cudaSuccess, @@ -100,8 +110,8 @@ void BatchDecodeWithPagedKVCachePyTorchWrapper::BeginForward( void BatchDecodeWithPagedKVCachePyTorchWrapper::EndForward() { handler_->EndForward(); } void BatchDecodeWithPagedKVCachePyTorchWrapper::UpdatePageLockedBufferSize( - unsigned int max_workspace_size_in_bytes) { - handler_->UpdatePageLockedBufferSize(max_workspace_size_in_bytes); + unsigned int int_workspace_size_in_bytes) { + handler_->UpdatePageLockedBufferSize(int_workspace_size_in_bytes); } std::vector BatchDecodeWithPagedKVCachePyTorchWrapper::Forward( diff --git a/python/csrc/batch_prefill.cu b/python/csrc/batch_prefill.cu index 85c4b3ca..86bea032 100644 --- a/python/csrc/batch_prefill.cu +++ b/python/csrc/batch_prefill.cu @@ -21,29 +21,36 @@ using namespace flashinfer; void BatchPrefillWithPagedKVCachePyTorchWrapper::BeginForward( - torch::Tensor workspace_buffer, torch::Tensor qo_indptr, torch::Tensor paged_kv_indptr, - unsigned int batch_size, unsigned int num_qo_heads, unsigned int num_kv_heads, - unsigned int head_dim, unsigned int page_size, torch::Tensor empty_q_data) { - CHECK_INPUT(workspace_buffer); + torch::Tensor float_workspace_buffer, torch::Tensor int_workspace_buffer, + torch::Tensor qo_indptr, torch::Tensor paged_kv_indptr, unsigned int batch_size, + unsigned int num_qo_heads, unsigned int num_kv_heads, unsigned int head_dim, + unsigned int page_size, torch::Tensor empty_q_data) { + CHECK_INPUT(float_workspace_buffer); + CHECK_INPUT(int_workspace_buffer); // NOTE(Zihao): not necessary to be a CUDA tensor CHECK_CONTIGUOUS(qo_indptr); CHECK_CONTIGUOUS(paged_kv_indptr); CHECK_GQA_HEAD_DIVISIBLE(num_qo_heads, num_kv_heads); CHECK_DIM(1, qo_indptr); CHECK_DIM(1, paged_kv_indptr); - CHECK_DIM(1, workspace_buffer); + CHECK_DIM(1, float_workspace_buffer); + CHECK_DIM(1, int_workspace_buffer); CHECK_EQ(qo_indptr.size(0), batch_size + 1); CHECK_EQ(paged_kv_indptr.size(0), batch_size + 1); qo_indptr = qo_indptr.to(torch::dtype(torch::kInt32).device(torch::kCPU)); paged_kv_indptr = paged_kv_indptr.to(torch::dtype(torch::kInt32).device(torch::kCPU)); - auto device = workspace_buffer.device(); - size_t workspace_size_in_bytes = workspace_buffer.size(0) * workspace_buffer.element_size(); + auto device = float_workspace_buffer.device(); + size_t float_workspace_size_in_bytes = + float_workspace_buffer.size(0) * float_workspace_buffer.element_size(); + size_t int_workspace_size_in_bytes = + int_workspace_buffer.size(0) * int_workspace_buffer.element_size(); cudaStream_t torch_current_stream = c10::cuda::getCurrentCUDAStream(device.index()); handler_->SetCUDAStream(torch_current_stream); DISPATCH_PYTORCH_DTYPE_TO_CTYPE_FP16(empty_q_data.scalar_type(), q_type, [&] { cudaError_t status = handler_->BeginForward( - static_cast(workspace_buffer.data_ptr()), workspace_size_in_bytes, + static_cast(float_workspace_buffer.data_ptr()), float_workspace_size_in_bytes, + static_cast(int_workspace_buffer.data_ptr()), int_workspace_size_in_bytes, static_cast(qo_indptr.data_ptr()), static_cast(paged_kv_indptr.data_ptr()), batch_size, num_qo_heads, num_kv_heads, head_dim, page_size); @@ -56,8 +63,8 @@ void BatchPrefillWithPagedKVCachePyTorchWrapper::BeginForward( void BatchPrefillWithPagedKVCachePyTorchWrapper::EndForward() { handler_->EndForward(); } void BatchPrefillWithPagedKVCachePyTorchWrapper::UpdatePageLockedBufferSize( - unsigned int max_workspace_size_in_bytes) { - handler_->UpdatePageLockedBufferSize(max_workspace_size_in_bytes); + unsigned int int_workspace_size_in_bytes) { + handler_->UpdatePageLockedBufferSize(int_workspace_size_in_bytes); } std::vector BatchPrefillWithPagedKVCachePyTorchWrapper::Forward( @@ -446,28 +453,35 @@ std::vector BatchPrefillWithPagedKVCachePyTorchWrapper::ForwardCu } void BatchPrefillWithRaggedKVCachePyTorchWrapper::BeginForward( - torch::Tensor workspace_buffer, torch::Tensor qo_indptr, torch::Tensor kv_indptr, - unsigned int batch_size, unsigned int num_qo_heads, unsigned int num_kv_heads, - unsigned int head_dim, torch::Tensor empty_q_data) { - CHECK_INPUT(workspace_buffer); + torch::Tensor float_workspace_buffer, torch::Tensor int_workspace_buffer, + torch::Tensor qo_indptr, torch::Tensor kv_indptr, unsigned int batch_size, + unsigned int num_qo_heads, unsigned int num_kv_heads, unsigned int head_dim, + torch::Tensor empty_q_data) { + CHECK_INPUT(float_workspace_buffer); + CHECK_INPUT(int_workspace_buffer); // NOTE(Zihao): not necessary to be a CUDA tensor CHECK_CONTIGUOUS(qo_indptr); CHECK_GQA_HEAD_DIVISIBLE(num_qo_heads, num_kv_heads); CHECK_DIM(1, qo_indptr); CHECK_DIM(1, kv_indptr); - CHECK_DIM(1, workspace_buffer); + CHECK_DIM(1, float_workspace_buffer); + CHECK_DIM(1, int_workspace_buffer); CHECK_EQ(qo_indptr.size(0), batch_size + 1); CHECK_EQ(kv_indptr.size(0), batch_size + 1); qo_indptr = qo_indptr.to(torch::dtype(torch::kInt32).device(torch::kCPU)); kv_indptr = kv_indptr.to(torch::dtype(torch::kInt32).device(torch::kCPU)); - size_t workspace_size_in_bytes = workspace_buffer.size(0) * workspace_buffer.element_size(); - auto device = workspace_buffer.device(); + size_t float_workspace_size_in_bytes = + float_workspace_buffer.size(0) * float_workspace_buffer.element_size(); + size_t int_workspace_size_in_bytes = + int_workspace_buffer.size(0) * int_workspace_buffer.element_size(); + auto device = float_workspace_buffer.device(); cudaStream_t torch_current_stream = c10::cuda::getCurrentCUDAStream(device.index()); handler_->SetCUDAStream(torch_current_stream); DISPATCH_PYTORCH_DTYPE_TO_CTYPE_FP16(empty_q_data.scalar_type(), q_type, [&] { cudaError_t status = handler_->BeginForward( - static_cast(workspace_buffer.data_ptr()), workspace_size_in_bytes, + static_cast(float_workspace_buffer.data_ptr()), float_workspace_size_in_bytes, + static_cast(int_workspace_buffer.data_ptr()), int_workspace_size_in_bytes, static_cast(qo_indptr.data_ptr()), static_cast(kv_indptr.data_ptr()), batch_size, num_qo_heads, num_kv_heads, head_dim, /*page_size=*/1); @@ -480,8 +494,8 @@ void BatchPrefillWithRaggedKVCachePyTorchWrapper::BeginForward( void BatchPrefillWithRaggedKVCachePyTorchWrapper::EndForward() { handler_->EndForward(); } void BatchPrefillWithRaggedKVCachePyTorchWrapper::UpdatePageLockedBufferSize( - unsigned int max_workspace_size_in_bytes) { - handler_->UpdatePageLockedBufferSize(max_workspace_size_in_bytes); + unsigned int int_workspace_size_in_bytes) { + handler_->UpdatePageLockedBufferSize(int_workspace_size_in_bytes); } std::vector BatchPrefillWithRaggedKVCachePyTorchWrapper::Forward( diff --git a/python/csrc/flashinfer_ops_decode.h b/python/csrc/flashinfer_ops_decode.h index 1f955a7f..0c383c04 100644 --- a/python/csrc/flashinfer_ops_decode.h +++ b/python/csrc/flashinfer_ops_decode.h @@ -28,13 +28,13 @@ torch::Tensor single_decode_with_kv_cache(torch::Tensor q, torch::Tensor k, torc class BatchDecodeWithPagedKVCachePyTorchWrapper { public: - void BeginForward(torch::Tensor workspace_buffer, torch::Tensor indptr, - torch::Tensor last_page_len, unsigned int batch_size, unsigned int num_qo_heads, - unsigned int num_kv_heads, unsigned int head_dim, unsigned int page_size, - unsigned int pos_encoding_mode, float logits_soft_cap, + void BeginForward(torch::Tensor float_workspace_buffer, torch::Tensor int_workspace_buffer, + torch::Tensor indptr, torch::Tensor last_page_len, unsigned int batch_size, + unsigned int num_qo_heads, unsigned int num_kv_heads, unsigned int head_dim, + unsigned int page_size, unsigned int pos_encoding_mode, float logits_soft_cap, torch::Tensor empty_q_data, torch::Tensor empty_kv_data); void EndForward(); - void UpdatePageLockedBufferSize(uint32_t max_workspace_size_in_bytes); + void UpdatePageLockedBufferSize(uint32_t int_workspace_size_in_bytes); bool IsCUDAGraphEnabled() const { return handler_->IsCUDAGraphEnabled(); } std::vector Forward(torch::Tensor q, std::optional paged_kv_cache, std::optional paged_k_cache, diff --git a/python/csrc/flashinfer_ops_prefill.h b/python/csrc/flashinfer_ops_prefill.h index 949da9ae..b895f8b1 100644 --- a/python/csrc/flashinfer_ops_prefill.h +++ b/python/csrc/flashinfer_ops_prefill.h @@ -34,13 +34,13 @@ std::vector single_prefill_with_kv_cache_custom_mask( class BatchPrefillWithPagedKVCachePyTorchWrapper { public: - void BeginForward(torch::Tensor workspace_buffer, torch::Tensor qo_indptr, - torch::Tensor page_kv_indptr, unsigned int batch_size, + void BeginForward(torch::Tensor float_workspace_buffer, torch::Tensor int_workspace_buffer, + torch::Tensor qo_indptr, torch::Tensor page_kv_indptr, unsigned int batch_size, unsigned int num_qo_heads, unsigned int num_kv_heads, unsigned int head_dim, unsigned page_size, torch::Tensor empty_q_data); void EndForward(); bool IsCUDAGraphEnabled() const { return handler_->IsCUDAGraphEnabled(); } - void UpdatePageLockedBufferSize(uint32_t max_workspace_size_in_bytes); + void UpdatePageLockedBufferSize(uint32_t int_workspace_size_in_bytes); std::vector Forward(torch::Tensor q, torch::Tensor qo_indptr, std::optional paged_kv_cache, std::optional paged_k_cache, @@ -69,12 +69,13 @@ class BatchPrefillWithPagedKVCachePyTorchWrapper { class BatchPrefillWithRaggedKVCachePyTorchWrapper { public: - void BeginForward(torch::Tensor workspace_buffer, torch::Tensor qo_indptr, - torch::Tensor kv_indptr, unsigned int batch_size, unsigned int num_qo_heads, - unsigned int num_kv_heads, unsigned int head_dim, torch::Tensor empty_q_data); + void BeginForward(torch::Tensor float_workspace_buffer, torch::Tensor int_workspace_buffer, + torch::Tensor qo_indptr, torch::Tensor kv_indptr, unsigned int batch_size, + unsigned int num_qo_heads, unsigned int num_kv_heads, unsigned int head_dim, + torch::Tensor empty_q_data); void EndForward(); bool IsCUDAGraphEnabled() const { return handler_->IsCUDAGraphEnabled(); } - void UpdatePageLockedBufferSize(uint32_t max_workspace_size_in_bytes); + void UpdatePageLockedBufferSize(uint32_t int_workspace_size_in_bytes); std::vector Forward(torch::Tensor q, torch::Tensor qo_indptr, torch::Tensor k, torch::Tensor v, torch::Tensor kv_indptr, bool causal, unsigned int pos_encoding_mode, bool allow_fp16_qk_reduction, diff --git a/python/flashinfer/cascade.py b/python/flashinfer/cascade.py index 1a2ad7b4..148b54f6 100644 --- a/python/flashinfer/cascade.py +++ b/python/flashinfer/cascade.py @@ -257,22 +257,32 @@ class BatchDecodeWithSharedPrefixPagedKVCacheWrapper: manages the lifecycle of these data structures. """ - def __init__(self, workspace_buffer: torch.Tensor, kv_layout: str = "NHD") -> None: + def __init__( + self, float_workspace_buffer: torch.Tensor, kv_layout: str = "NHD" + ) -> None: self._batch_decode_wrapper = BatchDecodeWithPagedKVCacheWrapper( - workspace_buffer, kv_layout + float_workspace_buffer, kv_layout ) self._kv_layout = kv_layout - def reset_workspace_buffer(self, new_workspace_buffer: torch.Tensor) -> None: + def reset_workspace_buffer( + self, float_workspace_buffer: torch.Tensor, int_workspace_buffer + ) -> None: r"""Reset the workspace buffer. Parameters ---------- - new_workspace_buffer : torch.Tensor - The new workspace buffer, the device of the new workspace buffer should + float_workspace_buffer : torch.Tensor + The new float workspace buffer, the device of the new float workspace buffer should + be the same as the device of the input tensors. + + int_workspace_buffer : torch.Tensor + The new int workspace buffer, the device of the new int workspace buffer should be the same as the device of the input tensors. """ - self._batch_decode_wrapper.reset_workspace_buffer(new_workspace_buffer) + self._batch_decode_wrapper.reset_workspace_buffer( + float_workspace_buffer, int_workspace_buffer + ) def begin_forward( self, @@ -503,33 +513,43 @@ class BatchPrefillWithSharedPrefixPagedKVCacheWrapper: layers). This wrapper class manages the lifecycle of these data structures. """ - def __init__(self, workspace_buffer: torch.Tensor, kv_layout: str = "NHD") -> None: + def __init__( + self, float_workspace_buffer: torch.Tensor, kv_layout: str = "NHD" + ) -> None: r"""Constructor of :class:`BatchDecodeWithSharedPrefixPagedKVCacheWrapper`. Parameters ---------- - workspace_buffer : torch.Tensor - The user reserved workspace buffer used to store auxiliary data structures, - recommended size is 128MB, the device of the workspace buffer should be the - same as the device of the input tensors. + float_workspace_buffer : torch.Tensor + The user reserved float workspace buffer used to store intermediate attention results + in the split-k algorithm. The recommended size is 128MB, the device of the workspace + buffer should be the same as the device of the input tensors. kv_layout : str The layout of the input k/v tensors, could be either ``NHD`` or ``HND``. """ self._batch_prefill_wrapper = BatchPrefillWithPagedKVCacheWrapper( - workspace_buffer, kv_layout + float_workspace_buffer, kv_layout ) self._kv_layout = kv_layout - def reset_workspace_buffer(self, new_workspace_buffer: torch.Tensor) -> None: + def reset_workspace_buffer( + self, float_workspace_buffer: torch.Tensor, int_workspace_buffer + ) -> None: r"""Reset the workspace buffer. Parameters ---------- - new_workspace_buffer : torch.Tensor - The new workspace buffer, the device of the new workspace buffer should + float_workspace_buffer : torch.Tensor + The new float workspace buffer, the device of the new float workspace buffer should + be the same as the device of the input tensors. + + int_workspace_buffer : torch.Tensor + The new int workspace buffer, the device of the new int workspace buffer should be the same as the device of the input tensors. """ - self._batch_prefill_wrapper.reset_workspace_buffer(new_workspace_buffer) + self._batch_prefill_wrapper.reset_workspace_buffer( + float_workspace_buffer, int_workspace_buffer + ) def begin_forward( self, diff --git a/python/flashinfer/decode.py b/python/flashinfer/decode.py index 123ca8b6..ebb652f8 100644 --- a/python/flashinfer/decode.py +++ b/python/flashinfer/decode.py @@ -279,7 +279,7 @@ class BatchDecodeWithPagedKVCacheWrapper: def __init__( self, - workspace_buffer: torch.Tensor, + float_workspace_buffer: torch.Tensor, kv_layout: str = "NHD", use_cuda_graph: bool = False, use_tensor_cores: bool = False, @@ -291,9 +291,9 @@ def __init__( Parameters ---------- - workspace_buffer : torch.Tensor - The user reserved workspace buffer used to store auxiliary data structures, - recommended size is 128MB, the device of the workspace + float_workspace_buffer : torch.Tensor + The user reserved float workspace buffer used to store intermediate attention results + in the split-k algorithm. The recommended size is 128MB, the device of the workspace buffer should be the same as the device of the input tensors. kv_layout : str @@ -326,7 +326,10 @@ def __init__( """ _check_kv_layout(kv_layout) self._kv_layout = kv_layout - self._workspace_buffer = workspace_buffer + self._float_workspace_buffer = float_workspace_buffer + self._int_workspace_buffer = torch.empty( + (8 * 1024 * 1024,), dtype=torch.uint8, device=float_workspace_buffer.device + ) if use_cuda_graph: if not torch.is_tensor(paged_kv_indptr_buffer): @@ -363,7 +366,7 @@ def __init__( self._qo_indptr_buf = torch.arange( self._fixed_batch_size + 1, dtype=torch.int32, - device=workspace_buffer.device, + device=float_workspace_buffer.device, ) else: self._use_tensor_cores = False @@ -381,16 +384,26 @@ def use_tensor_cores(self) -> bool: def is_cuda_graph_enabled(self) -> bool: return self._wrapper.is_cuda_graph_enabled() - def reset_workspace_buffer(self, new_workspace_buffer: torch.Tensor) -> None: + def reset_workspace_buffer( + self, float_workspace_buffer: torch.Tensor, int_workspace_buffer + ) -> None: r"""Reset the workspace buffer. Parameters ---------- - new_workspace_buffer : torch.Tensor - The new workspace buffer, the device of the new workspace buffer should + float_workspace_buffer : torch.Tensor + The new float workspace buffer, the device of the new float workspace buffer should + be the same as the device of the input tensors. + + int_workspace_buffer : torch.Tensor + The new int workspace buffer, the device of the new int workspace buffer should be the same as the device of the input tensors. """ - self._workspace_buffer = new_workspace_buffer + self._float_workspace_buffer = float_workspace_buffer + self._int_workspace_buffer = int_workspace_buffer + self._wrapper.update_page_locked_buffer_size( + int_workspace_buffer.numel() * int_workspace_buffer.element_size() + ) def begin_forward( self, @@ -511,7 +524,8 @@ def begin_forward( batch_size + 1, dtype=torch.int32, device=indptr.device ) self._wrapper.begin_forward( - self._workspace_buffer, + self._float_workspace_buffer, + self._int_workspace_buffer, self._qo_indptr_buf, indptr, batch_size, @@ -523,7 +537,8 @@ def begin_forward( ) else: self._wrapper.begin_forward( - self._workspace_buffer, + self._float_workspace_buffer, + self._int_workspace_buffer, indptr, last_page_len, batch_size, diff --git a/python/flashinfer/prefill.py b/python/flashinfer/prefill.py index 8c626d2c..8e01ee48 100644 --- a/python/flashinfer/prefill.py +++ b/python/flashinfer/prefill.py @@ -541,7 +541,7 @@ class BatchPrefillWithPagedKVCacheWrapper: def __init__( self, - workspace_buffer: torch.Tensor, + float_workspace_buffer: torch.Tensor, kv_layout: str = "NHD", use_cuda_graph: bool = False, qo_indptr_buf: Optional[torch.Tensor] = None, @@ -555,10 +555,10 @@ def __init__( Parameters ---------- - workspace_buffer : torch.Tensor - The user reserved workspace buffer used to store auxiliary data structures, - recommended size is 128MB, the device of the workspace buffer should be the - same as the device of the input tensors. + float_workspace_buffer : torch.Tensor + The user reserved workspace buffer used to store intermediate attention results in + split-k algorithm. The recommended size is 128MB, the device of the workspace buffer + should be the same as the device of the input tensors. kv_layout : str The layout of the input k/v tensors, could be either ``NHD`` or ``HND``. @@ -603,7 +603,10 @@ def __init__( """ _check_kv_layout(kv_layout) self._kv_layout = kv_layout - self._workspace_buffer = workspace_buffer + self._float_workspace_buffer = float_workspace_buffer + self._int_workspace_buffer = torch.empty( + (8 * 1024 * 1024,), dtype=torch.uint8, device=float_workspace_buffer.device + ) self._wrapper = _prefill.BatchPrefillWithPagedKVCachePyTorchWrapper( TensorLayout[kv_layout].value, use_cuda_graph, @@ -649,16 +652,26 @@ def __init__( def is_cuda_graph_enabled(self) -> bool: return self._wrapper.is_cuda_graph_enabled() - def reset_workspace_buffer(self, new_workspace_buffer: torch.Tensor) -> None: + def reset_workspace_buffer( + self, float_workspace_buffer: torch.Tensor, int_workspace_buffer: torch.Tensor + ) -> None: r"""Reset the workspace buffer. Parameters ---------- - new_workspace_buffer : torch.Tensor - The new workspace buffer, the device of the new workspace buffer should + float_workspace_buffer : torch.Tensor + The new float workspace buffer, the device of the new float workspace buffer should + be the same as the device of the input tensors. + + int_workspace_buffer : torch.Tensor + The new int workspace buffer, the device of the new int workspace buffer should be the same as the device of the input tensors. """ - self._workspace_buffer = new_workspace_buffer + self._float_workspace_buffer = float_workspace_buffer + self._int_workspace_buffer = int_workspace_buffer + self._wrapper.update_page_locked_buffer_size( + int_workspace_buffer.numel() * int_workspace_buffer.element_size() + ) def begin_forward( self, @@ -789,7 +802,8 @@ def begin_forward( ), ) self._wrapper.begin_forward( - self._workspace_buffer, + self._float_workspace_buffer, + self._int_workspace_buffer, qo_indptr, paged_kv_indptr, batch_size, @@ -1176,7 +1190,7 @@ class BatchPrefillWithRaggedKVCacheWrapper: def __init__( self, - workspace_buffer: torch.Tensor, + float_workspace_buffer: torch.Tensor, kv_layout: str = "NHD", use_cuda_graph: bool = False, qo_indptr_buf: Optional[torch.Tensor] = None, @@ -1188,10 +1202,10 @@ def __init__( Parameters ---------- - workspace_buffer : torch.Tensor - The user reserved workspace buffer used to store auxiliary data structures, - recommended size is 128MB, the device of the workspace buffer should be the - same as the device of the input tensors. + float_workspace_buffer : torch.Tensor + The user reserved float workspace buffer used to store intermediate attention results + in the split-k algorithm. The recommended size is 128MB, the device of the workspace + buffer should be the same as the device of the input tensors. kv_layout : str The layout of the input k/v tensors, could be either ``NHD`` or ``HND``. @@ -1224,7 +1238,10 @@ def __init__( """ _check_kv_layout(kv_layout) self._kv_layout = kv_layout - self._workspace_buffer = workspace_buffer + self._float_workspace_buffer = float_workspace_buffer + self._int_workspace_buffer = torch.empty( + (8 * 1024 * 1024,), dtype=torch.uint8, device=float_workspace_buffer.device + ) self._wrapper = _prefill.BatchPrefillWithRaggedKVCachePyTorchWrapper( TensorLayout[kv_layout].value, use_cuda_graph, @@ -1257,16 +1274,26 @@ def __init__( def is_cuda_graph_enabled(self) -> bool: return self._wrapper.is_cuda_graph_enabled() - def reset_workspace_buffer(self, new_workspace_buffer: torch.Tensor) -> None: + def reset_workspace_buffer( + self, float_workspace_buffer: torch.Tensor, int_workspace_buffer + ) -> None: r"""Reset the workspace buffer. Parameters ---------- - new_workspace_buffer : torch.Tensor - The new workspace buffer, the device of the new workspace buffer should + float_workspace_buffer : torch.Tensor + The new float workspace buffer, the device of the new float workspace buffer should + be the same as the device of the input tensors. + + int_workspace_buffer : torch.Tensor + The new int workspace buffer, the device of the new int workspace buffer should be the same as the device of the input tensors. """ - self._workspace_buffer = new_workspace_buffer + self._float_workspace_buffer = float_workspace_buffer + self._int_workspace_buffer = int_workspace_buffer + self._wrapper.update_page_locked_buffer_size( + int_workspace_buffer.numel() * int_workspace_buffer.element_size() + ) def begin_forward( self, @@ -1376,7 +1403,8 @@ def begin_forward( ), ) self._wrapper.begin_forward( - self._workspace_buffer, + self._float_workspace_buffer, + self._int_workspace_buffer, qo_indptr, kv_indptr, batch_size, diff --git a/python/flashinfer/sparse.py b/python/flashinfer/sparse.py index f80f1231..abfb5bd1 100644 --- a/python/flashinfer/sparse.py +++ b/python/flashinfer/sparse.py @@ -111,18 +111,21 @@ class BlockSparseAttentionWrapper: def __init__( self, - workspace_buffer: torch.Tensor, + float_workspace_buffer: torch.Tensor, ) -> None: r"""Constructs of :class:`BlockSparseAttentionWrapper`. Parameters ---------- - workspace_buffer : torch.Tensor - The user reserved workspace buffer used to store auxiliary data structures, - recommended size is 128MB, the device of the workspace buffer should be the - same as the device of the input tensors. + float_workspace_buffer : torch.Tensor + The user reserved float workspace buffer used to store intermediate attention results + in the split-k algorithm. The recommended size is 128MB, the device of the workspace + buffer should be the same as the device of the input tensors. """ - self._workspace_buffer = workspace_buffer + self._float_workspace_buffer = float_workspace_buffer + self._int_workspace_buffer = torch.empty( + (8 * 1024 * 1024,), dtype=torch.uint8, device=float_workspace_buffer.device + ) self._wrapper = _prefill.BatchPrefillWithPagedKVCachePyTorchWrapper( TensorLayout["NHD"].value, False, # use_cuda_graph @@ -138,6 +141,27 @@ def __init__( self.M: Optional[int] = None self.N: Optional[int] = None + def reset_workspace_buffer( + self, float_workspace_buffer: torch.Tensor, int_workspace_buffer + ) -> None: + r"""Reset the workspace buffer. + + Parameters + ---------- + float_workspace_buffer : torch.Tensor + The new float workspace buffer, the device of the new float workspace buffer should + be the same as the device of the input tensors. + + int_workspace_buffer : torch.Tensor + The new int workspace buffer, the device of the new int workspace buffer should + be the same as the device of the input tensors. + """ + self._float_workspace_buffer = float_workspace_buffer + self._int_workspace_buffer = int_workspace_buffer + self._wrapper.update_page_locked_buffer_size( + int_workspace_buffer.numel() * int_workspace_buffer.element_size() + ) + def begin_forward( self, indptr: torch.Tensor, @@ -244,7 +268,8 @@ def begin_forward( self.C = C self._wrapper.begin_forward( - self._workspace_buffer, + self._float_workspace_buffer, + self._int_workspace_buffer, self._qo_indptr, self._paged_kv_indptr_buf, num_blocks_row, diff --git a/python/setup.py b/python/setup.py index 9566da6c..a23c6ac3 100644 --- a/python/setup.py +++ b/python/setup.py @@ -313,9 +313,7 @@ def __init__(self, *args, **kwargs) -> None: files_prefill, files_decode = get_instantiation_cu() include_dirs = [ str(root.resolve() / "include"), - str( - root.resolve() / "3rdparty" / "cutlass" / "include" - ), # for group gemm + str(root.resolve() / "3rdparty" / "cutlass" / "include"), # for group gemm ] extra_compile_args = { "cxx": [ diff --git a/src/bench_batch_decode.cu b/src/bench_batch_decode.cu index 024a911e..d81f22a9 100644 --- a/src/bench_batch_decode.cu +++ b/src/bench_batch_decode.cu @@ -15,6 +15,7 @@ */ #include +#include #include #include #include @@ -71,13 +72,16 @@ void bench_flashinfer_batch_decode(nvbench::state& state) { BatchDecodeHandler handler; if (cooperative) { - size_t workspace_size_in_bytes = 32 * 1024 * 1024; - thrust::device_vector buffer(workspace_size_in_bytes); + size_t float_workspace_size_in_bytes = 32 * 1024 * 1024; + thrust::device_vector float_buffer(float_workspace_size_in_bytes); + size_t int_workspace_size_in_bytes = 8 * 1024 * 1024; + thrust::device_vector int_buffer(int_workspace_size_in_bytes); // begin forward BatchDecodeHandlerBeginForward( - &handler, (void*)thrust::raw_pointer_cast(buffer.data()), workspace_size_in_bytes, - kv_indptr_host.data(), kv_last_page_len_host.data(), batch_size, num_qo_heads, num_kv_heads, - head_dim, page_size, pos_encoding_mode); + &handler, (void*)thrust::raw_pointer_cast(float_buffer.data()), + float_workspace_size_in_bytes, (void*)thrust::raw_pointer_cast(int_buffer.data()), + int_workspace_size_in_bytes, kv_indptr_host.data(), kv_last_page_len_host.data(), + batch_size, num_qo_heads, num_kv_heads, head_dim, page_size, pos_encoding_mode); state.exec([&](nvbench::launch&) { cudaError_t status = BatchDecodeWithPagedKVCacheWrapper( @@ -148,12 +152,16 @@ void bench_flashinfer_batch_decode_with_prefill(nvbench::state& state) { "Read"); state.add_global_memory_writes(vec_bytes(o), "Write"); BatchPrefillHandler handler; - size_t workspace_size_in_bytes = 128 * 1024 * 1024; - thrust::device_vector buffer(workspace_size_in_bytes); + size_t float_workspace_size_in_bytes = 128 * 1024 * 1024; + thrust::device_vector float_buffer(float_workspace_size_in_bytes); + size_t int_workspace_size_in_bytes = 8 * 1024 * 1024; + thrust::device_vector int_buffer(int_workspace_size_in_bytes); handler.BeginForward( - (void*)thrust::raw_pointer_cast(buffer.data()), workspace_size_in_bytes, qo_indptr_h.data(), - kv_indptr_host.data(), batch_size, num_qo_heads, num_kv_heads, head_dim, page_size); + (void*)thrust::raw_pointer_cast(float_buffer.data()), float_workspace_size_in_bytes, + (void*)thrust::raw_pointer_cast(int_buffer.data()), int_workspace_size_in_bytes, + qo_indptr_h.data(), kv_indptr_host.data(), batch_size, num_qo_heads, num_kv_heads, head_dim, + page_size); state.exec(nvbench::exec_tag::sync, [&](nvbench::launch&) { cudaError_t status = diff --git a/src/bench_batch_prefill.cu b/src/bench_batch_prefill.cu index 81875436..d04e838a 100644 --- a/src/bench_batch_prefill.cu +++ b/src/bench_batch_prefill.cu @@ -16,6 +16,7 @@ #include #include +#include #include #include #include @@ -48,7 +49,10 @@ void bench_flashinfer_batch_prefill_with_ragged_kv(nvbench::state& state) { thrust::device_vector K(batch_size * kv_len * num_kv_heads * head_dim); thrust::device_vector V(batch_size * kv_len * num_kv_heads * head_dim); thrust::device_vector O(batch_size * qo_len * num_qo_heads * head_dim); - thrust::device_vector workspace(128 * 1024 * 1024); + size_t float_workspace_size_in_bytes = 128 * 1024 * 1024; + thrust::device_vector float_workspace(float_workspace_size_in_bytes); + size_t int_workspace_size_in_bytes = 8 * 1024 * 1024; + thrust::device_vector int_workspace(int_workspace_size_in_bytes); // Provide throughput information: state.add_global_memory_reads( @@ -69,10 +73,11 @@ void bench_flashinfer_batch_prefill_with_ragged_kv(nvbench::state& state) { BatchPrefillHandler handler; - handler.BeginForward(thrust::raw_pointer_cast(workspace.data()), - workspace.size() * sizeof(uint8_t), qo_indptr_h.data(), - kv_indptr_h.data(), batch_size, num_qo_heads, num_kv_heads, - head_dim, /*page_size=*/1); + handler.BeginForward( + thrust::raw_pointer_cast(float_workspace.data()), float_workspace_size_in_bytes, + thrust::raw_pointer_cast(int_workspace.data()), int_workspace_size_in_bytes, + qo_indptr_h.data(), kv_indptr_h.data(), batch_size, num_qo_heads, num_kv_heads, head_dim, + /*page_size=*/1); state.exec(nvbench::exec_tag::timer, [&](nvbench::launch& launch, auto& timer) { timer.start(); diff --git a/src/bench_cascade.cu b/src/bench_cascade.cu index 0312e06c..56cc3d09 100644 --- a/src/bench_cascade.cu +++ b/src/bench_cascade.cu @@ -15,6 +15,7 @@ */ #include +#include #include #include @@ -107,12 +108,15 @@ void bench_two_level_single_prefix_cascade_decode(nvbench::state& state) { thrust::raw_pointer_cast(kv_indptr_unique_d.data()), thrust::raw_pointer_cast(kv_last_page_len_unique_d.data())); BatchDecodeHandler cascade_handler; - size_t workspace_size_in_bytes = 32 * 1024 * 1024; - thrust::device_vector buffer(workspace_size_in_bytes); + size_t float_workspace_size_in_bytes = 32 * 1024 * 1024; + thrust::device_vector float_buffer(float_workspace_size_in_bytes); + size_t int_workspace_size_in_bytes = 8 * 1024 * 1024; + thrust::device_vector int_buffer(int_workspace_size_in_bytes); BatchDecodeHandlerBeginForward( - &cascade_handler, (void*)thrust::raw_pointer_cast(buffer.data()), workspace_size_in_bytes, - kv_indptr_unique_h.data(), kv_last_page_len_unique_h.data(), batch_size, num_qo_heads, - num_kv_heads, head_dim, page_size, PosEncodingMode::kNone); + &cascade_handler, (void*)thrust::raw_pointer_cast(float_buffer.data()), + float_workspace_size_in_bytes, (void*)thrust::raw_pointer_cast(int_buffer.data()), + int_workspace_size_in_bytes, kv_indptr_unique_h.data(), kv_last_page_len_unique_h.data(), + batch_size, num_qo_heads, num_kv_heads, head_dim, page_size, PosEncodingMode::kNone); state.exec(nvbench::exec_tag::timer, [&](nvbench::launch& launch, auto& timer) { timer.start(); @@ -165,12 +169,16 @@ void bench_two_level_single_prefix_cascade_decode(nvbench::state& state) { thrust::raw_pointer_cast(kv_indptr_combined_d.data()), thrust::raw_pointer_cast(kv_last_page_len_combined_d.data())); BatchDecodeHandler baseline_handler; - size_t workspace_size_in_bytes = 32 * 1024 * 1024; - thrust::device_vector buffer(workspace_size_in_bytes); + size_t float_workspace_size_in_bytes = 32 * 1024 * 1024; + thrust::device_vector float_buffer(float_workspace_size_in_bytes); + size_t int_workspace_size_in_bytes = 8 * 1024 * 1024; + thrust::device_vector int_buffer(int_workspace_size_in_bytes); BatchDecodeHandlerBeginForward( - &baseline_handler, (void*)thrust::raw_pointer_cast(buffer.data()), workspace_size_in_bytes, - kv_indptr_combined_h.data(), kv_last_page_len_combined_h.data(), batch_size, num_qo_heads, - num_kv_heads, head_dim, page_size, PosEncodingMode::kNone); + &baseline_handler, (void*)thrust::raw_pointer_cast(float_buffer.data()), + float_workspace_size_in_bytes, (void*)thrust::raw_pointer_cast(int_buffer.data()), + int_workspace_size_in_bytes, kv_indptr_combined_h.data(), + kv_last_page_len_combined_h.data(), batch_size, num_qo_heads, num_kv_heads, head_dim, + page_size, PosEncodingMode::kNone); state.exec(nvbench::exec_tag::timer, [&](nvbench::launch& launch, auto& timer) { timer.start(); @@ -241,11 +249,15 @@ void bench_two_level_single_prefix_cascade_append(nvbench::state& state) { thrust::raw_pointer_cast(kv_indptr_unique_d.data()), thrust::raw_pointer_cast(kv_last_page_len_unique_d.data())); BatchPrefillHandler cascade_handler; - size_t workspace_size_in_bytes = 32 * 1024 * 1024; - thrust::device_vector buffer(workspace_size_in_bytes); + size_t float_workspace_size_in_bytes = 32 * 1024 * 1024; + thrust::device_vector float_buffer(float_workspace_size_in_bytes); + size_t int_workspace_size_in_bytes = 8 * 1024 * 1024; + thrust::device_vector int_buffer(int_workspace_size_in_bytes); cascade_handler.BeginForward( - (void*)thrust::raw_pointer_cast(buffer.data()), workspace_size_in_bytes, qo_indptr_h.data(), - kv_indptr_unique_h.data(), batch_size, num_qo_heads, num_kv_heads, head_dim, page_size); + (void*)thrust::raw_pointer_cast(float_buffer.data()), float_workspace_size_in_bytes, + (void*)thrust::raw_pointer_cast(int_buffer.data()), int_workspace_size_in_bytes, + qo_indptr_h.data(), kv_indptr_unique_h.data(), batch_size, num_qo_heads, num_kv_heads, + head_dim, page_size); state.exec(nvbench::exec_tag::timer, [&](nvbench::launch& launch, auto& timer) { timer.start(); cudaError_t status = SinglePrefillWithKVCache( @@ -298,11 +310,15 @@ void bench_two_level_single_prefix_cascade_append(nvbench::state& state) { thrust::raw_pointer_cast(kv_indptr_combined_d.data()), thrust::raw_pointer_cast(kv_last_page_len_combined_d.data())); BatchPrefillHandler baseline_handler; - size_t workspace_size_in_bytes = 32 * 1024 * 1024; - thrust::device_vector buffer(workspace_size_in_bytes); + size_t float_workspace_size_in_bytes = 32 * 1024 * 1024; + thrust::device_vector float_buffer(float_workspace_size_in_bytes); + size_t int_workspace_size_in_bytes = 8 * 1024 * 1024; + thrust::device_vector int_buffer(int_workspace_size_in_bytes); baseline_handler.BeginForward( - (void*)thrust::raw_pointer_cast(buffer.data()), workspace_size_in_bytes, qo_indptr_h.data(), - kv_indptr_combined_h.data(), batch_size, num_qo_heads, num_kv_heads, head_dim, page_size); + (void*)thrust::raw_pointer_cast(float_buffer.data()), float_workspace_size_in_bytes, + (void*)thrust::raw_pointer_cast(int_buffer.data()), int_workspace_size_in_bytes, + qo_indptr_h.data(), kv_indptr_combined_h.data(), batch_size, num_qo_heads, num_kv_heads, + head_dim, page_size); state.exec(nvbench::exec_tag::timer, [&](nvbench::launch& launch, auto& timer) { timer.start(); cudaError_t status = BatchPrefillWithPagedKVCacheWrapper( diff --git a/src/flashinfer_ops.cuh b/src/flashinfer_ops.cuh index 9fce7ed9..753cb244 100644 --- a/src/flashinfer_ops.cuh +++ b/src/flashinfer_ops.cuh @@ -279,8 +279,9 @@ cudaError_t BatchDecodeWithPagedKVCacheWrapper( template -cudaError_t BatchDecodeHandlerBeginForward(BatchDecodeHandler* handler, void* buffer, - size_t workspace_size_in_bytes, IdType* indptr_h, +cudaError_t BatchDecodeHandlerBeginForward(BatchDecodeHandler* handler, void* float_buffer, + size_t float_workspace_size_in_bytes, void* int_buffer, + size_t int_workspace_size_in_bytes, IdType* indptr_h, IdType* last_page_len_h, uint32_t batch_size, uint32_t num_qo_heads, uint32_t num_kv_heads, uint32_t head_dim, uint32_t page_size, @@ -295,8 +296,8 @@ cudaError_t BatchDecodeHandlerBeginForward(BatchDecodeHandler* handler, void* bu DISPATCH_pos_encoding_mode(pos_encoding_mode, POS_ENCODING_MODE, { return handler->BeginForwardDispatched( - buffer, workspace_size_in_bytes, indptr_h, last_page_len_h, batch_size, num_qo_heads, - num_kv_heads, page_size); + float_buffer, float_workspace_size_in_bytes, int_buffer, int_workspace_size_in_bytes, + indptr_h, last_page_len_h, batch_size, num_qo_heads, num_kv_heads, page_size); }); }); } diff --git a/src/test_batch_decode.cu b/src/test_batch_decode.cu index 64b9b290..1debe0b0 100644 --- a/src/test_batch_decode.cu +++ b/src/test_batch_decode.cu @@ -98,10 +98,13 @@ void _TestBatchDecodingKernelCorrectness(size_t page_size, size_t batch_size, si thrust::raw_pointer_cast(kv_indptr_device.data()), thrust::raw_pointer_cast(kv_last_page_len_device.data())); flashinfer::BatchDecodeHandler handler; - size_t workspace_size_in_bytes = 32 * 1024 * 1024; - thrust::device_vector buffer(workspace_size_in_bytes); + size_t float_workspace_size_in_bytes = 32 * 1024 * 1024; + thrust::device_vector float_buffer(float_workspace_size_in_bytes); + size_t int_workspace_size_in_bytes = 8 * 1024 * 1024; + thrust::device_vector int_buffer(int_workspace_size_in_bytes); BatchDecodeHandlerBeginForward( - &handler, (void*)thrust::raw_pointer_cast(buffer.data()), workspace_size_in_bytes, + &handler, (void*)thrust::raw_pointer_cast(float_buffer.data()), float_workspace_size_in_bytes, + (void*)thrust::raw_pointer_cast(int_buffer.data()), int_workspace_size_in_bytes, kv_indptr.data(), kv_last_page_len.data(), batch_size, num_qo_heads, num_kv_heads, head_dim, page_size, pos_encoding_mode); diff --git a/src/test_batch_prefill.cu b/src/test_batch_prefill.cu index 39602710..26b00607 100644 --- a/src/test_batch_prefill.cu +++ b/src/test_batch_prefill.cu @@ -82,8 +82,10 @@ void _TestBatchPagedPrefillKernelOneHotCorrectness(size_t num_kv_heads, size_t n paged_kv.last_page_len = thrust::raw_pointer_cast(kv_last_page_len_device.data()); BatchPrefillHandler handler; - size_t workspace_size_in_bytes = 128 * 1024 * 1024; - thrust::device_vector buffer(workspace_size_in_bytes); + size_t float_workspace_size_in_bytes = 128 * 1024 * 1024; + thrust::device_vector float_buffer(float_workspace_size_in_bytes); + size_t int_workspace_size_in_bytes = 8 * 1024 * 1024; + thrust::device_vector int_buffer(int_workspace_size_in_bytes); for (uint32_t request_idx = 0; request_idx < batch_size; ++request_idx) { // create one-hot queries @@ -104,8 +106,10 @@ void _TestBatchPagedPrefillKernelOneHotCorrectness(size_t num_kv_heads, size_t n thrust::device_vector o_device(q_len * num_qo_heads * head_dim); handler.BeginForward( - (void*)thrust::raw_pointer_cast(buffer.data()), workspace_size_in_bytes, q_indptr.data(), - kv_indptr.data(), batch_size, num_qo_heads, num_kv_heads, head_dim, page_size); + (void*)thrust::raw_pointer_cast(float_buffer.data()), float_workspace_size_in_bytes, + (void*)thrust::raw_pointer_cast(int_buffer.data()), int_workspace_size_in_bytes, + q_indptr.data(), kv_indptr.data(), batch_size, num_qo_heads, num_kv_heads, head_dim, + page_size); for (uint32_t num_runs = 0; num_runs < 10; ++num_runs) { auto status = flashinfer::BatchPrefillWithPagedKVCacheWrapper output_refs; BatchPrefillHandler handler; - size_t workspace_size_in_bytes = 128 * 1024 * 1024; - thrust::device_vector buffer(workspace_size_in_bytes); + size_t float_workspace_size_in_bytes = 128 * 1024 * 1024; + thrust::device_vector float_buffer(float_workspace_size_in_bytes); + size_t int_workspace_size_in_bytes = 8 * 1024 * 1024; + thrust::device_vector int_buffer(int_workspace_size_in_bytes); for (uint32_t request_idx = 0; request_idx < batch_size; ++request_idx) { std::vector q(q_lens[request_idx] * num_qo_heads * head_dim); @@ -191,8 +197,10 @@ void _TestBatchRaggedPrefillKernelCorrectness(size_t num_kv_heads, size_t num_qo thrust::device_vector kv_indptr_device(kv_indptr); handler.BeginForward( - (void*)thrust::raw_pointer_cast(buffer.data()), workspace_size_in_bytes, append_indptr.data(), - kv_indptr.data(), batch_size, num_qo_heads, num_kv_heads, head_dim, /*page_size=*/1); + (void*)thrust::raw_pointer_cast(float_buffer.data()), float_workspace_size_in_bytes, + (void*)thrust::raw_pointer_cast(int_buffer.data()), int_workspace_size_in_bytes, + append_indptr.data(), kv_indptr.data(), batch_size, num_qo_heads, num_kv_heads, head_dim, + /*page_size=*/1); auto status = BatchPrefillWithRaggedKVCacheWrapper( &handler, thrust::raw_pointer_cast(queries_device.data()), @@ -315,12 +323,16 @@ void _TestBatchPagedPrefillKernelShortContextCorrectness(size_t num_kv_heads, si thrust::device_vector o_device(o_concat_ref.size()); BatchPrefillHandler handler; - size_t workspace_size_in_bytes = 32 * 1024 * 1024; - thrust::device_vector buffer(workspace_size_in_bytes); + size_t float_workspace_size_in_bytes = 32 * 1024 * 1024; + thrust::device_vector float_buffer(float_workspace_size_in_bytes); + size_t int_workspace_size_in_bytes = 8 * 1024 * 1024; + thrust::device_vector int_buffer(int_workspace_size_in_bytes); handler.BeginForward( - (void*)thrust::raw_pointer_cast(buffer.data()), workspace_size_in_bytes, q_indptr.data(), - kv_indptr.data(), batch_size, num_qo_heads, num_kv_heads, head_dim, page_size); + (void*)thrust::raw_pointer_cast(float_buffer.data()), float_workspace_size_in_bytes, + (void*)thrust::raw_pointer_cast(int_buffer.data()), int_workspace_size_in_bytes, + q_indptr.data(), kv_indptr.data(), batch_size, num_qo_heads, num_kv_heads, head_dim, + page_size); auto status = BatchPrefillWithPagedKVCacheWrapper( @@ -438,12 +450,16 @@ void _TestBatchPagedPrefillKernelQMinMaxKVMinMaxCorrectness( thrust::device_vector o_device(o_concat_ref.size()); BatchPrefillHandler handler; - size_t workspace_size_in_bytes = 32 * 1024 * 1024; - thrust::device_vector buffer(workspace_size_in_bytes); + size_t float_workspace_size_in_bytes = 32 * 1024 * 1024; + thrust::device_vector float_buffer(float_workspace_size_in_bytes); + size_t int_workspace_size_in_bytes = 8 * 1024 * 1024; + thrust::device_vector int_buffer(int_workspace_size_in_bytes); handler.BeginForward( - (void*)thrust::raw_pointer_cast(buffer.data()), workspace_size_in_bytes, q_indptr.data(), - kv_indptr.data(), batch_size, num_qo_heads, num_kv_heads, head_dim, page_size); + (void*)thrust::raw_pointer_cast(float_buffer.data()), float_workspace_size_in_bytes, + (void*)thrust::raw_pointer_cast(int_buffer.data()), int_workspace_size_in_bytes, + q_indptr.data(), kv_indptr.data(), batch_size, num_qo_heads, num_kv_heads, head_dim, + page_size); auto status = BatchPrefillWithPagedKVCacheWrapper( @@ -534,12 +550,15 @@ void _TestBatchPagedPrefillKernelLongContextCorrectness(size_t num_kv_heads, siz thrust::device_vector o_device(q_lens[0] * num_qo_heads * head_dim); BatchPrefillHandler handler; - size_t workspace_size_in_bytes = 32 * 1024 * 1024; - thrust::device_vector buffer(workspace_size_in_bytes); + size_t float_workspace_size_in_bytes = 32 * 1024 * 1024; + thrust::device_vector float_buffer(float_workspace_size_in_bytes); + size_t int_workspace_size_in_bytes = 8 * 1024 * 1024; + thrust::device_vector int_buffer(int_workspace_size_in_bytes); handler.BeginForward( - (void*)thrust::raw_pointer_cast(buffer.data()), workspace_size_in_bytes, append_indptr.data(), - kv_indptr.data(), + (void*)thrust::raw_pointer_cast(float_buffer.data()), float_workspace_size_in_bytes, + (void*)thrust::raw_pointer_cast(int_buffer.data()), int_workspace_size_in_bytes, + append_indptr.data(), kv_indptr.data(), /*batch_size=*/1, num_qo_heads, num_kv_heads, head_dim, page_size); auto status = BatchPrefillWithPagedKVCacheWrapper buffer_baseline(workspace_size_in_bytes), - buffer_cascade(workspace_size_in_bytes); + size_t float_workspace_size_in_bytes = 32 * 1024 * 1024; + thrust::device_vector float_buffer(float_workspace_size_in_bytes); + size_t int_workspace_size_in_bytes = 8 * 1024 * 1024; + thrust::device_vector int_buffer(int_workspace_size_in_bytes); BatchDecodeHandlerBeginForward( - &baseline_handler, (void*)thrust::raw_pointer_cast(buffer_baseline.data()), - workspace_size_in_bytes, kv_indptr_combined_h.data(), kv_last_page_len_combined_h.data(), + &baseline_handler, (void*)thrust::raw_pointer_cast(float_buffer.data()), + float_workspace_size_in_bytes, (void*)thrust::raw_pointer_cast(int_buffer.data()), + int_workspace_size_in_bytes, kv_indptr_combined_h.data(), kv_last_page_len_combined_h.data(), batch_size, num_qo_heads, num_kv_heads, head_dim, page_size, PosEncodingMode::kNone); BatchDecodeHandlerBeginForward( - &cascade_handler, (void*)thrust::raw_pointer_cast(buffer_cascade.data()), - workspace_size_in_bytes, kv_indptr_unique_h.data(), kv_last_page_len_unique_h.data(), + &cascade_handler, (void*)thrust::raw_pointer_cast(float_buffer.data()), + float_workspace_size_in_bytes, (void*)thrust::raw_pointer_cast(int_buffer.data()), + int_workspace_size_in_bytes, kv_indptr_unique_h.data(), kv_last_page_len_unique_h.data(), batch_size, num_qo_heads, num_kv_heads, head_dim, page_size, PosEncodingMode::kNone); // Compute result using baseline implementation @@ -408,18 +411,21 @@ void _TestTwoLevelSinglePrefixCascadeAppendCorrectness(size_t batch_size, thrust::raw_pointer_cast(kv_last_page_len_unique_d.data())); BatchPrefillHandler baseline_handler, cascade_handler; - size_t workspace_size_in_bytes = 32 * 1024 * 1024; - thrust::device_vector buffer_baseline(workspace_size_in_bytes), - buffer_cascade(workspace_size_in_bytes); - - baseline_handler.BeginForward((void*)thrust::raw_pointer_cast(buffer_baseline.data()), - workspace_size_in_bytes, qo_indptr_h.data(), - kv_indptr_combined_h.data(), batch_size, num_qo_heads, - num_kv_heads, head_dim, page_size); - cascade_handler.BeginForward((void*)thrust::raw_pointer_cast(buffer_cascade.data()), - workspace_size_in_bytes, qo_indptr_h.data(), - kv_indptr_unique_h.data(), batch_size, num_qo_heads, - num_kv_heads, head_dim, page_size); + size_t float_workspace_size_in_bytes = 32 * 1024 * 1024; + thrust::device_vector float_buffer(float_workspace_size_in_bytes); + size_t int_workspace_size_in_bytes = 8 * 1024 * 1024; + thrust::device_vector int_buffer(int_workspace_size_in_bytes); + + baseline_handler.BeginForward( + (void*)thrust::raw_pointer_cast(float_buffer.data()), float_workspace_size_in_bytes, + (void*)thrust::raw_pointer_cast(int_buffer.data()), int_workspace_size_in_bytes, + qo_indptr_h.data(), kv_indptr_combined_h.data(), batch_size, num_qo_heads, num_kv_heads, + head_dim, page_size); + cascade_handler.BeginForward( + (void*)thrust::raw_pointer_cast(float_buffer.data()), float_workspace_size_in_bytes, + (void*)thrust::raw_pointer_cast(int_buffer.data()), int_workspace_size_in_bytes, + qo_indptr_h.data(), kv_indptr_unique_h.data(), batch_size, num_qo_heads, num_kv_heads, + head_dim, page_size); cudaError_t status = BatchPrefillWithPagedKVCacheWrapper( &baseline_handler, thrust::raw_pointer_cast(q_d.data()), diff --git a/src/tvm_wrapper.cu b/src/tvm_wrapper.cu index c8da983b..3ccb4fa5 100644 --- a/src/tvm_wrapper.cu +++ b/src/tvm_wrapper.cu @@ -272,11 +272,15 @@ void _FlashInferAttentionPrefillWithPagedKVCache(int64_t handler_id, DLTensor* q } void _FlashInferAttentionPrefillWithPagedKVCacheBeginForward( - int64_t handler_idx, DLTensor* workspace_buffer, DLTensor* qo_indptr, DLTensor* kv_indptr, - int64_t batch_size, int64_t num_qo_heads, int64_t num_kv_heads, int64_t head_dim, - int64_t page_size, TVMStreamHandle copy_stream) { - CHECK_EQ(workspace_buffer->ndim, 1) << "The workspace buffer must be a 1-D tensor"; - size_t workspace_size_in_bytes = workspace_buffer->shape[0] * workspace_buffer->dtype.bits / 8; + int64_t handler_idx, DLTensor* float_workspace_buffer, DLTensor* int_workspace_buffer, + DLTensor* qo_indptr, DLTensor* kv_indptr, int64_t batch_size, int64_t num_qo_heads, + int64_t num_kv_heads, int64_t head_dim, int64_t page_size, TVMStreamHandle copy_stream) { + CHECK_EQ(float_workspace_buffer->ndim, 1) << "The float workspace buffer must be a 1-D tensor"; + size_t float_workspace_size_in_bytes = + float_workspace_buffer->shape[0] * float_workspace_buffer->dtype.bits / 8; + CHECK_EQ(int_workspace_buffer->ndim, 1) << "The int workspace buffer must be a 1-D tensor"; + size_t int_workspace_size_in_bytes = + int_workspace_buffer->shape[0] * int_workspace_buffer->dtype.bits / 8; CHECK(handler_idx < max_num_handlers) << "The handler id must be less than " << max_num_handlers; // NOTE(Zihao): here we presume the input data type is half, in the future we should @@ -288,7 +292,8 @@ void _FlashInferAttentionPrefillWithPagedKVCacheBeginForward( DISPATCH_TVM_CUDA_IDTYPE(qo_indptr->dtype, dtype_idx, { cudaError_t status = batch_prefill_paged_kv_handlers[handler_idx].BeginForward( - static_cast(workspace_buffer->data), workspace_size_in_bytes, + static_cast(float_workspace_buffer->data), float_workspace_size_in_bytes, + static_cast(int_workspace_buffer->data), int_workspace_size_in_bytes, static_cast(qo_indptr->data) + qo_indptr->byte_offset / sizeof(dtype_idx), static_cast(kv_indptr->data) + kv_indptr->byte_offset / sizeof(dtype_idx), batch_size, num_qo_heads, num_kv_heads, head_dim, page_size); @@ -416,11 +421,16 @@ void _FlashInferAttentionDecodeWithPagedKVCache(int64_t handler_id, DLTensor* q_ } void _FlashInferAttentionDecodeWithPagedKVCacheBeginForward( - int64_t handler_idx, DLTensor* workspace_buffer, DLTensor* page_table_indptr, - DLTensor* last_page_len, int64_t num_qo_heads, int64_t num_kv_heads, int64_t head_dim, - int64_t page_size, int64_t pos_encoding_mode, TVMStreamHandle copy_stream) { - CHECK_EQ(workspace_buffer->ndim, 1) << "The workspace buffer must be a 1-D tensor"; - size_t workspace_size_in_bytes = workspace_buffer->shape[0] * workspace_buffer->dtype.bits / 8; + int64_t handler_idx, DLTensor* float_workspace_buffer, DLTensor* int_workspace_buffer, + DLTensor* page_table_indptr, DLTensor* last_page_len, int64_t num_qo_heads, + int64_t num_kv_heads, int64_t head_dim, int64_t page_size, int64_t pos_encoding_mode, + TVMStreamHandle copy_stream) { + CHECK_EQ(float_workspace_buffer->ndim, 1) << "The float workspace buffer must be a 1-D tensor"; + size_t float_workspace_size_in_bytes = + float_workspace_buffer->shape[0] * float_workspace_buffer->dtype.bits / 8; + CHECK_EQ(int_workspace_buffer->ndim, 1) << "The int workspace buffer must be a 1-D tensor"; + size_t int_workspace_size_in_bytes = + int_workspace_buffer->shape[0] * int_workspace_buffer->dtype.bits / 8; CHECK_LT(handler_idx, max_num_handlers) << "The handler id must be less than " << max_num_handlers; constexpr PageStorage page_storage = PageStorage::kIndices; @@ -433,8 +443,9 @@ void _FlashInferAttentionDecodeWithPagedKVCacheBeginForward( DISPATCH_TVM_CUDA_IDTYPE(page_table_indptr->dtype, dtype_idx, { cudaError_t status = BatchDecodeHandlerBeginForward( - batch_decode_handlers + handler_idx, static_cast(workspace_buffer->data), - workspace_size_in_bytes, + batch_decode_handlers + handler_idx, static_cast(float_workspace_buffer->data), + float_workspace_size_in_bytes, static_cast(int_workspace_buffer->data), + int_workspace_size_in_bytes, static_cast(page_table_indptr->data) + page_table_indptr->byte_offset / sizeof(dtype_idx), static_cast(last_page_len->data) + @@ -551,10 +562,15 @@ void _FlashInferAttentionPrefillWithRaggedKVCache( } void _FlashInferAttentionPrefillWithRaggedKVCacheBeginForward( - DLTensor* workspace_buffer, DLTensor* qo_indptr, DLTensor* kv_indptr, int64_t batch_size, - int64_t num_qo_heads, int64_t num_kv_heads, int64_t head_dim, TVMStreamHandle copy_stream) { - CHECK_EQ(workspace_buffer->ndim, 1) << "The workspace buffer must be a 1-D tensor"; - size_t workspace_size_in_bytes = workspace_buffer->shape[0] * workspace_buffer->dtype.bits / 8; + DLTensor* float_workspace_buffer, DLTensor* int_workspace_buffer, DLTensor* qo_indptr, + DLTensor* kv_indptr, int64_t batch_size, int64_t num_qo_heads, int64_t num_kv_heads, + int64_t head_dim, TVMStreamHandle copy_stream) { + CHECK_EQ(float_workspace_buffer->ndim, 1) << "The workspace buffer must be a 1-D tensor"; + size_t float_workspace_size_in_bytes = + float_workspace_buffer->shape[0] * float_workspace_buffer->dtype.bits / 8; + CHECK_EQ(int_workspace_buffer->ndim, 1) << "The workspace buffer must be a 1-D tensor"; + size_t int_workspace_size_in_bytes = + int_workspace_buffer->shape[0] * int_workspace_buffer->dtype.bits / 8; cudaStream_t original_stream = batch_prefill_ragged_kv_handler.GetCUDAStream(); batch_prefill_ragged_kv_handler.SetCUDAStream(static_cast(copy_stream)); @@ -564,7 +580,8 @@ void _FlashInferAttentionPrefillWithRaggedKVCacheBeginForward( DISPATCH_TVM_CUDA_IDTYPE(qo_indptr->dtype, dtype_idx, { cudaError_t status = batch_prefill_ragged_kv_handler.BeginForward( - static_cast(workspace_buffer->data), workspace_size_in_bytes, + static_cast(float_workspace_buffer->data), float_workspace_size_in_bytes, + static_cast(int_workspace_buffer->data), int_workspace_size_in_bytes, static_cast(qo_indptr->data) + qo_indptr->byte_offset / sizeof(dtype_idx), static_cast(kv_indptr->data) + kv_indptr->byte_offset / sizeof(dtype_idx), batch_size, num_qo_heads, num_kv_heads, head_dim,