Skip to content

Commit

Permalink
perf: fix the performance issue of append_paged_kv_cache (#588)
Browse files Browse the repository at this point in the history
The performance of `append_paged_kv_cache` is terrible for small batch
size, which is a known issue that we haven't fixed for a long time, this
PR fixes it. This PR also adds support for non-contiguous append
keys/values (which could be sliced from fused qkv matrix).

We first call a triton kernel to convert `append_indptr` to
`batch_indices` and `positions` (which is similar to [CSR2COO
conversion](https://docs.nvidia.com/cuda/cusparse/#cusparse-t-csr2coo)
in sparse matrix). After the conversion, we can use element parallelism
instead of batch parallelism.

It's also worth trying using triton for the second
`AppendPagedKVCacheKernel` kernel, I think the performance should be
fine. I'll leave it for future work.

Some todo items:
1. add torch.compile support.

After this PR (reference number can be found at #583 ):
```bash
model: l1b      seqlens: [1, 1, 1, 1, 1, 1, 1, 1]                 single_layer: 0.006ms all_layers:   0.094ms throughput:    5.563GB/s
model: l1b      seqlens: [4993, 1, 1, 1, 1, 1, 1, 1]              single_layer: 0.014ms all_layers:   0.216ms throughput: 1514.280GB/s
model: l1b      seqlens: [5000]                                   single_layer: 0.014ms all_layers:   0.216ms throughput: 1517.017GB/s
model: l1b      seqlens: [625, 625, 625, 625, 625, 625, 625, 625] single_layer: 0.014ms all_layers:   0.217ms throughput: 1510.863GB/s
---
model: l3b      seqlens: [1, 1, 1, 1, 1, 1, 1, 1]                 single_layer: 0.006ms all_layers:   0.165ms throughput:   11.123GB/s
model: l3b      seqlens: [4993, 1, 1, 1, 1, 1, 1, 1]              single_layer: 0.021ms all_layers:   0.580ms throughput: 1975.732GB/s
model: l3b      seqlens: [5000]                                   single_layer: 0.021ms all_layers:   0.586ms throughput: 1958.078GB/s
model: l3b      seqlens: [625, 625, 625, 625, 625, 625, 625, 625] single_layer: 0.021ms all_layers:   0.581ms throughput: 1973.174GB/s
---
model: l8b      seqlens: [1, 1, 1, 1, 1, 1, 1, 1]                 single_layer: 0.006ms all_layers:   0.185ms throughput:   11.321GB/s
model: l8b      seqlens: [4993, 1, 1, 1, 1, 1, 1, 1]              single_layer: 0.021ms all_layers:   0.661ms throughput: 1982.815GB/s
model: l8b      seqlens: [5000]                                   single_layer: 0.021ms all_layers:   0.662ms throughput: 1980.227GB/s
model: l8b      seqlens: [625, 625, 625, 625, 625, 625, 625, 625] single_layer: 0.021ms all_layers:   0.667ms throughput: 1964.861GB/s
---
model: l70b-tp8 seqlens: [1, 1, 1, 1, 1, 1, 1, 1]                 single_layer: 0.006ms all_layers:   0.457ms throughput:    1.434GB/s
model: l70b-tp8 seqlens: [4993, 1, 1, 1, 1, 1, 1, 1]              single_layer: 0.009ms all_layers:   0.710ms throughput:  576.866GB/s
model: l70b-tp8 seqlens: [5000]                                   single_layer: 0.009ms all_layers:   0.685ms throughput:  598.366GB/s
model: l70b-tp8 seqlens: [625, 625, 625, 625, 625, 625, 625, 625] single_layer: 0.009ms all_layers:   0.690ms throughput:  593.453GB/s
```

cc @abcdabcd987
  • Loading branch information
yzh119 authored Nov 6, 2024
1 parent 1328693 commit e15f7c9
Show file tree
Hide file tree
Showing 9 changed files with 285 additions and 93 deletions.
9 changes: 8 additions & 1 deletion benchmarks/bench_append_paged_kv_cache.py
Original file line number Diff line number Diff line change
Expand Up @@ -99,12 +99,19 @@ def main():
dtype=torch.int32,
)

batch_indices, positions = flashinfer.get_batch_indices_positions(
x_indptr,
flashinfer.get_seq_lens(kv_indptr, kv_last_page_len, page_len),
k.shape[0],
)

@torch.cuda.nvtx.range(f"model={model_name}, seqlens={seqlens}")
def fn():
flashinfer.append_paged_kv_cache(
k,
v,
x_indptr,
batch_indices,
positions,
layer_buf,
kv_indices,
kv_indptr,
Expand Down
72 changes: 42 additions & 30 deletions include/flashinfer/page.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -249,38 +249,34 @@ __global__ void AppendPagedKVCacheDecodeKernel(paged_kv_t<DType, IdType> paged_k
* \param paged_kv The paged key-value cache
* \param key The key to be appended
* \param value The value to be appended
* \param append_indptr The indptr array of the appended ragged tensor
* \param batch_indices The batch indices of elements to be appended
* \param positions The positions of elements to be appended
*/
template <uint32_t head_dim, uint32_t vec_size, typename DType, typename IdType>
__global__ void AppendPagedKVCachePrefillKernel(paged_kv_t<DType, IdType> paged_kv,
DType* __restrict__ key, DType* __restrict__ value,
IdType* __restrict__ append_indptr) {
__global__ void AppendPagedKVCacheKernel(paged_kv_t<DType, IdType> paged_kv,
DType* __restrict__ append_key,
DType* __restrict__ append_value,
IdType* __restrict__ batch_indices,
IdType* __restrict__ positions, uint32_t nnz,
size_t append_k_stride_n, size_t append_k_stride_h,
size_t append_v_stride_n, size_t append_v_stride_h) {
uint32_t tx = threadIdx.x, ty = threadIdx.y;
uint32_t num_heads = paged_kv.num_heads;
uint32_t batch_idx = blockIdx.x;
uint32_t head_idx = ty;

uint32_t seq_len =
(paged_kv.indptr[batch_idx + 1] - paged_kv.indptr[batch_idx] - 1) * paged_kv.page_size +
paged_kv.last_page_len[batch_idx];
uint32_t append_seq_len = append_indptr[batch_idx + 1] - append_indptr[batch_idx];
uint32_t append_start = seq_len - append_seq_len;

#pragma unroll 2
for (uint32_t j = 0; j < append_seq_len; ++j) {
uint32_t page_seq_idx = j + append_start;
uint32_t page_iter = paged_kv.indptr[batch_idx] + page_seq_idx / paged_kv.page_size;
uint32_t entry_idx = page_seq_idx % paged_kv.page_size;

uint32_t cta_id = blockIdx.x;
uint32_t num_ctas = gridDim.x;

#pragma unroll 4
for (uint32_t i = cta_id; i < nnz; i += num_ctas) {
uint32_t page_iter, entry_idx;
paged_kv.page_size.divmod(paged_kv.indptr[batch_indices[i]] * paged_kv.page_size + positions[i],
page_iter, entry_idx);
DType* k_ptr = paged_kv.get_k_ptr(page_iter, head_idx, entry_idx, tx * vec_size);
DType* v_ptr = paged_kv.get_v_ptr(page_iter, head_idx, entry_idx, tx * vec_size);
vec_t<DType, vec_size>::memcpy(
k_ptr,
key + ((append_indptr[batch_idx] + j) * num_heads + head_idx) * head_dim + tx * vec_size);

k_ptr, append_key + i * append_k_stride_n + head_idx * append_k_stride_h + tx * vec_size);
vec_t<DType, vec_size>::memcpy(
v_ptr,
value + ((append_indptr[batch_idx] + j) * num_heads + head_idx) * head_dim + tx * vec_size);
v_ptr, append_value + i * append_v_stride_n + head_idx * append_v_stride_h + tx * vec_size);
}
}

Expand Down Expand Up @@ -327,20 +323,36 @@ cudaError_t AppendPagedKVCacheDecode(paged_kv_t<DType, IdType> paged_kv, DType*
* \return status Indicates whether CUDA calls are successful
*/
template <typename DType, typename IdType>
cudaError_t AppendPagedKVCache(paged_kv_t<DType, IdType> paged_kv, DType* key, DType* value,
IdType* append_indptr, cudaStream_t stream = nullptr) {
cudaError_t AppendPagedKVCache(paged_kv_t<DType, IdType> paged_kv, DType* append_key,
DType* append_value, IdType* batch_indices, IdType* positions,
uint32_t nnz, size_t append_k_stride_n, size_t append_k_stride_h,
size_t append_v_stride_n, size_t append_v_stride_h,
cudaStream_t stream = nullptr) {
uint32_t head_dim = paged_kv.head_dim;
uint32_t batch_size = paged_kv.batch_size;
uint32_t num_heads = paged_kv.num_heads;
int dev_id = 0;
int num_sms = 0;
int num_blocks_per_sm = 0;
FLASHINFER_CUDA_CALL(cudaGetDevice(&dev_id));
FLASHINFER_CUDA_CALL(cudaDeviceGetAttribute(&num_sms, cudaDevAttrMultiProcessorCount, dev_id));

DISPATCH_HEAD_DIM(head_dim, HEAD_DIM, {
constexpr uint32_t vec_size = std::max(16 / sizeof(DType), HEAD_DIM / 32);
uint32_t bdx = HEAD_DIM / vec_size;
uint32_t bdy = num_heads;
// NOTE(Zihao): could be slow for small batch size, will optimize later
dim3 nblks(batch_size);
uint32_t num_threads = bdx * bdy;
uint32_t smem_size = 0;
auto kernel = AppendPagedKVCacheKernel<HEAD_DIM, vec_size, DType, IdType>;
FLASHINFER_CUDA_CALL(cudaOccupancyMaxActiveBlocksPerMultiprocessor(&num_blocks_per_sm, kernel,
num_threads, smem_size));
num_blocks_per_sm = min(num_blocks_per_sm, ceil_div(int(nnz), num_sms));
dim3 nblks(num_blocks_per_sm * num_sms);
dim3 nthrs(bdx, bdy);
auto kernel = AppendPagedKVCachePrefillKernel<HEAD_DIM, vec_size, DType, IdType>;
void* args[] = {(void*)&paged_kv, (void*)&key, (void*)&value, (void*)&append_indptr};

void* args[] = {(void*)&paged_kv, (void*)&append_key, (void*)&append_value,
(void*)&batch_indices, (void*)&positions, (void*)&nnz,
(void*)&append_k_stride_n, (void*)&append_k_stride_h, (void*)&append_v_stride_n,
(void*)&append_v_stride_h};
FLASHINFER_CUDA_CALL(cudaLaunchKernel((void*)kernel, nblks, nthrs, args, 0, stream));
});
return cudaSuccess;
Expand Down
8 changes: 4 additions & 4 deletions python/csrc/flashinfer_page_ops.cu
Original file line number Diff line number Diff line change
Expand Up @@ -16,10 +16,10 @@
#include <torch/extension.h>

void append_paged_kv_cache(torch::Tensor append_key, torch::Tensor append_value,
torch::Tensor append_indptr, torch::Tensor paged_k_cache,
torch::Tensor paged_v_cache, torch::Tensor kv_indices,
torch::Tensor kv_indptr, torch::Tensor kv_last_page_len,
unsigned int layout);
torch::Tensor batch_indices, torch::Tensor positions,
torch::Tensor paged_k_cache, torch::Tensor paged_v_cache,
torch::Tensor kv_indices, torch::Tensor kv_indptr,
torch::Tensor kv_last_page_len, unsigned int layout);

PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
m.def("append_paged_kv_cache", &append_paged_kv_cache, "Append paged KV-Cache operator");
Expand Down
46 changes: 30 additions & 16 deletions python/csrc/page.cu
Original file line number Diff line number Diff line change
Expand Up @@ -20,13 +20,14 @@
using namespace flashinfer;

void append_paged_kv_cache(torch::Tensor append_key, torch::Tensor append_value,
torch::Tensor append_indptr, torch::Tensor paged_k_cache,
torch::Tensor paged_v_cache, torch::Tensor kv_indices,
torch::Tensor kv_indptr, torch::Tensor kv_last_page_len,
unsigned int layout) {
CHECK_INPUT(append_key);
CHECK_INPUT(append_value);
CHECK_INPUT(append_indptr);
torch::Tensor batch_indices, torch::Tensor positions,
torch::Tensor paged_k_cache, torch::Tensor paged_v_cache,
torch::Tensor kv_indices, torch::Tensor kv_indptr,
torch::Tensor kv_last_page_len, unsigned int layout) {
CHECK_LAST_DIM_CONTIGUOUS(append_key);
CHECK_LAST_DIM_CONTIGUOUS(append_value);
CHECK_INPUT(batch_indices);
CHECK_INPUT(positions);
// NOTE(Zihao): doesn't have to be contiguous
CHECK_LAST_DIM_CONTIGUOUS_INPUT(paged_k_cache);
CHECK_LAST_DIM_CONTIGUOUS_INPUT(paged_v_cache);
Expand All @@ -35,20 +36,24 @@ void append_paged_kv_cache(torch::Tensor append_key, torch::Tensor append_value,
CHECK_INPUT(kv_last_page_len);
CHECK_DIM(3, append_key);
CHECK_DIM(3, append_value);
CHECK_DIM(1, append_indptr);
CHECK_DIM(1, batch_indices);
CHECK_DIM(1, positions);
CHECK_DIM(4, paged_k_cache);
CHECK_DIM(4, paged_v_cache);
CHECK_DIM(1, kv_indices);
CHECK_DIM(1, kv_indptr);
CHECK_DIM(1, kv_last_page_len);
unsigned int nnz = append_key.size(0);
unsigned int batch_size = kv_last_page_len.size(0);
CHECK_EQ(append_indptr.size(0), batch_size + 1);
CHECK_EQ(kv_indptr.size(0), batch_size + 1);
CHECK_EQ(append_indptr.scalar_type(), torch::kInt32);
CHECK_EQ(batch_indices.size(0), nnz);
CHECK_EQ(positions.size(0), nnz);
CHECK_EQ(batch_indices.scalar_type(), torch::kInt32);
CHECK_EQ(positions.scalar_type(), torch::kInt32);
CHECK_EQ(kv_indptr.scalar_type(), torch::kInt32);
CHECK_EQ(kv_indices.scalar_type(), torch::kInt32);
CHECK_EQ(kv_last_page_len.scalar_type(), torch::kInt32);
auto device = append_indptr.device();
auto device = append_key.device();
CHECK_EQ(append_key.device(), device);
CHECK_EQ(append_value.device(), device);
CHECK_EQ(paged_k_cache.device(), device);
Expand Down Expand Up @@ -76,10 +81,17 @@ void append_paged_kv_cache(torch::Tensor append_key, torch::Tensor append_value,
TORCH_CHECK(k_strides == v_strides, "k/v strides must be identical");
kv_cache_strides = k_strides.data();

auto append_k_strides = append_key.strides();
auto append_k_stride_n = append_k_strides[0];
auto append_k_stride_h = append_k_strides[1];
auto append_v_strides = append_value.strides();
auto append_v_stride_n = append_v_strides[0];
auto append_v_stride_h = append_v_strides[1];

CHECK_EQ(append_key.size(1), num_heads);
CHECK_EQ(append_key.size(2), head_dim);
CHECK_EQ(append_value.size(1), num_heads);
CHECK_EQ(append_key.size(2), head_dim);
CHECK_EQ(append_value.size(2), head_dim);

cudaStream_t torch_current_stream = c10::cuda::getCurrentCUDAStream(device.index());

Expand All @@ -92,10 +104,12 @@ void append_paged_kv_cache(torch::Tensor append_key, torch::Tensor append_value,
static_cast<c_type*>(paged_v_cache.data_ptr()), kv_cache_strides,
static_cast<int32_t*>(kv_indices.data_ptr()), static_cast<int32_t*>(kv_indptr.data_ptr()),
static_cast<int32_t*>(kv_last_page_len.data_ptr()));
cudaError_t status =
AppendPagedKVCache(paged_kv, static_cast<c_type*>(append_key.data_ptr()),
static_cast<c_type*>(append_value.data_ptr()),
static_cast<int32_t*>(append_indptr.data_ptr()), torch_current_stream);
cudaError_t status = AppendPagedKVCache(paged_kv, static_cast<c_type*>(append_key.data_ptr()),
static_cast<c_type*>(append_value.data_ptr()),
static_cast<int32_t*>(batch_indices.data_ptr()),
static_cast<int32_t*>(positions.data_ptr()), nnz,
append_k_stride_n, append_k_stride_h, append_v_stride_n,
append_v_stride_h, torch_current_stream);
TORCH_CHECK(status == cudaSuccess,
"AppendPagedKVCache failed with error: ", cudaGetErrorString(status));
return true;
Expand Down
8 changes: 4 additions & 4 deletions python/csrc_aot/flashinfer_ops.cu
Original file line number Diff line number Diff line change
Expand Up @@ -80,10 +80,10 @@ void gemma_fused_add_rmsnorm(torch::Tensor& input, torch::Tensor& residual, torc
//========== page ==========

void append_paged_kv_cache(torch::Tensor append_key, torch::Tensor append_value,
torch::Tensor append_indptr, torch::Tensor paged_k_cache,
torch::Tensor paged_v_cache, torch::Tensor kv_indices,
torch::Tensor kv_indptr, torch::Tensor kv_last_page_len,
unsigned int layout);
torch::Tensor batch_indices, torch::Tensor positions,
torch::Tensor paged_k_cache, torch::Tensor paged_v_cache,
torch::Tensor kv_indices, torch::Tensor kv_indptr,
torch::Tensor kv_last_page_len, unsigned int layout);

//========== prefill ==========

Expand Down
2 changes: 2 additions & 0 deletions python/flashinfer/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,8 @@
from .norm import gemma_rmsnorm as gemma_rmsnorm
from .norm import rmsnorm as rmsnorm
from .page import append_paged_kv_cache as append_paged_kv_cache
from .page import get_batch_indices_positions as get_batch_indices_positions
from .page import get_seq_lens as get_seq_lens
from .prefill import (
BatchPrefillWithPagedKVCacheWrapper as BatchPrefillWithPagedKVCacheWrapper,
)
Expand Down
Loading

0 comments on commit e15f7c9

Please # to comment.