Skip to content

Commit

Permalink
feat: decouple float and int workspace buffer (#442)
Browse files Browse the repository at this point in the history
Before this PR, flashinfer coupled float and int buffers in a single
workspace buffer, and different wrappers cannot share the same buffers.

This PR decouples float and int workspace buffer. The float workspace
buffer (large) can be shared in multiple wrappers, and the int buffer
(small) is unique for each wrapper. This PR can save GPU memory when
multiple wrappers are created (decode, prefill paged, prefill ragged) or
cascade inference.
  • Loading branch information
yzh119 authored Aug 13, 2024
1 parent 3fff008 commit a7ee566
Show file tree
Hide file tree
Showing 19 changed files with 467 additions and 275 deletions.
149 changes: 78 additions & 71 deletions include/flashinfer/attention/handler.cuh

Large diffs are not rendered by default.

3 changes: 1 addition & 2 deletions python/csrc/activation.cu
Original file line number Diff line number Diff line change
Expand Up @@ -51,8 +51,7 @@ void gelu_tanh_and_mul(torch::Tensor& out, torch::Tensor& input) {
DISPATCH_PYTORCH_DTYPE_TO_CTYPE_FP16(input.scalar_type(), c_type, [&] {
uint32_t vec_size = 16 / sizeof(c_type);
dim3 block(std::min(d / vec_size, 1024U));
flashinfer::activation::act_and_mul_kernel<c_type,
flashinfer::activation::gelu_tanh_kernel>
flashinfer::activation::act_and_mul_kernel<c_type, flashinfer::activation::gelu_tanh_kernel>
<<<grid, block, 0, stream>>>(static_cast<c_type*>(out.data_ptr()),
static_cast<c_type*>(input.data_ptr()), d);

Expand Down
38 changes: 24 additions & 14 deletions python/csrc/batch_decode.cu
Original file line number Diff line number Diff line change
Expand Up @@ -21,22 +21,28 @@
using namespace flashinfer;

void BatchDecodeWithPagedKVCachePyTorchWrapper::BeginForward(
torch::Tensor workspace_buffer, torch::Tensor indptr, torch::Tensor last_page_len,
unsigned int batch_size, unsigned int num_qo_heads, unsigned int num_kv_heads,
unsigned int head_dim, unsigned int page_size, unsigned int pos_encoding_mode,
float logits_soft_cap, torch::Tensor empty_q_data, torch::Tensor empty_kv_data) {
CHECK_INPUT(workspace_buffer);
torch::Tensor float_workspace_buffer, torch::Tensor int_workspace_buffer, torch::Tensor indptr,
torch::Tensor last_page_len, unsigned int batch_size, unsigned int num_qo_heads,
unsigned int num_kv_heads, unsigned int head_dim, unsigned int page_size,
unsigned int pos_encoding_mode, float logits_soft_cap, torch::Tensor empty_q_data,
torch::Tensor empty_kv_data) {
CHECK_INPUT(float_workspace_buffer);
CHECK_INPUT(int_workspace_buffer);
// NOTE(zihao): not necessary to be CUDA tensor
CHECK_CONTIGUOUS(indptr);
CHECK_CONTIGUOUS(last_page_len);
CHECK_DIM(1, indptr);
CHECK_DIM(1, last_page_len);
CHECK_DIM(1, workspace_buffer);
CHECK_DIM(1, float_workspace_buffer);
CHECK_DIM(1, int_workspace_buffer);
CHECK_EQ(indptr.scalar_type(), torch::kInt32);
CHECK_EQ(indptr.scalar_type(), torch::kInt32);
CHECK_GQA_HEAD_DIVISIBLE(num_qo_heads, num_kv_heads);
size_t workspace_size_in_bytes = workspace_buffer.size(0) * workspace_buffer.element_size();
auto device = workspace_buffer.device();
size_t float_workspace_size_in_bytes =
float_workspace_buffer.size(0) * float_workspace_buffer.element_size();
size_t int_workspace_size_in_bytes =
int_workspace_buffer.size(0) * int_workspace_buffer.element_size();
auto device = float_workspace_buffer.device();
cudaStream_t torch_current_stream = c10::cuda::getCurrentCUDAStream(device.index());
handler_->SetCUDAStream(torch_current_stream);
indptr = indptr.to(torch::kCPU);
Expand All @@ -59,8 +65,10 @@ void BatchDecodeWithPagedKVCachePyTorchWrapper::BeginForward(
handler_->BeginForwardDispatched<HEAD_DIM, PageStorage::kIndices,
LOGITS_POST_HOOK, POS_ENCODING_MODE, qkv_type,
qkv_type, qkv_type, int32_t>(
static_cast<void*>(workspace_buffer.data_ptr()), workspace_size_in_bytes,
static_cast<int32_t*>(indptr.data_ptr()),
static_cast<void*>(float_workspace_buffer.data_ptr()),
float_workspace_size_in_bytes,
static_cast<void*>(int_workspace_buffer.data_ptr()),
int_workspace_size_in_bytes, static_cast<int32_t*>(indptr.data_ptr()),
static_cast<int32_t*>(last_page_len.data_ptr()), batch_size, num_qo_heads,
num_kv_heads, page_size);
TORCH_CHECK(status == cudaSuccess, "BatchDecodeWithPagedKVCache failed with error ",
Expand All @@ -81,8 +89,10 @@ void BatchDecodeWithPagedKVCachePyTorchWrapper::BeginForward(
handler_->BeginForwardDispatched<HEAD_DIM, PageStorage::kIndices,
LOGITS_POST_HOOK, POS_ENCODING_MODE, q_type,
kv_type, q_type, int32_t>(
static_cast<void*>(workspace_buffer.data_ptr()), workspace_size_in_bytes,
static_cast<int32_t*>(indptr.data_ptr()),
static_cast<void*>(float_workspace_buffer.data_ptr()),
float_workspace_size_in_bytes,
static_cast<void*>(int_workspace_buffer.data_ptr()),
int_workspace_size_in_bytes, static_cast<int32_t*>(indptr.data_ptr()),
static_cast<int32_t*>(last_page_len.data_ptr()), batch_size, num_qo_heads,
num_kv_heads, page_size);
TORCH_CHECK(status == cudaSuccess,
Expand All @@ -100,8 +110,8 @@ void BatchDecodeWithPagedKVCachePyTorchWrapper::BeginForward(
void BatchDecodeWithPagedKVCachePyTorchWrapper::EndForward() { handler_->EndForward(); }

void BatchDecodeWithPagedKVCachePyTorchWrapper::UpdatePageLockedBufferSize(
unsigned int max_workspace_size_in_bytes) {
handler_->UpdatePageLockedBufferSize(max_workspace_size_in_bytes);
unsigned int int_workspace_size_in_bytes) {
handler_->UpdatePageLockedBufferSize(int_workspace_size_in_bytes);
}

std::vector<torch::Tensor> BatchDecodeWithPagedKVCachePyTorchWrapper::Forward(
Expand Down
54 changes: 34 additions & 20 deletions python/csrc/batch_prefill.cu
Original file line number Diff line number Diff line change
Expand Up @@ -21,29 +21,36 @@
using namespace flashinfer;

void BatchPrefillWithPagedKVCachePyTorchWrapper::BeginForward(
torch::Tensor workspace_buffer, torch::Tensor qo_indptr, torch::Tensor paged_kv_indptr,
unsigned int batch_size, unsigned int num_qo_heads, unsigned int num_kv_heads,
unsigned int head_dim, unsigned int page_size, torch::Tensor empty_q_data) {
CHECK_INPUT(workspace_buffer);
torch::Tensor float_workspace_buffer, torch::Tensor int_workspace_buffer,
torch::Tensor qo_indptr, torch::Tensor paged_kv_indptr, unsigned int batch_size,
unsigned int num_qo_heads, unsigned int num_kv_heads, unsigned int head_dim,
unsigned int page_size, torch::Tensor empty_q_data) {
CHECK_INPUT(float_workspace_buffer);
CHECK_INPUT(int_workspace_buffer);
// NOTE(Zihao): not necessary to be a CUDA tensor
CHECK_CONTIGUOUS(qo_indptr);
CHECK_CONTIGUOUS(paged_kv_indptr);
CHECK_GQA_HEAD_DIVISIBLE(num_qo_heads, num_kv_heads);
CHECK_DIM(1, qo_indptr);
CHECK_DIM(1, paged_kv_indptr);
CHECK_DIM(1, workspace_buffer);
CHECK_DIM(1, float_workspace_buffer);
CHECK_DIM(1, int_workspace_buffer);
CHECK_EQ(qo_indptr.size(0), batch_size + 1);
CHECK_EQ(paged_kv_indptr.size(0), batch_size + 1);
qo_indptr = qo_indptr.to(torch::dtype(torch::kInt32).device(torch::kCPU));
paged_kv_indptr = paged_kv_indptr.to(torch::dtype(torch::kInt32).device(torch::kCPU));
auto device = workspace_buffer.device();
size_t workspace_size_in_bytes = workspace_buffer.size(0) * workspace_buffer.element_size();
auto device = float_workspace_buffer.device();
size_t float_workspace_size_in_bytes =
float_workspace_buffer.size(0) * float_workspace_buffer.element_size();
size_t int_workspace_size_in_bytes =
int_workspace_buffer.size(0) * int_workspace_buffer.element_size();
cudaStream_t torch_current_stream = c10::cuda::getCurrentCUDAStream(device.index());
handler_->SetCUDAStream(torch_current_stream);

DISPATCH_PYTORCH_DTYPE_TO_CTYPE_FP16(empty_q_data.scalar_type(), q_type, [&] {
cudaError_t status = handler_->BeginForward<q_type, int32_t>(
static_cast<void*>(workspace_buffer.data_ptr()), workspace_size_in_bytes,
static_cast<void*>(float_workspace_buffer.data_ptr()), float_workspace_size_in_bytes,
static_cast<void*>(int_workspace_buffer.data_ptr()), int_workspace_size_in_bytes,
static_cast<int32_t*>(qo_indptr.data_ptr()),
static_cast<int32_t*>(paged_kv_indptr.data_ptr()), batch_size, num_qo_heads, num_kv_heads,
head_dim, page_size);
Expand All @@ -56,8 +63,8 @@ void BatchPrefillWithPagedKVCachePyTorchWrapper::BeginForward(
void BatchPrefillWithPagedKVCachePyTorchWrapper::EndForward() { handler_->EndForward(); }

void BatchPrefillWithPagedKVCachePyTorchWrapper::UpdatePageLockedBufferSize(
unsigned int max_workspace_size_in_bytes) {
handler_->UpdatePageLockedBufferSize(max_workspace_size_in_bytes);
unsigned int int_workspace_size_in_bytes) {
handler_->UpdatePageLockedBufferSize(int_workspace_size_in_bytes);
}

std::vector<torch::Tensor> BatchPrefillWithPagedKVCachePyTorchWrapper::Forward(
Expand Down Expand Up @@ -446,28 +453,35 @@ std::vector<torch::Tensor> BatchPrefillWithPagedKVCachePyTorchWrapper::ForwardCu
}

void BatchPrefillWithRaggedKVCachePyTorchWrapper::BeginForward(
torch::Tensor workspace_buffer, torch::Tensor qo_indptr, torch::Tensor kv_indptr,
unsigned int batch_size, unsigned int num_qo_heads, unsigned int num_kv_heads,
unsigned int head_dim, torch::Tensor empty_q_data) {
CHECK_INPUT(workspace_buffer);
torch::Tensor float_workspace_buffer, torch::Tensor int_workspace_buffer,
torch::Tensor qo_indptr, torch::Tensor kv_indptr, unsigned int batch_size,
unsigned int num_qo_heads, unsigned int num_kv_heads, unsigned int head_dim,
torch::Tensor empty_q_data) {
CHECK_INPUT(float_workspace_buffer);
CHECK_INPUT(int_workspace_buffer);
// NOTE(Zihao): not necessary to be a CUDA tensor
CHECK_CONTIGUOUS(qo_indptr);
CHECK_GQA_HEAD_DIVISIBLE(num_qo_heads, num_kv_heads);
CHECK_DIM(1, qo_indptr);
CHECK_DIM(1, kv_indptr);
CHECK_DIM(1, workspace_buffer);
CHECK_DIM(1, float_workspace_buffer);
CHECK_DIM(1, int_workspace_buffer);
CHECK_EQ(qo_indptr.size(0), batch_size + 1);
CHECK_EQ(kv_indptr.size(0), batch_size + 1);
qo_indptr = qo_indptr.to(torch::dtype(torch::kInt32).device(torch::kCPU));
kv_indptr = kv_indptr.to(torch::dtype(torch::kInt32).device(torch::kCPU));
size_t workspace_size_in_bytes = workspace_buffer.size(0) * workspace_buffer.element_size();
auto device = workspace_buffer.device();
size_t float_workspace_size_in_bytes =
float_workspace_buffer.size(0) * float_workspace_buffer.element_size();
size_t int_workspace_size_in_bytes =
int_workspace_buffer.size(0) * int_workspace_buffer.element_size();
auto device = float_workspace_buffer.device();
cudaStream_t torch_current_stream = c10::cuda::getCurrentCUDAStream(device.index());
handler_->SetCUDAStream(torch_current_stream);

DISPATCH_PYTORCH_DTYPE_TO_CTYPE_FP16(empty_q_data.scalar_type(), q_type, [&] {
cudaError_t status = handler_->BeginForward<q_type, int32_t>(
static_cast<void*>(workspace_buffer.data_ptr()), workspace_size_in_bytes,
static_cast<void*>(float_workspace_buffer.data_ptr()), float_workspace_size_in_bytes,
static_cast<void*>(int_workspace_buffer.data_ptr()), int_workspace_size_in_bytes,
static_cast<int32_t*>(qo_indptr.data_ptr()), static_cast<int32_t*>(kv_indptr.data_ptr()),
batch_size, num_qo_heads, num_kv_heads, head_dim,
/*page_size=*/1);
Expand All @@ -480,8 +494,8 @@ void BatchPrefillWithRaggedKVCachePyTorchWrapper::BeginForward(
void BatchPrefillWithRaggedKVCachePyTorchWrapper::EndForward() { handler_->EndForward(); }

void BatchPrefillWithRaggedKVCachePyTorchWrapper::UpdatePageLockedBufferSize(
unsigned int max_workspace_size_in_bytes) {
handler_->UpdatePageLockedBufferSize(max_workspace_size_in_bytes);
unsigned int int_workspace_size_in_bytes) {
handler_->UpdatePageLockedBufferSize(int_workspace_size_in_bytes);
}

std::vector<torch::Tensor> BatchPrefillWithRaggedKVCachePyTorchWrapper::Forward(
Expand Down
10 changes: 5 additions & 5 deletions python/csrc/flashinfer_ops_decode.h
Original file line number Diff line number Diff line change
Expand Up @@ -28,13 +28,13 @@ torch::Tensor single_decode_with_kv_cache(torch::Tensor q, torch::Tensor k, torc

class BatchDecodeWithPagedKVCachePyTorchWrapper {
public:
void BeginForward(torch::Tensor workspace_buffer, torch::Tensor indptr,
torch::Tensor last_page_len, unsigned int batch_size, unsigned int num_qo_heads,
unsigned int num_kv_heads, unsigned int head_dim, unsigned int page_size,
unsigned int pos_encoding_mode, float logits_soft_cap,
void BeginForward(torch::Tensor float_workspace_buffer, torch::Tensor int_workspace_buffer,
torch::Tensor indptr, torch::Tensor last_page_len, unsigned int batch_size,
unsigned int num_qo_heads, unsigned int num_kv_heads, unsigned int head_dim,
unsigned int page_size, unsigned int pos_encoding_mode, float logits_soft_cap,
torch::Tensor empty_q_data, torch::Tensor empty_kv_data);
void EndForward();
void UpdatePageLockedBufferSize(uint32_t max_workspace_size_in_bytes);
void UpdatePageLockedBufferSize(uint32_t int_workspace_size_in_bytes);
bool IsCUDAGraphEnabled() const { return handler_->IsCUDAGraphEnabled(); }
std::vector<torch::Tensor> Forward(torch::Tensor q, std::optional<torch::Tensor> paged_kv_cache,
std::optional<torch::Tensor> paged_k_cache,
Expand Down
15 changes: 8 additions & 7 deletions python/csrc/flashinfer_ops_prefill.h
Original file line number Diff line number Diff line change
Expand Up @@ -34,13 +34,13 @@ std::vector<torch::Tensor> single_prefill_with_kv_cache_custom_mask(

class BatchPrefillWithPagedKVCachePyTorchWrapper {
public:
void BeginForward(torch::Tensor workspace_buffer, torch::Tensor qo_indptr,
torch::Tensor page_kv_indptr, unsigned int batch_size,
void BeginForward(torch::Tensor float_workspace_buffer, torch::Tensor int_workspace_buffer,
torch::Tensor qo_indptr, torch::Tensor page_kv_indptr, unsigned int batch_size,
unsigned int num_qo_heads, unsigned int num_kv_heads, unsigned int head_dim,
unsigned page_size, torch::Tensor empty_q_data);
void EndForward();
bool IsCUDAGraphEnabled() const { return handler_->IsCUDAGraphEnabled(); }
void UpdatePageLockedBufferSize(uint32_t max_workspace_size_in_bytes);
void UpdatePageLockedBufferSize(uint32_t int_workspace_size_in_bytes);
std::vector<torch::Tensor> Forward(torch::Tensor q, torch::Tensor qo_indptr,
std::optional<torch::Tensor> paged_kv_cache,
std::optional<torch::Tensor> paged_k_cache,
Expand Down Expand Up @@ -69,12 +69,13 @@ class BatchPrefillWithPagedKVCachePyTorchWrapper {

class BatchPrefillWithRaggedKVCachePyTorchWrapper {
public:
void BeginForward(torch::Tensor workspace_buffer, torch::Tensor qo_indptr,
torch::Tensor kv_indptr, unsigned int batch_size, unsigned int num_qo_heads,
unsigned int num_kv_heads, unsigned int head_dim, torch::Tensor empty_q_data);
void BeginForward(torch::Tensor float_workspace_buffer, torch::Tensor int_workspace_buffer,
torch::Tensor qo_indptr, torch::Tensor kv_indptr, unsigned int batch_size,
unsigned int num_qo_heads, unsigned int num_kv_heads, unsigned int head_dim,
torch::Tensor empty_q_data);
void EndForward();
bool IsCUDAGraphEnabled() const { return handler_->IsCUDAGraphEnabled(); }
void UpdatePageLockedBufferSize(uint32_t max_workspace_size_in_bytes);
void UpdatePageLockedBufferSize(uint32_t int_workspace_size_in_bytes);
std::vector<torch::Tensor> Forward(torch::Tensor q, torch::Tensor qo_indptr, torch::Tensor k,
torch::Tensor v, torch::Tensor kv_indptr, bool causal,
unsigned int pos_encoding_mode, bool allow_fp16_qk_reduction,
Expand Down
52 changes: 36 additions & 16 deletions python/flashinfer/cascade.py
Original file line number Diff line number Diff line change
Expand Up @@ -257,22 +257,32 @@ class BatchDecodeWithSharedPrefixPagedKVCacheWrapper:
manages the lifecycle of these data structures.
"""

def __init__(self, workspace_buffer: torch.Tensor, kv_layout: str = "NHD") -> None:
def __init__(
self, float_workspace_buffer: torch.Tensor, kv_layout: str = "NHD"
) -> None:
self._batch_decode_wrapper = BatchDecodeWithPagedKVCacheWrapper(
workspace_buffer, kv_layout
float_workspace_buffer, kv_layout
)
self._kv_layout = kv_layout

def reset_workspace_buffer(self, new_workspace_buffer: torch.Tensor) -> None:
def reset_workspace_buffer(
self, float_workspace_buffer: torch.Tensor, int_workspace_buffer
) -> None:
r"""Reset the workspace buffer.
Parameters
----------
new_workspace_buffer : torch.Tensor
The new workspace buffer, the device of the new workspace buffer should
float_workspace_buffer : torch.Tensor
The new float workspace buffer, the device of the new float workspace buffer should
be the same as the device of the input tensors.
int_workspace_buffer : torch.Tensor
The new int workspace buffer, the device of the new int workspace buffer should
be the same as the device of the input tensors.
"""
self._batch_decode_wrapper.reset_workspace_buffer(new_workspace_buffer)
self._batch_decode_wrapper.reset_workspace_buffer(
float_workspace_buffer, int_workspace_buffer
)

def begin_forward(
self,
Expand Down Expand Up @@ -503,33 +513,43 @@ class BatchPrefillWithSharedPrefixPagedKVCacheWrapper:
layers). This wrapper class manages the lifecycle of these data structures.
"""

def __init__(self, workspace_buffer: torch.Tensor, kv_layout: str = "NHD") -> None:
def __init__(
self, float_workspace_buffer: torch.Tensor, kv_layout: str = "NHD"
) -> None:
r"""Constructor of :class:`BatchDecodeWithSharedPrefixPagedKVCacheWrapper`.
Parameters
----------
workspace_buffer : torch.Tensor
The user reserved workspace buffer used to store auxiliary data structures,
recommended size is 128MB, the device of the workspace buffer should be the
same as the device of the input tensors.
float_workspace_buffer : torch.Tensor
The user reserved float workspace buffer used to store intermediate attention results
in the split-k algorithm. The recommended size is 128MB, the device of the workspace
buffer should be the same as the device of the input tensors.
kv_layout : str
The layout of the input k/v tensors, could be either ``NHD`` or ``HND``.
"""
self._batch_prefill_wrapper = BatchPrefillWithPagedKVCacheWrapper(
workspace_buffer, kv_layout
float_workspace_buffer, kv_layout
)
self._kv_layout = kv_layout

def reset_workspace_buffer(self, new_workspace_buffer: torch.Tensor) -> None:
def reset_workspace_buffer(
self, float_workspace_buffer: torch.Tensor, int_workspace_buffer
) -> None:
r"""Reset the workspace buffer.
Parameters
----------
new_workspace_buffer : torch.Tensor
The new workspace buffer, the device of the new workspace buffer should
float_workspace_buffer : torch.Tensor
The new float workspace buffer, the device of the new float workspace buffer should
be the same as the device of the input tensors.
int_workspace_buffer : torch.Tensor
The new int workspace buffer, the device of the new int workspace buffer should
be the same as the device of the input tensors.
"""
self._batch_prefill_wrapper.reset_workspace_buffer(new_workspace_buffer)
self._batch_prefill_wrapper.reset_workspace_buffer(
float_workspace_buffer, int_workspace_buffer
)

def begin_forward(
self,
Expand Down
Loading

0 comments on commit a7ee566

Please # to comment.