diff --git a/include/flashinfer/attention/handler.cuh b/include/flashinfer/attention/handler.cuh index e89c762e..0d721314 100644 --- a/include/flashinfer/attention/handler.cuh +++ b/include/flashinfer/attention/handler.cuh @@ -309,8 +309,9 @@ class BatchDecodeHandler { template - cudaError_t BeginForwardDispatched(void* buffer, size_t workspace_size_in_bytes, IdType* indptr_h, - IdType* last_page_len_h, uint32_t batch_size, + cudaError_t BeginForwardDispatched(void* float_buffer, size_t float_workspace_size_in_bytes, + void* int_buffer, size_t int_workspace_size_in_bytes, + IdType* indptr_h, IdType* last_page_len_h, uint32_t batch_size, uint32_t num_qo_heads, uint32_t num_kv_heads, uint32_t page_size) { batch_size_before_partition_ = batch_size; @@ -337,44 +338,45 @@ class BatchDecodeHandler { size_t padded_batch_size = max_grid_size / num_kv_heads; if (split_kv) { padded_batch_size_ = padded_batch_size; - AlignedAllocator allocator(buffer, workspace_size_in_bytes); - tmp_v_ = allocator.aligned_alloc( + AlignedAllocator float_allocator(float_buffer, float_workspace_size_in_bytes); + tmp_v_ = float_allocator.aligned_alloc( num_qo_heads * padded_batch_size * HEAD_DIM * sizeof(DTypeOut), 16, "batch_decode_tmp_v"); - tmp_s_ = allocator.aligned_alloc(num_qo_heads * padded_batch_size * sizeof(float), - 16, "batch_decode_tmp_s"); - new_indptr_ = allocator.aligned_alloc((padded_batch_size + 1) * sizeof(IdType), 16, - "batch_decode_new_indptr"); + tmp_s_ = float_allocator.aligned_alloc( + num_qo_heads * padded_batch_size * sizeof(float), 16, "batch_decode_tmp_s"); + AlignedAllocator int_allocator(int_buffer, int_workspace_size_in_bytes); + new_indptr_ = int_allocator.aligned_alloc((padded_batch_size + 1) * sizeof(IdType), + 16, "batch_decode_new_indptr"); void* new_indptr_h_ = page_locked_buffer_; - new_last_page_len_ = allocator.aligned_alloc(padded_batch_size * sizeof(IdType), 16, - "batch_decode_new_last_page_len"); + new_last_page_len_ = int_allocator.aligned_alloc( + padded_batch_size * sizeof(IdType), 16, "batch_decode_new_last_page_len"); void* new_last_page_len_h_ = (char*)page_locked_buffer_ + ((char*)new_last_page_len_ - (char*)new_indptr_); - chunk_indptr_ = allocator.aligned_alloc((padded_batch_size + 1) * sizeof(IdType), - 16, "batch_decode_chunk_indptr"); + chunk_indptr_ = int_allocator.aligned_alloc( + (padded_batch_size + 1) * sizeof(IdType), 16, "batch_decode_chunk_indptr"); void* chunk_indptr_h_ = (char*)page_locked_buffer_ + ((char*)chunk_indptr_ - (char*)new_indptr_); - batch_idx_map_ = allocator.aligned_alloc(padded_batch_size * sizeof(IdType), 16, - "batch_decode_batch_idx_map"); + batch_idx_map_ = int_allocator.aligned_alloc(padded_batch_size * sizeof(IdType), 16, + "batch_decode_batch_idx_map"); void* batch_idx_map_h_ = (char*)page_locked_buffer_ + ((char*)batch_idx_map_ - (char*)new_indptr_); - chunk_start_pos_ = allocator.aligned_alloc(padded_batch_size * sizeof(IdType), 16, - "batch_decode_chunk_start_pos"); + chunk_start_pos_ = int_allocator.aligned_alloc(padded_batch_size * sizeof(IdType), + 16, "batch_decode_chunk_start_pos"); void* chunk_start_pos_h_ = (char*)page_locked_buffer_ + ((char*)chunk_start_pos_ - (char*)new_indptr_); - seq_lengths_before_partition_ = allocator.aligned_alloc( + seq_lengths_before_partition_ = int_allocator.aligned_alloc( padded_batch_size * sizeof(IdType), 16, "batch_decode_seq_lengths_before_partition"); void* seq_lengths_before_partition_h_ = (char*)page_locked_buffer_ + ((char*)seq_lengths_before_partition_ - (char*)new_indptr_); - block_valid_mask_ = allocator.aligned_alloc(padded_batch_size * sizeof(bool), 16, - "batch_decode_block_valid_mask"); + block_valid_mask_ = int_allocator.aligned_alloc( + padded_batch_size * sizeof(bool), 16, "batch_decode_block_valid_mask"); bool* block_valid_mask_h_ = (bool*)page_locked_buffer_ + ((bool*)block_valid_mask_ - (bool*)new_indptr_); std::fill(block_valid_mask_h_, block_valid_mask_h_ + padded_batch_size, 0); - size_t num_bytes_to_copy = (char*)allocator.ptr - (char*)new_indptr_; + size_t num_bytes_to_copy = (char*)int_allocator.ptr - (char*)new_indptr_; FLASHINFER_CUDA_CALL(PartitionPagedKVCacheComputeAuxiliaryInfo( max_num_pages_per_batch, batch_size, padded_batch_size, page_size, indptr_h, last_page_len_h, (IdType*)new_indptr_h_, (IdType*)new_last_page_len_h_, @@ -392,38 +394,39 @@ class BatchDecodeHandler { // do not pad the batch size when not using CUDAGraph padded_batch_size_ = batch_size_after_partition_; if (split_kv) { - AlignedAllocator allocator(buffer, workspace_size_in_bytes); - tmp_v_ = allocator.aligned_alloc( + AlignedAllocator float_allocator(float_buffer, float_workspace_size_in_bytes); + tmp_v_ = float_allocator.aligned_alloc( num_qo_heads * new_batch_size * HEAD_DIM * sizeof(DTypeOut), 16, "batch_decode_tmp_v"); - tmp_s_ = allocator.aligned_alloc(num_qo_heads * new_batch_size * sizeof(float), 16, - "batch_decode_tmp_s"); - new_indptr_ = allocator.aligned_alloc( + tmp_s_ = float_allocator.aligned_alloc( + num_qo_heads * new_batch_size * sizeof(float), 16, "batch_decode_tmp_s"); + AlignedAllocator int_allocator(int_buffer, int_workspace_size_in_bytes); + new_indptr_ = int_allocator.aligned_alloc( (batch_size_after_partition_ + 1) * sizeof(IdType), 16, "batch_decode_new_indptr"); void* new_indptr_h_ = page_locked_buffer_; - new_last_page_len_ = allocator.aligned_alloc( + new_last_page_len_ = int_allocator.aligned_alloc( batch_size_after_partition_ * sizeof(IdType), 16, "batch_decode_new_last_page_len"); void* new_last_page_len_h_ = (char*)page_locked_buffer_ + ((char*)new_last_page_len_ - (char*)new_indptr_); - chunk_indptr_ = allocator.aligned_alloc( + chunk_indptr_ = int_allocator.aligned_alloc( (batch_size_before_partition_ + 1) * sizeof(IdType), 16, "batch_decode_chunk_indptr"); void* chunk_indptr_h_ = (char*)page_locked_buffer_ + ((char*)chunk_indptr_ - (char*)new_indptr_); - batch_idx_map_ = allocator.aligned_alloc( + batch_idx_map_ = int_allocator.aligned_alloc( batch_size_after_partition_ * sizeof(IdType), 16, "batch_decode_batch_idx_map"); void* batch_idx_map_h_ = (char*)page_locked_buffer_ + ((char*)batch_idx_map_ - (char*)new_indptr_); - chunk_start_pos_ = allocator.aligned_alloc( + chunk_start_pos_ = int_allocator.aligned_alloc( batch_size_after_partition_ * sizeof(IdType), 16, "batch_decode_chunk_start_pos"); void* chunk_start_pos_h_ = (char*)page_locked_buffer_ + ((char*)chunk_start_pos_ - (char*)new_indptr_); seq_lengths_before_partition_ = - allocator.aligned_alloc(batch_size_after_partition_ * sizeof(IdType), 16, - "batch_decode_seq_lengths_before_partition"); + int_allocator.aligned_alloc(batch_size_after_partition_ * sizeof(IdType), 16, + "batch_decode_seq_lengths_before_partition"); void* seq_lengths_before_partition_h_ = (char*)page_locked_buffer_ + ((char*)seq_lengths_before_partition_ - (char*)new_indptr_); - size_t num_bytes_to_copy = (char*)allocator.ptr - (char*)new_indptr_; + size_t num_bytes_to_copy = (char*)int_allocator.ptr - (char*)new_indptr_; FLASHINFER_CUDA_CALL(PartitionPagedKVCacheComputeAuxiliaryInfo( max_num_pages_per_batch, batch_size, batch_size_after_partition_, page_size, indptr_h, last_page_len_h, (IdType*)new_indptr_h_, (IdType*)new_last_page_len_h_, @@ -458,9 +461,9 @@ class BatchDecodeHandler { bool IsForwardStarted() const { return forward_started_; } - void UpdatePageLockedBufferSize(size_t max_workspace_size_in_bytes) { + void UpdatePageLockedBufferSize(size_t int_workspace_size_in_bytes) { cudaFreeHost(page_locked_buffer_); - cudaMallocHost(&page_locked_buffer_, max_workspace_size_in_bytes); + cudaMallocHost(&page_locked_buffer_, int_workspace_size_in_bytes); } uint32_t GetBatchSizeBeforePartition() const { return batch_size_before_partition_; } @@ -655,15 +658,17 @@ class BatchPrefillHandler { bool IsForwardStarted() const { return request_indices_ != nullptr; } - void UpdatePageLockedBufferSize(size_t max_workspace_size_in_bytes) { + void UpdatePageLockedBufferSize(size_t int_workspace_size_in_bytes) { cudaFreeHost(page_locked_buffer_); - cudaMallocHost(&page_locked_buffer_, max_workspace_size_in_bytes); + cudaMallocHost(&page_locked_buffer_, int_workspace_size_in_bytes); } template - cudaError_t BeginForward(void* buffer, size_t workspace_size_in_bytes, IdType* qo_indptr_h, - IdType* kv_indptr_h, uint32_t batch_size, uint32_t num_qo_heads, - uint32_t num_kv_heads, uint32_t head_dim, uint32_t page_size) { + cudaError_t BeginForward(void* float_buffer, size_t float_workspace_size_in_bytes, + void* int_buffer, size_t int_workspace_size_in_bytes, + IdType* qo_indptr_h, IdType* kv_indptr_h, uint32_t batch_size, + uint32_t num_qo_heads, uint32_t num_kv_heads, uint32_t head_dim, + uint32_t page_size) { if (num_qo_heads % num_kv_heads != 0) { std::ostringstream err_msg; err_msg << "num_qo_heads " << num_qo_heads << " should be divisible by num_kv_heads " @@ -683,35 +688,35 @@ class BatchPrefillHandler { if (IsCUDAGraphEnabled()) { padded_batch_size_ = std::max(split_max_batch_size, total_num_tiles_q); - AlignedAllocator allocator(buffer, workspace_size_in_bytes); - request_indices_ = allocator.aligned_alloc(sizeof(IdType) * padded_batch_size_, 16, - "batch_prefill_request_indices"); + AlignedAllocator int_allocator(int_buffer, int_workspace_size_in_bytes); + request_indices_ = int_allocator.aligned_alloc(sizeof(IdType) * padded_batch_size_, 16, + "batch_prefill_request_indices"); void* request_indices_h_ = page_locked_buffer_; - qo_tile_indices_ = allocator.aligned_alloc(sizeof(IdType) * padded_batch_size_, 16, - "batch_prefill_qo_tile_indices"); + qo_tile_indices_ = int_allocator.aligned_alloc(sizeof(IdType) * padded_batch_size_, 16, + "batch_prefill_qo_tile_indices"); void* qo_tile_indices_h_ = (char*)page_locked_buffer_ + ((char*)qo_tile_indices_ - (char*)request_indices_); - kv_tile_indices_ = allocator.aligned_alloc(sizeof(IdType) * padded_batch_size_, 16, - "batch_prefill_kv_tile_indices"); + kv_tile_indices_ = int_allocator.aligned_alloc(sizeof(IdType) * padded_batch_size_, 16, + "batch_prefill_kv_tile_indices"); void* kv_tile_indices_h_ = (char*)page_locked_buffer_ + ((char*)kv_tile_indices_ - (char*)request_indices_); - o_indptr_ = allocator.aligned_alloc(sizeof(IdType) * (batch_size + 1), 16, - "batch_prefill_o_indptr"); + o_indptr_ = int_allocator.aligned_alloc(sizeof(IdType) * (batch_size + 1), 16, + "batch_prefill_o_indptr"); void* o_indptr_h_ = (char*)page_locked_buffer_ + ((char*)o_indptr_ - (char*)request_indices_); kv_chunk_size_ptr_ = - allocator.aligned_alloc(sizeof(IdType), 1, "batch_prefill_kv_chunk_size_ptr"); + int_allocator.aligned_alloc(sizeof(IdType), 1, "batch_prefill_kv_chunk_size_ptr"); void* kv_chunk_size_ptr_h_ = (char*)page_locked_buffer_ + ((char*)kv_chunk_size_ptr_ - (char*)request_indices_); *(IdType*)kv_chunk_size_ptr_h_ = kv_chunk_size; if (total_num_tiles_q < split_max_batch_size) { // need merge_indptr - merge_indptr_ = allocator.aligned_alloc(sizeof(IdType) * (total_num_rows_ + 1), 16, - "batch_prefill_merge_indptr"); + merge_indptr_ = int_allocator.aligned_alloc(sizeof(IdType) * (total_num_rows_ + 1), + 16, "batch_prefill_merge_indptr"); void* merge_indptr_h_ = (char*)page_locked_buffer_ + ((char*)merge_indptr_ - (char*)request_indices_); std::copy(merge_indptr_vec.begin(), merge_indptr_vec.end(), (IdType*)merge_indptr_h_); - block_valid_mask_ = allocator.aligned_alloc(sizeof(bool) * padded_batch_size_, 16, - "batch_prefill_block_valid_mask"); + block_valid_mask_ = int_allocator.aligned_alloc(sizeof(bool) * padded_batch_size_, 16, + "batch_prefill_block_valid_mask"); bool* block_valid_mask_h_ = (bool*)page_locked_buffer_ + ((bool*)block_valid_mask_ - (bool*)request_indices_); for (uint32_t i = 0; i < padded_batch_size_; ++i) { @@ -731,15 +736,16 @@ class BatchPrefillHandler { (IdType*)kv_tile_indices_h_); std::copy(o_indptr_vec.begin(), o_indptr_vec.end(), (IdType*)o_indptr_h_); - size_t num_bytes_to_copy = (char*)allocator.ptr - (char*)request_indices_; + size_t num_bytes_to_copy = (char*)int_allocator.ptr - (char*)request_indices_; FLASHINFER_CUDA_CALL(cudaMemcpyAsync(request_indices_, page_locked_buffer_, num_bytes_to_copy, cudaMemcpyHostToDevice, stream_)) if (total_num_tiles_q < split_max_batch_size) { - tmp_v_ = allocator.aligned_alloc( + AlignedAllocator float_allocator(float_buffer, float_workspace_size_in_bytes); + tmp_v_ = float_allocator.aligned_alloc( num_qo_heads * split_max_batch_size * qo_tile_size * head_dim * sizeof(DTypeOut), 16, "batch_prefill_tmp_v"); - tmp_s_ = allocator.aligned_alloc( + tmp_s_ = float_allocator.aligned_alloc( num_qo_heads * split_max_batch_size * qo_tile_size * sizeof(float), 16, "batch_prefill_tmp_s"); } else { @@ -748,31 +754,31 @@ class BatchPrefillHandler { } } else { padded_batch_size_ = new_batch_size; - AlignedAllocator allocator(buffer, workspace_size_in_bytes); - request_indices_ = allocator.aligned_alloc(sizeof(IdType) * request_indices_vec.size(), - 16, "batch_prefill_request_indices"); + AlignedAllocator int_allocator(int_buffer, int_workspace_size_in_bytes); + request_indices_ = int_allocator.aligned_alloc( + sizeof(IdType) * request_indices_vec.size(), 16, "batch_prefill_request_indices"); void* request_indices_h_ = page_locked_buffer_; - qo_tile_indices_ = allocator.aligned_alloc(sizeof(IdType) * qo_tile_indices_vec.size(), - 16, "batch_prefill_qo_tile_indices"); + qo_tile_indices_ = int_allocator.aligned_alloc( + sizeof(IdType) * qo_tile_indices_vec.size(), 16, "batch_prefill_qo_tile_indices"); void* qo_tile_indices_h_ = (char*)page_locked_buffer_ + ((char*)qo_tile_indices_ - (char*)request_indices_); - kv_tile_indices_ = allocator.aligned_alloc(sizeof(IdType) * kv_tile_indices_vec.size(), - 16, "batch_prefill_kv_tile_indices"); + kv_tile_indices_ = int_allocator.aligned_alloc( + sizeof(IdType) * kv_tile_indices_vec.size(), 16, "batch_prefill_kv_tile_indices"); void* kv_tile_indices_h_ = (char*)page_locked_buffer_ + ((char*)kv_tile_indices_ - (char*)request_indices_); if (split_kv) { // need merge_indptr when split_kv is true - merge_indptr_ = allocator.aligned_alloc(sizeof(IdType) * merge_indptr_vec.size(), 16, - "batch_prefill_merge_indptr"); + merge_indptr_ = int_allocator.aligned_alloc(sizeof(IdType) * merge_indptr_vec.size(), + 16, "batch_prefill_merge_indptr"); void* merge_indptr_h_ = (char*)page_locked_buffer_ + ((char*)merge_indptr_ - (char*)request_indices_); std::copy(merge_indptr_vec.begin(), merge_indptr_vec.end(), (IdType*)merge_indptr_h_); } - o_indptr_ = allocator.aligned_alloc(sizeof(IdType) * o_indptr_vec.size(), 16, - "batch_prefill_o_indptr"); + o_indptr_ = int_allocator.aligned_alloc(sizeof(IdType) * o_indptr_vec.size(), 16, + "batch_prefill_o_indptr"); void* o_indptr_h_ = (char*)page_locked_buffer_ + ((char*)o_indptr_ - (char*)request_indices_); kv_chunk_size_ptr_ = - allocator.aligned_alloc(sizeof(IdType), 1, "batch_prefill_kv_chunk_size_ptr"); + int_allocator.aligned_alloc(sizeof(IdType), 1, "batch_prefill_kv_chunk_size_ptr"); void* kv_chunk_size_ptr_h_ = (char*)page_locked_buffer_ + ((char*)kv_chunk_size_ptr_ - (char*)request_indices_); *(IdType*)kv_chunk_size_ptr_h_ = kv_chunk_size; @@ -783,16 +789,17 @@ class BatchPrefillHandler { std::copy(kv_tile_indices_vec.begin(), kv_tile_indices_vec.end(), (IdType*)kv_tile_indices_h_); std::copy(o_indptr_vec.begin(), o_indptr_vec.end(), (IdType*)o_indptr_h_); - size_t num_bytes_to_copy = (char*)allocator.ptr - (char*)request_indices_; + size_t num_bytes_to_copy = (char*)int_allocator.ptr - (char*)request_indices_; FLASHINFER_CUDA_CALL(cudaMemcpyAsync(request_indices_, page_locked_buffer_, num_bytes_to_copy, cudaMemcpyHostToDevice, stream_)) if (split_kv) { - tmp_v_ = allocator.aligned_alloc( + AlignedAllocator float_allocator(float_buffer, float_workspace_size_in_bytes); + tmp_v_ = float_allocator.aligned_alloc( num_qo_heads * new_batch_size * qo_tile_size * head_dim * sizeof(DTypeOut), 16, "batch_prefill_tmp_v"); - tmp_s_ = allocator.aligned_alloc( + tmp_s_ = float_allocator.aligned_alloc( num_qo_heads * new_batch_size * qo_tile_size * sizeof(float), 16, "batch_prefill_tmp_s"); } else { diff --git a/python/csrc/activation.cu b/python/csrc/activation.cu index e4dcf7c0..4830334b 100644 --- a/python/csrc/activation.cu +++ b/python/csrc/activation.cu @@ -51,8 +51,7 @@ void gelu_tanh_and_mul(torch::Tensor& out, torch::Tensor& input) { DISPATCH_PYTORCH_DTYPE_TO_CTYPE_FP16(input.scalar_type(), c_type, [&] { uint32_t vec_size = 16 / sizeof(c_type); dim3 block(std::min(d / vec_size, 1024U)); - flashinfer::activation::act_and_mul_kernel + flashinfer::activation::act_and_mul_kernel <<>>(static_cast(out.data_ptr()), static_cast(input.data_ptr()), d); diff --git a/python/csrc/batch_decode.cu b/python/csrc/batch_decode.cu index 94365376..92ce222f 100644 --- a/python/csrc/batch_decode.cu +++ b/python/csrc/batch_decode.cu @@ -21,22 +21,28 @@ using namespace flashinfer; void BatchDecodeWithPagedKVCachePyTorchWrapper::BeginForward( - torch::Tensor workspace_buffer, torch::Tensor indptr, torch::Tensor last_page_len, - unsigned int batch_size, unsigned int num_qo_heads, unsigned int num_kv_heads, - unsigned int head_dim, unsigned int page_size, unsigned int pos_encoding_mode, - float logits_soft_cap, torch::Tensor empty_q_data, torch::Tensor empty_kv_data) { - CHECK_INPUT(workspace_buffer); + torch::Tensor float_workspace_buffer, torch::Tensor int_workspace_buffer, torch::Tensor indptr, + torch::Tensor last_page_len, unsigned int batch_size, unsigned int num_qo_heads, + unsigned int num_kv_heads, unsigned int head_dim, unsigned int page_size, + unsigned int pos_encoding_mode, float logits_soft_cap, torch::Tensor empty_q_data, + torch::Tensor empty_kv_data) { + CHECK_INPUT(float_workspace_buffer); + CHECK_INPUT(int_workspace_buffer); // NOTE(zihao): not necessary to be CUDA tensor CHECK_CONTIGUOUS(indptr); CHECK_CONTIGUOUS(last_page_len); CHECK_DIM(1, indptr); CHECK_DIM(1, last_page_len); - CHECK_DIM(1, workspace_buffer); + CHECK_DIM(1, float_workspace_buffer); + CHECK_DIM(1, int_workspace_buffer); CHECK_EQ(indptr.scalar_type(), torch::kInt32); CHECK_EQ(indptr.scalar_type(), torch::kInt32); CHECK_GQA_HEAD_DIVISIBLE(num_qo_heads, num_kv_heads); - size_t workspace_size_in_bytes = workspace_buffer.size(0) * workspace_buffer.element_size(); - auto device = workspace_buffer.device(); + size_t float_workspace_size_in_bytes = + float_workspace_buffer.size(0) * float_workspace_buffer.element_size(); + size_t int_workspace_size_in_bytes = + int_workspace_buffer.size(0) * int_workspace_buffer.element_size(); + auto device = float_workspace_buffer.device(); cudaStream_t torch_current_stream = c10::cuda::getCurrentCUDAStream(device.index()); handler_->SetCUDAStream(torch_current_stream); indptr = indptr.to(torch::kCPU); @@ -59,8 +65,10 @@ void BatchDecodeWithPagedKVCachePyTorchWrapper::BeginForward( handler_->BeginForwardDispatched( - static_cast(workspace_buffer.data_ptr()), workspace_size_in_bytes, - static_cast(indptr.data_ptr()), + static_cast(float_workspace_buffer.data_ptr()), + float_workspace_size_in_bytes, + static_cast(int_workspace_buffer.data_ptr()), + int_workspace_size_in_bytes, static_cast(indptr.data_ptr()), static_cast(last_page_len.data_ptr()), batch_size, num_qo_heads, num_kv_heads, page_size); TORCH_CHECK(status == cudaSuccess, "BatchDecodeWithPagedKVCache failed with error ", @@ -81,8 +89,10 @@ void BatchDecodeWithPagedKVCachePyTorchWrapper::BeginForward( handler_->BeginForwardDispatched( - static_cast(workspace_buffer.data_ptr()), workspace_size_in_bytes, - static_cast(indptr.data_ptr()), + static_cast(float_workspace_buffer.data_ptr()), + float_workspace_size_in_bytes, + static_cast(int_workspace_buffer.data_ptr()), + int_workspace_size_in_bytes, static_cast(indptr.data_ptr()), static_cast(last_page_len.data_ptr()), batch_size, num_qo_heads, num_kv_heads, page_size); TORCH_CHECK(status == cudaSuccess, @@ -100,8 +110,8 @@ void BatchDecodeWithPagedKVCachePyTorchWrapper::BeginForward( void BatchDecodeWithPagedKVCachePyTorchWrapper::EndForward() { handler_->EndForward(); } void BatchDecodeWithPagedKVCachePyTorchWrapper::UpdatePageLockedBufferSize( - unsigned int max_workspace_size_in_bytes) { - handler_->UpdatePageLockedBufferSize(max_workspace_size_in_bytes); + unsigned int int_workspace_size_in_bytes) { + handler_->UpdatePageLockedBufferSize(int_workspace_size_in_bytes); } std::vector BatchDecodeWithPagedKVCachePyTorchWrapper::Forward( diff --git a/python/csrc/batch_prefill.cu b/python/csrc/batch_prefill.cu index 85c4b3ca..86bea032 100644 --- a/python/csrc/batch_prefill.cu +++ b/python/csrc/batch_prefill.cu @@ -21,29 +21,36 @@ using namespace flashinfer; void BatchPrefillWithPagedKVCachePyTorchWrapper::BeginForward( - torch::Tensor workspace_buffer, torch::Tensor qo_indptr, torch::Tensor paged_kv_indptr, - unsigned int batch_size, unsigned int num_qo_heads, unsigned int num_kv_heads, - unsigned int head_dim, unsigned int page_size, torch::Tensor empty_q_data) { - CHECK_INPUT(workspace_buffer); + torch::Tensor float_workspace_buffer, torch::Tensor int_workspace_buffer, + torch::Tensor qo_indptr, torch::Tensor paged_kv_indptr, unsigned int batch_size, + unsigned int num_qo_heads, unsigned int num_kv_heads, unsigned int head_dim, + unsigned int page_size, torch::Tensor empty_q_data) { + CHECK_INPUT(float_workspace_buffer); + CHECK_INPUT(int_workspace_buffer); // NOTE(Zihao): not necessary to be a CUDA tensor CHECK_CONTIGUOUS(qo_indptr); CHECK_CONTIGUOUS(paged_kv_indptr); CHECK_GQA_HEAD_DIVISIBLE(num_qo_heads, num_kv_heads); CHECK_DIM(1, qo_indptr); CHECK_DIM(1, paged_kv_indptr); - CHECK_DIM(1, workspace_buffer); + CHECK_DIM(1, float_workspace_buffer); + CHECK_DIM(1, int_workspace_buffer); CHECK_EQ(qo_indptr.size(0), batch_size + 1); CHECK_EQ(paged_kv_indptr.size(0), batch_size + 1); qo_indptr = qo_indptr.to(torch::dtype(torch::kInt32).device(torch::kCPU)); paged_kv_indptr = paged_kv_indptr.to(torch::dtype(torch::kInt32).device(torch::kCPU)); - auto device = workspace_buffer.device(); - size_t workspace_size_in_bytes = workspace_buffer.size(0) * workspace_buffer.element_size(); + auto device = float_workspace_buffer.device(); + size_t float_workspace_size_in_bytes = + float_workspace_buffer.size(0) * float_workspace_buffer.element_size(); + size_t int_workspace_size_in_bytes = + int_workspace_buffer.size(0) * int_workspace_buffer.element_size(); cudaStream_t torch_current_stream = c10::cuda::getCurrentCUDAStream(device.index()); handler_->SetCUDAStream(torch_current_stream); DISPATCH_PYTORCH_DTYPE_TO_CTYPE_FP16(empty_q_data.scalar_type(), q_type, [&] { cudaError_t status = handler_->BeginForward( - static_cast(workspace_buffer.data_ptr()), workspace_size_in_bytes, + static_cast(float_workspace_buffer.data_ptr()), float_workspace_size_in_bytes, + static_cast(int_workspace_buffer.data_ptr()), int_workspace_size_in_bytes, static_cast(qo_indptr.data_ptr()), static_cast(paged_kv_indptr.data_ptr()), batch_size, num_qo_heads, num_kv_heads, head_dim, page_size); @@ -56,8 +63,8 @@ void BatchPrefillWithPagedKVCachePyTorchWrapper::BeginForward( void BatchPrefillWithPagedKVCachePyTorchWrapper::EndForward() { handler_->EndForward(); } void BatchPrefillWithPagedKVCachePyTorchWrapper::UpdatePageLockedBufferSize( - unsigned int max_workspace_size_in_bytes) { - handler_->UpdatePageLockedBufferSize(max_workspace_size_in_bytes); + unsigned int int_workspace_size_in_bytes) { + handler_->UpdatePageLockedBufferSize(int_workspace_size_in_bytes); } std::vector BatchPrefillWithPagedKVCachePyTorchWrapper::Forward( @@ -446,28 +453,35 @@ std::vector BatchPrefillWithPagedKVCachePyTorchWrapper::ForwardCu } void BatchPrefillWithRaggedKVCachePyTorchWrapper::BeginForward( - torch::Tensor workspace_buffer, torch::Tensor qo_indptr, torch::Tensor kv_indptr, - unsigned int batch_size, unsigned int num_qo_heads, unsigned int num_kv_heads, - unsigned int head_dim, torch::Tensor empty_q_data) { - CHECK_INPUT(workspace_buffer); + torch::Tensor float_workspace_buffer, torch::Tensor int_workspace_buffer, + torch::Tensor qo_indptr, torch::Tensor kv_indptr, unsigned int batch_size, + unsigned int num_qo_heads, unsigned int num_kv_heads, unsigned int head_dim, + torch::Tensor empty_q_data) { + CHECK_INPUT(float_workspace_buffer); + CHECK_INPUT(int_workspace_buffer); // NOTE(Zihao): not necessary to be a CUDA tensor CHECK_CONTIGUOUS(qo_indptr); CHECK_GQA_HEAD_DIVISIBLE(num_qo_heads, num_kv_heads); CHECK_DIM(1, qo_indptr); CHECK_DIM(1, kv_indptr); - CHECK_DIM(1, workspace_buffer); + CHECK_DIM(1, float_workspace_buffer); + CHECK_DIM(1, int_workspace_buffer); CHECK_EQ(qo_indptr.size(0), batch_size + 1); CHECK_EQ(kv_indptr.size(0), batch_size + 1); qo_indptr = qo_indptr.to(torch::dtype(torch::kInt32).device(torch::kCPU)); kv_indptr = kv_indptr.to(torch::dtype(torch::kInt32).device(torch::kCPU)); - size_t workspace_size_in_bytes = workspace_buffer.size(0) * workspace_buffer.element_size(); - auto device = workspace_buffer.device(); + size_t float_workspace_size_in_bytes = + float_workspace_buffer.size(0) * float_workspace_buffer.element_size(); + size_t int_workspace_size_in_bytes = + int_workspace_buffer.size(0) * int_workspace_buffer.element_size(); + auto device = float_workspace_buffer.device(); cudaStream_t torch_current_stream = c10::cuda::getCurrentCUDAStream(device.index()); handler_->SetCUDAStream(torch_current_stream); DISPATCH_PYTORCH_DTYPE_TO_CTYPE_FP16(empty_q_data.scalar_type(), q_type, [&] { cudaError_t status = handler_->BeginForward( - static_cast(workspace_buffer.data_ptr()), workspace_size_in_bytes, + static_cast(float_workspace_buffer.data_ptr()), float_workspace_size_in_bytes, + static_cast(int_workspace_buffer.data_ptr()), int_workspace_size_in_bytes, static_cast(qo_indptr.data_ptr()), static_cast(kv_indptr.data_ptr()), batch_size, num_qo_heads, num_kv_heads, head_dim, /*page_size=*/1); @@ -480,8 +494,8 @@ void BatchPrefillWithRaggedKVCachePyTorchWrapper::BeginForward( void BatchPrefillWithRaggedKVCachePyTorchWrapper::EndForward() { handler_->EndForward(); } void BatchPrefillWithRaggedKVCachePyTorchWrapper::UpdatePageLockedBufferSize( - unsigned int max_workspace_size_in_bytes) { - handler_->UpdatePageLockedBufferSize(max_workspace_size_in_bytes); + unsigned int int_workspace_size_in_bytes) { + handler_->UpdatePageLockedBufferSize(int_workspace_size_in_bytes); } std::vector BatchPrefillWithRaggedKVCachePyTorchWrapper::Forward( diff --git a/python/csrc/flashinfer_ops_decode.h b/python/csrc/flashinfer_ops_decode.h index 1f955a7f..0c383c04 100644 --- a/python/csrc/flashinfer_ops_decode.h +++ b/python/csrc/flashinfer_ops_decode.h @@ -28,13 +28,13 @@ torch::Tensor single_decode_with_kv_cache(torch::Tensor q, torch::Tensor k, torc class BatchDecodeWithPagedKVCachePyTorchWrapper { public: - void BeginForward(torch::Tensor workspace_buffer, torch::Tensor indptr, - torch::Tensor last_page_len, unsigned int batch_size, unsigned int num_qo_heads, - unsigned int num_kv_heads, unsigned int head_dim, unsigned int page_size, - unsigned int pos_encoding_mode, float logits_soft_cap, + void BeginForward(torch::Tensor float_workspace_buffer, torch::Tensor int_workspace_buffer, + torch::Tensor indptr, torch::Tensor last_page_len, unsigned int batch_size, + unsigned int num_qo_heads, unsigned int num_kv_heads, unsigned int head_dim, + unsigned int page_size, unsigned int pos_encoding_mode, float logits_soft_cap, torch::Tensor empty_q_data, torch::Tensor empty_kv_data); void EndForward(); - void UpdatePageLockedBufferSize(uint32_t max_workspace_size_in_bytes); + void UpdatePageLockedBufferSize(uint32_t int_workspace_size_in_bytes); bool IsCUDAGraphEnabled() const { return handler_->IsCUDAGraphEnabled(); } std::vector Forward(torch::Tensor q, std::optional paged_kv_cache, std::optional paged_k_cache, diff --git a/python/csrc/flashinfer_ops_prefill.h b/python/csrc/flashinfer_ops_prefill.h index 949da9ae..b895f8b1 100644 --- a/python/csrc/flashinfer_ops_prefill.h +++ b/python/csrc/flashinfer_ops_prefill.h @@ -34,13 +34,13 @@ std::vector single_prefill_with_kv_cache_custom_mask( class BatchPrefillWithPagedKVCachePyTorchWrapper { public: - void BeginForward(torch::Tensor workspace_buffer, torch::Tensor qo_indptr, - torch::Tensor page_kv_indptr, unsigned int batch_size, + void BeginForward(torch::Tensor float_workspace_buffer, torch::Tensor int_workspace_buffer, + torch::Tensor qo_indptr, torch::Tensor page_kv_indptr, unsigned int batch_size, unsigned int num_qo_heads, unsigned int num_kv_heads, unsigned int head_dim, unsigned page_size, torch::Tensor empty_q_data); void EndForward(); bool IsCUDAGraphEnabled() const { return handler_->IsCUDAGraphEnabled(); } - void UpdatePageLockedBufferSize(uint32_t max_workspace_size_in_bytes); + void UpdatePageLockedBufferSize(uint32_t int_workspace_size_in_bytes); std::vector Forward(torch::Tensor q, torch::Tensor qo_indptr, std::optional paged_kv_cache, std::optional paged_k_cache, @@ -69,12 +69,13 @@ class BatchPrefillWithPagedKVCachePyTorchWrapper { class BatchPrefillWithRaggedKVCachePyTorchWrapper { public: - void BeginForward(torch::Tensor workspace_buffer, torch::Tensor qo_indptr, - torch::Tensor kv_indptr, unsigned int batch_size, unsigned int num_qo_heads, - unsigned int num_kv_heads, unsigned int head_dim, torch::Tensor empty_q_data); + void BeginForward(torch::Tensor float_workspace_buffer, torch::Tensor int_workspace_buffer, + torch::Tensor qo_indptr, torch::Tensor kv_indptr, unsigned int batch_size, + unsigned int num_qo_heads, unsigned int num_kv_heads, unsigned int head_dim, + torch::Tensor empty_q_data); void EndForward(); bool IsCUDAGraphEnabled() const { return handler_->IsCUDAGraphEnabled(); } - void UpdatePageLockedBufferSize(uint32_t max_workspace_size_in_bytes); + void UpdatePageLockedBufferSize(uint32_t int_workspace_size_in_bytes); std::vector Forward(torch::Tensor q, torch::Tensor qo_indptr, torch::Tensor k, torch::Tensor v, torch::Tensor kv_indptr, bool causal, unsigned int pos_encoding_mode, bool allow_fp16_qk_reduction, diff --git a/python/flashinfer/cascade.py b/python/flashinfer/cascade.py index 1a2ad7b4..148b54f6 100644 --- a/python/flashinfer/cascade.py +++ b/python/flashinfer/cascade.py @@ -257,22 +257,32 @@ class BatchDecodeWithSharedPrefixPagedKVCacheWrapper: manages the lifecycle of these data structures. """ - def __init__(self, workspace_buffer: torch.Tensor, kv_layout: str = "NHD") -> None: + def __init__( + self, float_workspace_buffer: torch.Tensor, kv_layout: str = "NHD" + ) -> None: self._batch_decode_wrapper = BatchDecodeWithPagedKVCacheWrapper( - workspace_buffer, kv_layout + float_workspace_buffer, kv_layout ) self._kv_layout = kv_layout - def reset_workspace_buffer(self, new_workspace_buffer: torch.Tensor) -> None: + def reset_workspace_buffer( + self, float_workspace_buffer: torch.Tensor, int_workspace_buffer + ) -> None: r"""Reset the workspace buffer. Parameters ---------- - new_workspace_buffer : torch.Tensor - The new workspace buffer, the device of the new workspace buffer should + float_workspace_buffer : torch.Tensor + The new float workspace buffer, the device of the new float workspace buffer should + be the same as the device of the input tensors. + + int_workspace_buffer : torch.Tensor + The new int workspace buffer, the device of the new int workspace buffer should be the same as the device of the input tensors. """ - self._batch_decode_wrapper.reset_workspace_buffer(new_workspace_buffer) + self._batch_decode_wrapper.reset_workspace_buffer( + float_workspace_buffer, int_workspace_buffer + ) def begin_forward( self, @@ -503,33 +513,43 @@ class BatchPrefillWithSharedPrefixPagedKVCacheWrapper: layers). This wrapper class manages the lifecycle of these data structures. """ - def __init__(self, workspace_buffer: torch.Tensor, kv_layout: str = "NHD") -> None: + def __init__( + self, float_workspace_buffer: torch.Tensor, kv_layout: str = "NHD" + ) -> None: r"""Constructor of :class:`BatchDecodeWithSharedPrefixPagedKVCacheWrapper`. Parameters ---------- - workspace_buffer : torch.Tensor - The user reserved workspace buffer used to store auxiliary data structures, - recommended size is 128MB, the device of the workspace buffer should be the - same as the device of the input tensors. + float_workspace_buffer : torch.Tensor + The user reserved float workspace buffer used to store intermediate attention results + in the split-k algorithm. The recommended size is 128MB, the device of the workspace + buffer should be the same as the device of the input tensors. kv_layout : str The layout of the input k/v tensors, could be either ``NHD`` or ``HND``. """ self._batch_prefill_wrapper = BatchPrefillWithPagedKVCacheWrapper( - workspace_buffer, kv_layout + float_workspace_buffer, kv_layout ) self._kv_layout = kv_layout - def reset_workspace_buffer(self, new_workspace_buffer: torch.Tensor) -> None: + def reset_workspace_buffer( + self, float_workspace_buffer: torch.Tensor, int_workspace_buffer + ) -> None: r"""Reset the workspace buffer. Parameters ---------- - new_workspace_buffer : torch.Tensor - The new workspace buffer, the device of the new workspace buffer should + float_workspace_buffer : torch.Tensor + The new float workspace buffer, the device of the new float workspace buffer should + be the same as the device of the input tensors. + + int_workspace_buffer : torch.Tensor + The new int workspace buffer, the device of the new int workspace buffer should be the same as the device of the input tensors. """ - self._batch_prefill_wrapper.reset_workspace_buffer(new_workspace_buffer) + self._batch_prefill_wrapper.reset_workspace_buffer( + float_workspace_buffer, int_workspace_buffer + ) def begin_forward( self, diff --git a/python/flashinfer/decode.py b/python/flashinfer/decode.py index 123ca8b6..ebb652f8 100644 --- a/python/flashinfer/decode.py +++ b/python/flashinfer/decode.py @@ -279,7 +279,7 @@ class BatchDecodeWithPagedKVCacheWrapper: def __init__( self, - workspace_buffer: torch.Tensor, + float_workspace_buffer: torch.Tensor, kv_layout: str = "NHD", use_cuda_graph: bool = False, use_tensor_cores: bool = False, @@ -291,9 +291,9 @@ def __init__( Parameters ---------- - workspace_buffer : torch.Tensor - The user reserved workspace buffer used to store auxiliary data structures, - recommended size is 128MB, the device of the workspace + float_workspace_buffer : torch.Tensor + The user reserved float workspace buffer used to store intermediate attention results + in the split-k algorithm. The recommended size is 128MB, the device of the workspace buffer should be the same as the device of the input tensors. kv_layout : str @@ -326,7 +326,10 @@ def __init__( """ _check_kv_layout(kv_layout) self._kv_layout = kv_layout - self._workspace_buffer = workspace_buffer + self._float_workspace_buffer = float_workspace_buffer + self._int_workspace_buffer = torch.empty( + (8 * 1024 * 1024,), dtype=torch.uint8, device=float_workspace_buffer.device + ) if use_cuda_graph: if not torch.is_tensor(paged_kv_indptr_buffer): @@ -363,7 +366,7 @@ def __init__( self._qo_indptr_buf = torch.arange( self._fixed_batch_size + 1, dtype=torch.int32, - device=workspace_buffer.device, + device=float_workspace_buffer.device, ) else: self._use_tensor_cores = False @@ -381,16 +384,26 @@ def use_tensor_cores(self) -> bool: def is_cuda_graph_enabled(self) -> bool: return self._wrapper.is_cuda_graph_enabled() - def reset_workspace_buffer(self, new_workspace_buffer: torch.Tensor) -> None: + def reset_workspace_buffer( + self, float_workspace_buffer: torch.Tensor, int_workspace_buffer + ) -> None: r"""Reset the workspace buffer. Parameters ---------- - new_workspace_buffer : torch.Tensor - The new workspace buffer, the device of the new workspace buffer should + float_workspace_buffer : torch.Tensor + The new float workspace buffer, the device of the new float workspace buffer should + be the same as the device of the input tensors. + + int_workspace_buffer : torch.Tensor + The new int workspace buffer, the device of the new int workspace buffer should be the same as the device of the input tensors. """ - self._workspace_buffer = new_workspace_buffer + self._float_workspace_buffer = float_workspace_buffer + self._int_workspace_buffer = int_workspace_buffer + self._wrapper.update_page_locked_buffer_size( + int_workspace_buffer.numel() * int_workspace_buffer.element_size() + ) def begin_forward( self, @@ -511,7 +524,8 @@ def begin_forward( batch_size + 1, dtype=torch.int32, device=indptr.device ) self._wrapper.begin_forward( - self._workspace_buffer, + self._float_workspace_buffer, + self._int_workspace_buffer, self._qo_indptr_buf, indptr, batch_size, @@ -523,7 +537,8 @@ def begin_forward( ) else: self._wrapper.begin_forward( - self._workspace_buffer, + self._float_workspace_buffer, + self._int_workspace_buffer, indptr, last_page_len, batch_size, diff --git a/python/flashinfer/prefill.py b/python/flashinfer/prefill.py index 8c626d2c..8e01ee48 100644 --- a/python/flashinfer/prefill.py +++ b/python/flashinfer/prefill.py @@ -541,7 +541,7 @@ class BatchPrefillWithPagedKVCacheWrapper: def __init__( self, - workspace_buffer: torch.Tensor, + float_workspace_buffer: torch.Tensor, kv_layout: str = "NHD", use_cuda_graph: bool = False, qo_indptr_buf: Optional[torch.Tensor] = None, @@ -555,10 +555,10 @@ def __init__( Parameters ---------- - workspace_buffer : torch.Tensor - The user reserved workspace buffer used to store auxiliary data structures, - recommended size is 128MB, the device of the workspace buffer should be the - same as the device of the input tensors. + float_workspace_buffer : torch.Tensor + The user reserved workspace buffer used to store intermediate attention results in + split-k algorithm. The recommended size is 128MB, the device of the workspace buffer + should be the same as the device of the input tensors. kv_layout : str The layout of the input k/v tensors, could be either ``NHD`` or ``HND``. @@ -603,7 +603,10 @@ def __init__( """ _check_kv_layout(kv_layout) self._kv_layout = kv_layout - self._workspace_buffer = workspace_buffer + self._float_workspace_buffer = float_workspace_buffer + self._int_workspace_buffer = torch.empty( + (8 * 1024 * 1024,), dtype=torch.uint8, device=float_workspace_buffer.device + ) self._wrapper = _prefill.BatchPrefillWithPagedKVCachePyTorchWrapper( TensorLayout[kv_layout].value, use_cuda_graph, @@ -649,16 +652,26 @@ def __init__( def is_cuda_graph_enabled(self) -> bool: return self._wrapper.is_cuda_graph_enabled() - def reset_workspace_buffer(self, new_workspace_buffer: torch.Tensor) -> None: + def reset_workspace_buffer( + self, float_workspace_buffer: torch.Tensor, int_workspace_buffer: torch.Tensor + ) -> None: r"""Reset the workspace buffer. Parameters ---------- - new_workspace_buffer : torch.Tensor - The new workspace buffer, the device of the new workspace buffer should + float_workspace_buffer : torch.Tensor + The new float workspace buffer, the device of the new float workspace buffer should + be the same as the device of the input tensors. + + int_workspace_buffer : torch.Tensor + The new int workspace buffer, the device of the new int workspace buffer should be the same as the device of the input tensors. """ - self._workspace_buffer = new_workspace_buffer + self._float_workspace_buffer = float_workspace_buffer + self._int_workspace_buffer = int_workspace_buffer + self._wrapper.update_page_locked_buffer_size( + int_workspace_buffer.numel() * int_workspace_buffer.element_size() + ) def begin_forward( self, @@ -789,7 +802,8 @@ def begin_forward( ), ) self._wrapper.begin_forward( - self._workspace_buffer, + self._float_workspace_buffer, + self._int_workspace_buffer, qo_indptr, paged_kv_indptr, batch_size, @@ -1176,7 +1190,7 @@ class BatchPrefillWithRaggedKVCacheWrapper: def __init__( self, - workspace_buffer: torch.Tensor, + float_workspace_buffer: torch.Tensor, kv_layout: str = "NHD", use_cuda_graph: bool = False, qo_indptr_buf: Optional[torch.Tensor] = None, @@ -1188,10 +1202,10 @@ def __init__( Parameters ---------- - workspace_buffer : torch.Tensor - The user reserved workspace buffer used to store auxiliary data structures, - recommended size is 128MB, the device of the workspace buffer should be the - same as the device of the input tensors. + float_workspace_buffer : torch.Tensor + The user reserved float workspace buffer used to store intermediate attention results + in the split-k algorithm. The recommended size is 128MB, the device of the workspace + buffer should be the same as the device of the input tensors. kv_layout : str The layout of the input k/v tensors, could be either ``NHD`` or ``HND``. @@ -1224,7 +1238,10 @@ def __init__( """ _check_kv_layout(kv_layout) self._kv_layout = kv_layout - self._workspace_buffer = workspace_buffer + self._float_workspace_buffer = float_workspace_buffer + self._int_workspace_buffer = torch.empty( + (8 * 1024 * 1024,), dtype=torch.uint8, device=float_workspace_buffer.device + ) self._wrapper = _prefill.BatchPrefillWithRaggedKVCachePyTorchWrapper( TensorLayout[kv_layout].value, use_cuda_graph, @@ -1257,16 +1274,26 @@ def __init__( def is_cuda_graph_enabled(self) -> bool: return self._wrapper.is_cuda_graph_enabled() - def reset_workspace_buffer(self, new_workspace_buffer: torch.Tensor) -> None: + def reset_workspace_buffer( + self, float_workspace_buffer: torch.Tensor, int_workspace_buffer + ) -> None: r"""Reset the workspace buffer. Parameters ---------- - new_workspace_buffer : torch.Tensor - The new workspace buffer, the device of the new workspace buffer should + float_workspace_buffer : torch.Tensor + The new float workspace buffer, the device of the new float workspace buffer should + be the same as the device of the input tensors. + + int_workspace_buffer : torch.Tensor + The new int workspace buffer, the device of the new int workspace buffer should be the same as the device of the input tensors. """ - self._workspace_buffer = new_workspace_buffer + self._float_workspace_buffer = float_workspace_buffer + self._int_workspace_buffer = int_workspace_buffer + self._wrapper.update_page_locked_buffer_size( + int_workspace_buffer.numel() * int_workspace_buffer.element_size() + ) def begin_forward( self, @@ -1376,7 +1403,8 @@ def begin_forward( ), ) self._wrapper.begin_forward( - self._workspace_buffer, + self._float_workspace_buffer, + self._int_workspace_buffer, qo_indptr, kv_indptr, batch_size, diff --git a/python/flashinfer/sparse.py b/python/flashinfer/sparse.py index f80f1231..abfb5bd1 100644 --- a/python/flashinfer/sparse.py +++ b/python/flashinfer/sparse.py @@ -111,18 +111,21 @@ class BlockSparseAttentionWrapper: def __init__( self, - workspace_buffer: torch.Tensor, + float_workspace_buffer: torch.Tensor, ) -> None: r"""Constructs of :class:`BlockSparseAttentionWrapper`. Parameters ---------- - workspace_buffer : torch.Tensor - The user reserved workspace buffer used to store auxiliary data structures, - recommended size is 128MB, the device of the workspace buffer should be the - same as the device of the input tensors. + float_workspace_buffer : torch.Tensor + The user reserved float workspace buffer used to store intermediate attention results + in the split-k algorithm. The recommended size is 128MB, the device of the workspace + buffer should be the same as the device of the input tensors. """ - self._workspace_buffer = workspace_buffer + self._float_workspace_buffer = float_workspace_buffer + self._int_workspace_buffer = torch.empty( + (8 * 1024 * 1024,), dtype=torch.uint8, device=float_workspace_buffer.device + ) self._wrapper = _prefill.BatchPrefillWithPagedKVCachePyTorchWrapper( TensorLayout["NHD"].value, False, # use_cuda_graph @@ -138,6 +141,27 @@ def __init__( self.M: Optional[int] = None self.N: Optional[int] = None + def reset_workspace_buffer( + self, float_workspace_buffer: torch.Tensor, int_workspace_buffer + ) -> None: + r"""Reset the workspace buffer. + + Parameters + ---------- + float_workspace_buffer : torch.Tensor + The new float workspace buffer, the device of the new float workspace buffer should + be the same as the device of the input tensors. + + int_workspace_buffer : torch.Tensor + The new int workspace buffer, the device of the new int workspace buffer should + be the same as the device of the input tensors. + """ + self._float_workspace_buffer = float_workspace_buffer + self._int_workspace_buffer = int_workspace_buffer + self._wrapper.update_page_locked_buffer_size( + int_workspace_buffer.numel() * int_workspace_buffer.element_size() + ) + def begin_forward( self, indptr: torch.Tensor, @@ -244,7 +268,8 @@ def begin_forward( self.C = C self._wrapper.begin_forward( - self._workspace_buffer, + self._float_workspace_buffer, + self._int_workspace_buffer, self._qo_indptr, self._paged_kv_indptr_buf, num_blocks_row, diff --git a/python/setup.py b/python/setup.py index 9566da6c..a23c6ac3 100644 --- a/python/setup.py +++ b/python/setup.py @@ -313,9 +313,7 @@ def __init__(self, *args, **kwargs) -> None: files_prefill, files_decode = get_instantiation_cu() include_dirs = [ str(root.resolve() / "include"), - str( - root.resolve() / "3rdparty" / "cutlass" / "include" - ), # for group gemm + str(root.resolve() / "3rdparty" / "cutlass" / "include"), # for group gemm ] extra_compile_args = { "cxx": [ diff --git a/src/bench_batch_decode.cu b/src/bench_batch_decode.cu index 024a911e..d81f22a9 100644 --- a/src/bench_batch_decode.cu +++ b/src/bench_batch_decode.cu @@ -15,6 +15,7 @@ */ #include +#include #include #include #include @@ -71,13 +72,16 @@ void bench_flashinfer_batch_decode(nvbench::state& state) { BatchDecodeHandler handler; if (cooperative) { - size_t workspace_size_in_bytes = 32 * 1024 * 1024; - thrust::device_vector buffer(workspace_size_in_bytes); + size_t float_workspace_size_in_bytes = 32 * 1024 * 1024; + thrust::device_vector float_buffer(float_workspace_size_in_bytes); + size_t int_workspace_size_in_bytes = 8 * 1024 * 1024; + thrust::device_vector int_buffer(int_workspace_size_in_bytes); // begin forward BatchDecodeHandlerBeginForward( - &handler, (void*)thrust::raw_pointer_cast(buffer.data()), workspace_size_in_bytes, - kv_indptr_host.data(), kv_last_page_len_host.data(), batch_size, num_qo_heads, num_kv_heads, - head_dim, page_size, pos_encoding_mode); + &handler, (void*)thrust::raw_pointer_cast(float_buffer.data()), + float_workspace_size_in_bytes, (void*)thrust::raw_pointer_cast(int_buffer.data()), + int_workspace_size_in_bytes, kv_indptr_host.data(), kv_last_page_len_host.data(), + batch_size, num_qo_heads, num_kv_heads, head_dim, page_size, pos_encoding_mode); state.exec([&](nvbench::launch&) { cudaError_t status = BatchDecodeWithPagedKVCacheWrapper( @@ -148,12 +152,16 @@ void bench_flashinfer_batch_decode_with_prefill(nvbench::state& state) { "Read"); state.add_global_memory_writes(vec_bytes(o), "Write"); BatchPrefillHandler handler; - size_t workspace_size_in_bytes = 128 * 1024 * 1024; - thrust::device_vector buffer(workspace_size_in_bytes); + size_t float_workspace_size_in_bytes = 128 * 1024 * 1024; + thrust::device_vector float_buffer(float_workspace_size_in_bytes); + size_t int_workspace_size_in_bytes = 8 * 1024 * 1024; + thrust::device_vector int_buffer(int_workspace_size_in_bytes); handler.BeginForward( - (void*)thrust::raw_pointer_cast(buffer.data()), workspace_size_in_bytes, qo_indptr_h.data(), - kv_indptr_host.data(), batch_size, num_qo_heads, num_kv_heads, head_dim, page_size); + (void*)thrust::raw_pointer_cast(float_buffer.data()), float_workspace_size_in_bytes, + (void*)thrust::raw_pointer_cast(int_buffer.data()), int_workspace_size_in_bytes, + qo_indptr_h.data(), kv_indptr_host.data(), batch_size, num_qo_heads, num_kv_heads, head_dim, + page_size); state.exec(nvbench::exec_tag::sync, [&](nvbench::launch&) { cudaError_t status = diff --git a/src/bench_batch_prefill.cu b/src/bench_batch_prefill.cu index 81875436..d04e838a 100644 --- a/src/bench_batch_prefill.cu +++ b/src/bench_batch_prefill.cu @@ -16,6 +16,7 @@ #include #include +#include #include #include #include @@ -48,7 +49,10 @@ void bench_flashinfer_batch_prefill_with_ragged_kv(nvbench::state& state) { thrust::device_vector K(batch_size * kv_len * num_kv_heads * head_dim); thrust::device_vector V(batch_size * kv_len * num_kv_heads * head_dim); thrust::device_vector O(batch_size * qo_len * num_qo_heads * head_dim); - thrust::device_vector workspace(128 * 1024 * 1024); + size_t float_workspace_size_in_bytes = 128 * 1024 * 1024; + thrust::device_vector float_workspace(float_workspace_size_in_bytes); + size_t int_workspace_size_in_bytes = 8 * 1024 * 1024; + thrust::device_vector int_workspace(int_workspace_size_in_bytes); // Provide throughput information: state.add_global_memory_reads( @@ -69,10 +73,11 @@ void bench_flashinfer_batch_prefill_with_ragged_kv(nvbench::state& state) { BatchPrefillHandler handler; - handler.BeginForward(thrust::raw_pointer_cast(workspace.data()), - workspace.size() * sizeof(uint8_t), qo_indptr_h.data(), - kv_indptr_h.data(), batch_size, num_qo_heads, num_kv_heads, - head_dim, /*page_size=*/1); + handler.BeginForward( + thrust::raw_pointer_cast(float_workspace.data()), float_workspace_size_in_bytes, + thrust::raw_pointer_cast(int_workspace.data()), int_workspace_size_in_bytes, + qo_indptr_h.data(), kv_indptr_h.data(), batch_size, num_qo_heads, num_kv_heads, head_dim, + /*page_size=*/1); state.exec(nvbench::exec_tag::timer, [&](nvbench::launch& launch, auto& timer) { timer.start(); diff --git a/src/bench_cascade.cu b/src/bench_cascade.cu index 0312e06c..56cc3d09 100644 --- a/src/bench_cascade.cu +++ b/src/bench_cascade.cu @@ -15,6 +15,7 @@ */ #include +#include #include #include @@ -107,12 +108,15 @@ void bench_two_level_single_prefix_cascade_decode(nvbench::state& state) { thrust::raw_pointer_cast(kv_indptr_unique_d.data()), thrust::raw_pointer_cast(kv_last_page_len_unique_d.data())); BatchDecodeHandler cascade_handler; - size_t workspace_size_in_bytes = 32 * 1024 * 1024; - thrust::device_vector buffer(workspace_size_in_bytes); + size_t float_workspace_size_in_bytes = 32 * 1024 * 1024; + thrust::device_vector float_buffer(float_workspace_size_in_bytes); + size_t int_workspace_size_in_bytes = 8 * 1024 * 1024; + thrust::device_vector int_buffer(int_workspace_size_in_bytes); BatchDecodeHandlerBeginForward( - &cascade_handler, (void*)thrust::raw_pointer_cast(buffer.data()), workspace_size_in_bytes, - kv_indptr_unique_h.data(), kv_last_page_len_unique_h.data(), batch_size, num_qo_heads, - num_kv_heads, head_dim, page_size, PosEncodingMode::kNone); + &cascade_handler, (void*)thrust::raw_pointer_cast(float_buffer.data()), + float_workspace_size_in_bytes, (void*)thrust::raw_pointer_cast(int_buffer.data()), + int_workspace_size_in_bytes, kv_indptr_unique_h.data(), kv_last_page_len_unique_h.data(), + batch_size, num_qo_heads, num_kv_heads, head_dim, page_size, PosEncodingMode::kNone); state.exec(nvbench::exec_tag::timer, [&](nvbench::launch& launch, auto& timer) { timer.start(); @@ -165,12 +169,16 @@ void bench_two_level_single_prefix_cascade_decode(nvbench::state& state) { thrust::raw_pointer_cast(kv_indptr_combined_d.data()), thrust::raw_pointer_cast(kv_last_page_len_combined_d.data())); BatchDecodeHandler baseline_handler; - size_t workspace_size_in_bytes = 32 * 1024 * 1024; - thrust::device_vector buffer(workspace_size_in_bytes); + size_t float_workspace_size_in_bytes = 32 * 1024 * 1024; + thrust::device_vector float_buffer(float_workspace_size_in_bytes); + size_t int_workspace_size_in_bytes = 8 * 1024 * 1024; + thrust::device_vector int_buffer(int_workspace_size_in_bytes); BatchDecodeHandlerBeginForward( - &baseline_handler, (void*)thrust::raw_pointer_cast(buffer.data()), workspace_size_in_bytes, - kv_indptr_combined_h.data(), kv_last_page_len_combined_h.data(), batch_size, num_qo_heads, - num_kv_heads, head_dim, page_size, PosEncodingMode::kNone); + &baseline_handler, (void*)thrust::raw_pointer_cast(float_buffer.data()), + float_workspace_size_in_bytes, (void*)thrust::raw_pointer_cast(int_buffer.data()), + int_workspace_size_in_bytes, kv_indptr_combined_h.data(), + kv_last_page_len_combined_h.data(), batch_size, num_qo_heads, num_kv_heads, head_dim, + page_size, PosEncodingMode::kNone); state.exec(nvbench::exec_tag::timer, [&](nvbench::launch& launch, auto& timer) { timer.start(); @@ -241,11 +249,15 @@ void bench_two_level_single_prefix_cascade_append(nvbench::state& state) { thrust::raw_pointer_cast(kv_indptr_unique_d.data()), thrust::raw_pointer_cast(kv_last_page_len_unique_d.data())); BatchPrefillHandler cascade_handler; - size_t workspace_size_in_bytes = 32 * 1024 * 1024; - thrust::device_vector buffer(workspace_size_in_bytes); + size_t float_workspace_size_in_bytes = 32 * 1024 * 1024; + thrust::device_vector float_buffer(float_workspace_size_in_bytes); + size_t int_workspace_size_in_bytes = 8 * 1024 * 1024; + thrust::device_vector int_buffer(int_workspace_size_in_bytes); cascade_handler.BeginForward( - (void*)thrust::raw_pointer_cast(buffer.data()), workspace_size_in_bytes, qo_indptr_h.data(), - kv_indptr_unique_h.data(), batch_size, num_qo_heads, num_kv_heads, head_dim, page_size); + (void*)thrust::raw_pointer_cast(float_buffer.data()), float_workspace_size_in_bytes, + (void*)thrust::raw_pointer_cast(int_buffer.data()), int_workspace_size_in_bytes, + qo_indptr_h.data(), kv_indptr_unique_h.data(), batch_size, num_qo_heads, num_kv_heads, + head_dim, page_size); state.exec(nvbench::exec_tag::timer, [&](nvbench::launch& launch, auto& timer) { timer.start(); cudaError_t status = SinglePrefillWithKVCache( @@ -298,11 +310,15 @@ void bench_two_level_single_prefix_cascade_append(nvbench::state& state) { thrust::raw_pointer_cast(kv_indptr_combined_d.data()), thrust::raw_pointer_cast(kv_last_page_len_combined_d.data())); BatchPrefillHandler baseline_handler; - size_t workspace_size_in_bytes = 32 * 1024 * 1024; - thrust::device_vector buffer(workspace_size_in_bytes); + size_t float_workspace_size_in_bytes = 32 * 1024 * 1024; + thrust::device_vector float_buffer(float_workspace_size_in_bytes); + size_t int_workspace_size_in_bytes = 8 * 1024 * 1024; + thrust::device_vector int_buffer(int_workspace_size_in_bytes); baseline_handler.BeginForward( - (void*)thrust::raw_pointer_cast(buffer.data()), workspace_size_in_bytes, qo_indptr_h.data(), - kv_indptr_combined_h.data(), batch_size, num_qo_heads, num_kv_heads, head_dim, page_size); + (void*)thrust::raw_pointer_cast(float_buffer.data()), float_workspace_size_in_bytes, + (void*)thrust::raw_pointer_cast(int_buffer.data()), int_workspace_size_in_bytes, + qo_indptr_h.data(), kv_indptr_combined_h.data(), batch_size, num_qo_heads, num_kv_heads, + head_dim, page_size); state.exec(nvbench::exec_tag::timer, [&](nvbench::launch& launch, auto& timer) { timer.start(); cudaError_t status = BatchPrefillWithPagedKVCacheWrapper( diff --git a/src/flashinfer_ops.cuh b/src/flashinfer_ops.cuh index 9fce7ed9..753cb244 100644 --- a/src/flashinfer_ops.cuh +++ b/src/flashinfer_ops.cuh @@ -279,8 +279,9 @@ cudaError_t BatchDecodeWithPagedKVCacheWrapper( template -cudaError_t BatchDecodeHandlerBeginForward(BatchDecodeHandler* handler, void* buffer, - size_t workspace_size_in_bytes, IdType* indptr_h, +cudaError_t BatchDecodeHandlerBeginForward(BatchDecodeHandler* handler, void* float_buffer, + size_t float_workspace_size_in_bytes, void* int_buffer, + size_t int_workspace_size_in_bytes, IdType* indptr_h, IdType* last_page_len_h, uint32_t batch_size, uint32_t num_qo_heads, uint32_t num_kv_heads, uint32_t head_dim, uint32_t page_size, @@ -295,8 +296,8 @@ cudaError_t BatchDecodeHandlerBeginForward(BatchDecodeHandler* handler, void* bu DISPATCH_pos_encoding_mode(pos_encoding_mode, POS_ENCODING_MODE, { return handler->BeginForwardDispatched( - buffer, workspace_size_in_bytes, indptr_h, last_page_len_h, batch_size, num_qo_heads, - num_kv_heads, page_size); + float_buffer, float_workspace_size_in_bytes, int_buffer, int_workspace_size_in_bytes, + indptr_h, last_page_len_h, batch_size, num_qo_heads, num_kv_heads, page_size); }); }); } diff --git a/src/test_batch_decode.cu b/src/test_batch_decode.cu index 64b9b290..1debe0b0 100644 --- a/src/test_batch_decode.cu +++ b/src/test_batch_decode.cu @@ -98,10 +98,13 @@ void _TestBatchDecodingKernelCorrectness(size_t page_size, size_t batch_size, si thrust::raw_pointer_cast(kv_indptr_device.data()), thrust::raw_pointer_cast(kv_last_page_len_device.data())); flashinfer::BatchDecodeHandler handler; - size_t workspace_size_in_bytes = 32 * 1024 * 1024; - thrust::device_vector buffer(workspace_size_in_bytes); + size_t float_workspace_size_in_bytes = 32 * 1024 * 1024; + thrust::device_vector float_buffer(float_workspace_size_in_bytes); + size_t int_workspace_size_in_bytes = 8 * 1024 * 1024; + thrust::device_vector int_buffer(int_workspace_size_in_bytes); BatchDecodeHandlerBeginForward( - &handler, (void*)thrust::raw_pointer_cast(buffer.data()), workspace_size_in_bytes, + &handler, (void*)thrust::raw_pointer_cast(float_buffer.data()), float_workspace_size_in_bytes, + (void*)thrust::raw_pointer_cast(int_buffer.data()), int_workspace_size_in_bytes, kv_indptr.data(), kv_last_page_len.data(), batch_size, num_qo_heads, num_kv_heads, head_dim, page_size, pos_encoding_mode); diff --git a/src/test_batch_prefill.cu b/src/test_batch_prefill.cu index 39602710..26b00607 100644 --- a/src/test_batch_prefill.cu +++ b/src/test_batch_prefill.cu @@ -82,8 +82,10 @@ void _TestBatchPagedPrefillKernelOneHotCorrectness(size_t num_kv_heads, size_t n paged_kv.last_page_len = thrust::raw_pointer_cast(kv_last_page_len_device.data()); BatchPrefillHandler handler; - size_t workspace_size_in_bytes = 128 * 1024 * 1024; - thrust::device_vector buffer(workspace_size_in_bytes); + size_t float_workspace_size_in_bytes = 128 * 1024 * 1024; + thrust::device_vector float_buffer(float_workspace_size_in_bytes); + size_t int_workspace_size_in_bytes = 8 * 1024 * 1024; + thrust::device_vector int_buffer(int_workspace_size_in_bytes); for (uint32_t request_idx = 0; request_idx < batch_size; ++request_idx) { // create one-hot queries @@ -104,8 +106,10 @@ void _TestBatchPagedPrefillKernelOneHotCorrectness(size_t num_kv_heads, size_t n thrust::device_vector o_device(q_len * num_qo_heads * head_dim); handler.BeginForward( - (void*)thrust::raw_pointer_cast(buffer.data()), workspace_size_in_bytes, q_indptr.data(), - kv_indptr.data(), batch_size, num_qo_heads, num_kv_heads, head_dim, page_size); + (void*)thrust::raw_pointer_cast(float_buffer.data()), float_workspace_size_in_bytes, + (void*)thrust::raw_pointer_cast(int_buffer.data()), int_workspace_size_in_bytes, + q_indptr.data(), kv_indptr.data(), batch_size, num_qo_heads, num_kv_heads, head_dim, + page_size); for (uint32_t num_runs = 0; num_runs < 10; ++num_runs) { auto status = flashinfer::BatchPrefillWithPagedKVCacheWrapper output_refs; BatchPrefillHandler handler; - size_t workspace_size_in_bytes = 128 * 1024 * 1024; - thrust::device_vector buffer(workspace_size_in_bytes); + size_t float_workspace_size_in_bytes = 128 * 1024 * 1024; + thrust::device_vector float_buffer(float_workspace_size_in_bytes); + size_t int_workspace_size_in_bytes = 8 * 1024 * 1024; + thrust::device_vector int_buffer(int_workspace_size_in_bytes); for (uint32_t request_idx = 0; request_idx < batch_size; ++request_idx) { std::vector q(q_lens[request_idx] * num_qo_heads * head_dim); @@ -191,8 +197,10 @@ void _TestBatchRaggedPrefillKernelCorrectness(size_t num_kv_heads, size_t num_qo thrust::device_vector kv_indptr_device(kv_indptr); handler.BeginForward( - (void*)thrust::raw_pointer_cast(buffer.data()), workspace_size_in_bytes, append_indptr.data(), - kv_indptr.data(), batch_size, num_qo_heads, num_kv_heads, head_dim, /*page_size=*/1); + (void*)thrust::raw_pointer_cast(float_buffer.data()), float_workspace_size_in_bytes, + (void*)thrust::raw_pointer_cast(int_buffer.data()), int_workspace_size_in_bytes, + append_indptr.data(), kv_indptr.data(), batch_size, num_qo_heads, num_kv_heads, head_dim, + /*page_size=*/1); auto status = BatchPrefillWithRaggedKVCacheWrapper( &handler, thrust::raw_pointer_cast(queries_device.data()), @@ -315,12 +323,16 @@ void _TestBatchPagedPrefillKernelShortContextCorrectness(size_t num_kv_heads, si thrust::device_vector o_device(o_concat_ref.size()); BatchPrefillHandler handler; - size_t workspace_size_in_bytes = 32 * 1024 * 1024; - thrust::device_vector buffer(workspace_size_in_bytes); + size_t float_workspace_size_in_bytes = 32 * 1024 * 1024; + thrust::device_vector float_buffer(float_workspace_size_in_bytes); + size_t int_workspace_size_in_bytes = 8 * 1024 * 1024; + thrust::device_vector int_buffer(int_workspace_size_in_bytes); handler.BeginForward( - (void*)thrust::raw_pointer_cast(buffer.data()), workspace_size_in_bytes, q_indptr.data(), - kv_indptr.data(), batch_size, num_qo_heads, num_kv_heads, head_dim, page_size); + (void*)thrust::raw_pointer_cast(float_buffer.data()), float_workspace_size_in_bytes, + (void*)thrust::raw_pointer_cast(int_buffer.data()), int_workspace_size_in_bytes, + q_indptr.data(), kv_indptr.data(), batch_size, num_qo_heads, num_kv_heads, head_dim, + page_size); auto status = BatchPrefillWithPagedKVCacheWrapper( @@ -438,12 +450,16 @@ void _TestBatchPagedPrefillKernelQMinMaxKVMinMaxCorrectness( thrust::device_vector o_device(o_concat_ref.size()); BatchPrefillHandler handler; - size_t workspace_size_in_bytes = 32 * 1024 * 1024; - thrust::device_vector buffer(workspace_size_in_bytes); + size_t float_workspace_size_in_bytes = 32 * 1024 * 1024; + thrust::device_vector float_buffer(float_workspace_size_in_bytes); + size_t int_workspace_size_in_bytes = 8 * 1024 * 1024; + thrust::device_vector int_buffer(int_workspace_size_in_bytes); handler.BeginForward( - (void*)thrust::raw_pointer_cast(buffer.data()), workspace_size_in_bytes, q_indptr.data(), - kv_indptr.data(), batch_size, num_qo_heads, num_kv_heads, head_dim, page_size); + (void*)thrust::raw_pointer_cast(float_buffer.data()), float_workspace_size_in_bytes, + (void*)thrust::raw_pointer_cast(int_buffer.data()), int_workspace_size_in_bytes, + q_indptr.data(), kv_indptr.data(), batch_size, num_qo_heads, num_kv_heads, head_dim, + page_size); auto status = BatchPrefillWithPagedKVCacheWrapper( @@ -534,12 +550,15 @@ void _TestBatchPagedPrefillKernelLongContextCorrectness(size_t num_kv_heads, siz thrust::device_vector o_device(q_lens[0] * num_qo_heads * head_dim); BatchPrefillHandler handler; - size_t workspace_size_in_bytes = 32 * 1024 * 1024; - thrust::device_vector buffer(workspace_size_in_bytes); + size_t float_workspace_size_in_bytes = 32 * 1024 * 1024; + thrust::device_vector float_buffer(float_workspace_size_in_bytes); + size_t int_workspace_size_in_bytes = 8 * 1024 * 1024; + thrust::device_vector int_buffer(int_workspace_size_in_bytes); handler.BeginForward( - (void*)thrust::raw_pointer_cast(buffer.data()), workspace_size_in_bytes, append_indptr.data(), - kv_indptr.data(), + (void*)thrust::raw_pointer_cast(float_buffer.data()), float_workspace_size_in_bytes, + (void*)thrust::raw_pointer_cast(int_buffer.data()), int_workspace_size_in_bytes, + append_indptr.data(), kv_indptr.data(), /*batch_size=*/1, num_qo_heads, num_kv_heads, head_dim, page_size); auto status = BatchPrefillWithPagedKVCacheWrapper buffer_baseline(workspace_size_in_bytes), - buffer_cascade(workspace_size_in_bytes); + size_t float_workspace_size_in_bytes = 32 * 1024 * 1024; + thrust::device_vector float_buffer(float_workspace_size_in_bytes); + size_t int_workspace_size_in_bytes = 8 * 1024 * 1024; + thrust::device_vector int_buffer(int_workspace_size_in_bytes); BatchDecodeHandlerBeginForward( - &baseline_handler, (void*)thrust::raw_pointer_cast(buffer_baseline.data()), - workspace_size_in_bytes, kv_indptr_combined_h.data(), kv_last_page_len_combined_h.data(), + &baseline_handler, (void*)thrust::raw_pointer_cast(float_buffer.data()), + float_workspace_size_in_bytes, (void*)thrust::raw_pointer_cast(int_buffer.data()), + int_workspace_size_in_bytes, kv_indptr_combined_h.data(), kv_last_page_len_combined_h.data(), batch_size, num_qo_heads, num_kv_heads, head_dim, page_size, PosEncodingMode::kNone); BatchDecodeHandlerBeginForward( - &cascade_handler, (void*)thrust::raw_pointer_cast(buffer_cascade.data()), - workspace_size_in_bytes, kv_indptr_unique_h.data(), kv_last_page_len_unique_h.data(), + &cascade_handler, (void*)thrust::raw_pointer_cast(float_buffer.data()), + float_workspace_size_in_bytes, (void*)thrust::raw_pointer_cast(int_buffer.data()), + int_workspace_size_in_bytes, kv_indptr_unique_h.data(), kv_last_page_len_unique_h.data(), batch_size, num_qo_heads, num_kv_heads, head_dim, page_size, PosEncodingMode::kNone); // Compute result using baseline implementation @@ -408,18 +411,21 @@ void _TestTwoLevelSinglePrefixCascadeAppendCorrectness(size_t batch_size, thrust::raw_pointer_cast(kv_last_page_len_unique_d.data())); BatchPrefillHandler baseline_handler, cascade_handler; - size_t workspace_size_in_bytes = 32 * 1024 * 1024; - thrust::device_vector buffer_baseline(workspace_size_in_bytes), - buffer_cascade(workspace_size_in_bytes); - - baseline_handler.BeginForward((void*)thrust::raw_pointer_cast(buffer_baseline.data()), - workspace_size_in_bytes, qo_indptr_h.data(), - kv_indptr_combined_h.data(), batch_size, num_qo_heads, - num_kv_heads, head_dim, page_size); - cascade_handler.BeginForward((void*)thrust::raw_pointer_cast(buffer_cascade.data()), - workspace_size_in_bytes, qo_indptr_h.data(), - kv_indptr_unique_h.data(), batch_size, num_qo_heads, - num_kv_heads, head_dim, page_size); + size_t float_workspace_size_in_bytes = 32 * 1024 * 1024; + thrust::device_vector float_buffer(float_workspace_size_in_bytes); + size_t int_workspace_size_in_bytes = 8 * 1024 * 1024; + thrust::device_vector int_buffer(int_workspace_size_in_bytes); + + baseline_handler.BeginForward( + (void*)thrust::raw_pointer_cast(float_buffer.data()), float_workspace_size_in_bytes, + (void*)thrust::raw_pointer_cast(int_buffer.data()), int_workspace_size_in_bytes, + qo_indptr_h.data(), kv_indptr_combined_h.data(), batch_size, num_qo_heads, num_kv_heads, + head_dim, page_size); + cascade_handler.BeginForward( + (void*)thrust::raw_pointer_cast(float_buffer.data()), float_workspace_size_in_bytes, + (void*)thrust::raw_pointer_cast(int_buffer.data()), int_workspace_size_in_bytes, + qo_indptr_h.data(), kv_indptr_unique_h.data(), batch_size, num_qo_heads, num_kv_heads, + head_dim, page_size); cudaError_t status = BatchPrefillWithPagedKVCacheWrapper( &baseline_handler, thrust::raw_pointer_cast(q_d.data()), diff --git a/src/tvm_wrapper.cu b/src/tvm_wrapper.cu index c8da983b..3ccb4fa5 100644 --- a/src/tvm_wrapper.cu +++ b/src/tvm_wrapper.cu @@ -272,11 +272,15 @@ void _FlashInferAttentionPrefillWithPagedKVCache(int64_t handler_id, DLTensor* q } void _FlashInferAttentionPrefillWithPagedKVCacheBeginForward( - int64_t handler_idx, DLTensor* workspace_buffer, DLTensor* qo_indptr, DLTensor* kv_indptr, - int64_t batch_size, int64_t num_qo_heads, int64_t num_kv_heads, int64_t head_dim, - int64_t page_size, TVMStreamHandle copy_stream) { - CHECK_EQ(workspace_buffer->ndim, 1) << "The workspace buffer must be a 1-D tensor"; - size_t workspace_size_in_bytes = workspace_buffer->shape[0] * workspace_buffer->dtype.bits / 8; + int64_t handler_idx, DLTensor* float_workspace_buffer, DLTensor* int_workspace_buffer, + DLTensor* qo_indptr, DLTensor* kv_indptr, int64_t batch_size, int64_t num_qo_heads, + int64_t num_kv_heads, int64_t head_dim, int64_t page_size, TVMStreamHandle copy_stream) { + CHECK_EQ(float_workspace_buffer->ndim, 1) << "The float workspace buffer must be a 1-D tensor"; + size_t float_workspace_size_in_bytes = + float_workspace_buffer->shape[0] * float_workspace_buffer->dtype.bits / 8; + CHECK_EQ(int_workspace_buffer->ndim, 1) << "The int workspace buffer must be a 1-D tensor"; + size_t int_workspace_size_in_bytes = + int_workspace_buffer->shape[0] * int_workspace_buffer->dtype.bits / 8; CHECK(handler_idx < max_num_handlers) << "The handler id must be less than " << max_num_handlers; // NOTE(Zihao): here we presume the input data type is half, in the future we should @@ -288,7 +292,8 @@ void _FlashInferAttentionPrefillWithPagedKVCacheBeginForward( DISPATCH_TVM_CUDA_IDTYPE(qo_indptr->dtype, dtype_idx, { cudaError_t status = batch_prefill_paged_kv_handlers[handler_idx].BeginForward( - static_cast(workspace_buffer->data), workspace_size_in_bytes, + static_cast(float_workspace_buffer->data), float_workspace_size_in_bytes, + static_cast(int_workspace_buffer->data), int_workspace_size_in_bytes, static_cast(qo_indptr->data) + qo_indptr->byte_offset / sizeof(dtype_idx), static_cast(kv_indptr->data) + kv_indptr->byte_offset / sizeof(dtype_idx), batch_size, num_qo_heads, num_kv_heads, head_dim, page_size); @@ -416,11 +421,16 @@ void _FlashInferAttentionDecodeWithPagedKVCache(int64_t handler_id, DLTensor* q_ } void _FlashInferAttentionDecodeWithPagedKVCacheBeginForward( - int64_t handler_idx, DLTensor* workspace_buffer, DLTensor* page_table_indptr, - DLTensor* last_page_len, int64_t num_qo_heads, int64_t num_kv_heads, int64_t head_dim, - int64_t page_size, int64_t pos_encoding_mode, TVMStreamHandle copy_stream) { - CHECK_EQ(workspace_buffer->ndim, 1) << "The workspace buffer must be a 1-D tensor"; - size_t workspace_size_in_bytes = workspace_buffer->shape[0] * workspace_buffer->dtype.bits / 8; + int64_t handler_idx, DLTensor* float_workspace_buffer, DLTensor* int_workspace_buffer, + DLTensor* page_table_indptr, DLTensor* last_page_len, int64_t num_qo_heads, + int64_t num_kv_heads, int64_t head_dim, int64_t page_size, int64_t pos_encoding_mode, + TVMStreamHandle copy_stream) { + CHECK_EQ(float_workspace_buffer->ndim, 1) << "The float workspace buffer must be a 1-D tensor"; + size_t float_workspace_size_in_bytes = + float_workspace_buffer->shape[0] * float_workspace_buffer->dtype.bits / 8; + CHECK_EQ(int_workspace_buffer->ndim, 1) << "The int workspace buffer must be a 1-D tensor"; + size_t int_workspace_size_in_bytes = + int_workspace_buffer->shape[0] * int_workspace_buffer->dtype.bits / 8; CHECK_LT(handler_idx, max_num_handlers) << "The handler id must be less than " << max_num_handlers; constexpr PageStorage page_storage = PageStorage::kIndices; @@ -433,8 +443,9 @@ void _FlashInferAttentionDecodeWithPagedKVCacheBeginForward( DISPATCH_TVM_CUDA_IDTYPE(page_table_indptr->dtype, dtype_idx, { cudaError_t status = BatchDecodeHandlerBeginForward( - batch_decode_handlers + handler_idx, static_cast(workspace_buffer->data), - workspace_size_in_bytes, + batch_decode_handlers + handler_idx, static_cast(float_workspace_buffer->data), + float_workspace_size_in_bytes, static_cast(int_workspace_buffer->data), + int_workspace_size_in_bytes, static_cast(page_table_indptr->data) + page_table_indptr->byte_offset / sizeof(dtype_idx), static_cast(last_page_len->data) + @@ -551,10 +562,15 @@ void _FlashInferAttentionPrefillWithRaggedKVCache( } void _FlashInferAttentionPrefillWithRaggedKVCacheBeginForward( - DLTensor* workspace_buffer, DLTensor* qo_indptr, DLTensor* kv_indptr, int64_t batch_size, - int64_t num_qo_heads, int64_t num_kv_heads, int64_t head_dim, TVMStreamHandle copy_stream) { - CHECK_EQ(workspace_buffer->ndim, 1) << "The workspace buffer must be a 1-D tensor"; - size_t workspace_size_in_bytes = workspace_buffer->shape[0] * workspace_buffer->dtype.bits / 8; + DLTensor* float_workspace_buffer, DLTensor* int_workspace_buffer, DLTensor* qo_indptr, + DLTensor* kv_indptr, int64_t batch_size, int64_t num_qo_heads, int64_t num_kv_heads, + int64_t head_dim, TVMStreamHandle copy_stream) { + CHECK_EQ(float_workspace_buffer->ndim, 1) << "The workspace buffer must be a 1-D tensor"; + size_t float_workspace_size_in_bytes = + float_workspace_buffer->shape[0] * float_workspace_buffer->dtype.bits / 8; + CHECK_EQ(int_workspace_buffer->ndim, 1) << "The workspace buffer must be a 1-D tensor"; + size_t int_workspace_size_in_bytes = + int_workspace_buffer->shape[0] * int_workspace_buffer->dtype.bits / 8; cudaStream_t original_stream = batch_prefill_ragged_kv_handler.GetCUDAStream(); batch_prefill_ragged_kv_handler.SetCUDAStream(static_cast(copy_stream)); @@ -564,7 +580,8 @@ void _FlashInferAttentionPrefillWithRaggedKVCacheBeginForward( DISPATCH_TVM_CUDA_IDTYPE(qo_indptr->dtype, dtype_idx, { cudaError_t status = batch_prefill_ragged_kv_handler.BeginForward( - static_cast(workspace_buffer->data), workspace_size_in_bytes, + static_cast(float_workspace_buffer->data), float_workspace_size_in_bytes, + static_cast(int_workspace_buffer->data), int_workspace_size_in_bytes, static_cast(qo_indptr->data) + qo_indptr->byte_offset / sizeof(dtype_idx), static_cast(kv_indptr->data) + kv_indptr->byte_offset / sizeof(dtype_idx), batch_size, num_qo_heads, num_kv_heads, head_dim,