Skip to content
New issue

Have a question about this project? # for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “#”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? # to your account

feat: Separate Q and KV dtypes for decode #286

Merged
merged 10 commits into from
Jun 13, 2024

Conversation

Yard1
Copy link
Contributor

@Yard1 Yard1 commented Jun 5, 2024

Closes #285

Modified unit tests pass. May need some extra validation.

@Yard1
Copy link
Contributor Author

Yard1 commented Jun 5, 2024

@yzh119 Please let me know if this is on the right track! I couldn't see anything directly related to the dtype of the query in the kernels, so my assumption is this should "just work", but I don't know if this will not affect eg. q_vec loading. I am compiling it to test it right now.

@yzh119
Copy link
Collaborator

yzh119 commented Jun 5, 2024

Yes I do think you are on the right track, thank you!

but I don't know if this will not affect eg. q_vec loading.

I don't think so.

@Yard1 Yard1 marked this pull request as ready for review June 11, 2024 21:31
@Yard1
Copy link
Contributor Author

Yard1 commented Jun 11, 2024

@yzh119 The modified unit test passes for me, can you review and validate?

@Yard1 Yard1 changed the title [WIP] Separate Q and KV dtypes for decode Separate Q and KV dtypes for decode Jun 11, 2024
@Yard1 Yard1 changed the title Separate Q and KV dtypes for decode feat: Separate Q and KV dtypes for decode Jun 11, 2024
Copy link
Collaborator

@yzh119 yzh119 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hi @Yard1 , thanks so much for doing this and it look good to me in general.

I beg some other changes, mainly around BeginForward functions because it seems you assume we are using the same data type for q and kv and it might affect some resource estimation.

I left some suggested changes, besides them, you also need to separate qtype and kvtype in this function (pass the qtype also as an empty tensor):

def begin_forward(
self,
indptr: torch.Tensor,
indices: torch.Tensor,
last_page_len: torch.Tensor,
num_qo_heads: int,
num_kv_heads: int,
head_dim: int,
page_size: int,
pos_encoding_mode: str = "NONE",
data_type: Union[str, torch.dtype] = "float16",
):
r"""Create auxiliary data structures for batch decode for multiple forward calls
within the same decode step.
Parameters
----------
indptr : torch.Tensor
The indptr of the paged kv cache, shape: ``[batch_size + 1]``
indices : torch.Tensor
The page indices of the paged kv cache, shape: ``[qo_indptr[-1]]``
last_page_len : torch.Tensor
The number of entries in the last page of each request in the paged kv
cache, shape: ``[batch_size]``
num_qo_heads : int
The number of query/output heads
num_kv_heads : int
The number of key/value heads
head_dim : int
The dimension of the heads
page_size : int
The page size of the paged kv cache
pos_encoding_mode : str
Whether to apply RoPE on-the-fly inside attention kernels, could be
``NONE``/``ROPE_LLAMA`` (LLAMA style rotary embedding) /``ALIBI``.
data_type : Union[str, torch.dtype]
The data type of the paged kv cache
Note
----
The :meth:`begin_forward` method should be called before any :meth:`forward` or
:meth:`forward_return_lse` calls, auxiliary data structures will be created
during this call and cached for multiple forward calls.
The ``num_qo_heads`` must be a multiple of ``num_kv_heads``. If ``num_qo_heads``
is not equal to ``num_kv_heads``, the function will use
`grouped query attention <https://arxiv.org/abs/2305.13245>`_.
"""
batch_size = len(last_page_len)
if self.is_cuda_graph_enabled:
if batch_size != self._fixed_batch_size:
raise ValueError(
"The batch size should be fixed in cudagraph mode, the runtime batch size {} "
" mismatches the batch size set during initialization {}".format(
batch_size, self._fixed_batch_size
)
)
if len(indices) > len(self._paged_kv_indices_buf):
raise ValueError(
"The size of indices should be less than or equal to the allocated buffer"
)
self._paged_kv_indptr_buf.copy_(indptr)
self._paged_kv_indices_buf[: len(indices)] = indices
self._paged_kv_last_page_len_buf.copy_(last_page_len)
else:
self._paged_kv_indptr_buf = indptr
self._paged_kv_indices_buf = indices
self._paged_kv_last_page_len_buf = last_page_len
# NOTE(Zihao): the following tensor acts as placeholder to pass dtype info
empty_data = torch.empty(
0,
dtype=(
getattr(torch, data_type) if isinstance(data_type, str) else data_type
),
)
self._wrapper.begin_forward(
self._workspace_buffer,
indptr,
last_page_len,
batch_size,
num_qo_heads,
num_kv_heads,
head_dim,
page_size,
PosEncodingMode[pos_encoding_mode].value,
empty_data,
)

and update

void BeginForward(torch::Tensor workspace_buffer, torch::Tensor indptr,
torch::Tensor last_page_len, unsigned int batch_size, unsigned int num_qo_heads,
unsigned int num_kv_heads, unsigned int head_dim, unsigned int page_size,
unsigned int pos_encoding_mode, torch::Tensor empty_data);

void BatchDecodeWithPagedKVCachePyTorchWrapper::BeginForward(
torch::Tensor workspace_buffer, torch::Tensor indptr, torch::Tensor last_page_len,
unsigned int batch_size, unsigned int num_qo_heads, unsigned int num_kv_heads,
unsigned int head_dim, unsigned int page_size, unsigned int pos_encoding_mode,
torch::Tensor empty_data) {
// NOTE(zihao): not necessary to be CUDA tensor
CHECK_CONTIGUOUS(indptr);
CHECK_CONTIGUOUS(last_page_len);
CHECK_CONTIGUOUS(workspace_buffer);
CHECK_DIM(1, indptr);
CHECK_DIM(1, last_page_len);
CHECK_DIM(1, workspace_buffer);
CHECK_EQ(indptr.scalar_type(), torch::kInt32);
CHECK_EQ(indptr.scalar_type(), torch::kInt32);
CHECK_GQA_HEAD_DIVISIBLE(num_qo_heads, num_kv_heads);
size_t workspace_size_in_bytes = workspace_buffer.size(0) * workspace_buffer.element_size();
cudaStream_t torch_current_stream = c10::cuda::getCurrentCUDAStream();
handler_->SetCUDAStream(torch_current_stream);
if (is_float8_tensor(empty_data)) {
DISPATCH_PYTORCH_DTYPE_TO_CTYPE_FP8(empty_data.scalar_type(), c_type, [&] {
return DISPATCH_group_size(num_qo_heads / num_kv_heads, GROUP_SIZE, [&] {
return DISPATCH_head_dim(head_dim, HEAD_DIM, [&] {
return DISPATCH_kv_layout(kv_layout_, KV_LAYOUT, [&] {
return DISPATCH_pos_encoding_mode(
PosEncodingMode(pos_encoding_mode), POS_ENCODING_MODE, [&] {
cudaError_t status =
handler_->BeginForwardDispatched<GROUP_SIZE, HEAD_DIM, PageStorage::kIndices,
KV_LAYOUT, POS_ENCODING_MODE, c_type,
nv_half, int32_t>(
static_cast<void*>(workspace_buffer.data_ptr()), workspace_size_in_bytes,
static_cast<int32_t*>(indptr.data_ptr()),
static_cast<int32_t*>(last_page_len.data_ptr()), batch_size, num_qo_heads,
page_size);
TORCH_CHECK(status == cudaSuccess,
"BatchDecodeWithPagedKVCache failed with error ",
cudaGetErrorString(status));
return true;
});
});
});
});
});
} else {
DISPATCH_PYTORCH_DTYPE_TO_CTYPE(empty_data.scalar_type(), c_type, [&] {
return DISPATCH_group_size(num_qo_heads / num_kv_heads, GROUP_SIZE, [&] {
return DISPATCH_head_dim(head_dim, HEAD_DIM, [&] {
return DISPATCH_kv_layout(kv_layout_, KV_LAYOUT, [&] {
return DISPATCH_pos_encoding_mode(
PosEncodingMode(pos_encoding_mode), POS_ENCODING_MODE, [&] {
cudaError_t status =
handler_->BeginForwardDispatched<GROUP_SIZE, HEAD_DIM, PageStorage::kIndices,
KV_LAYOUT, POS_ENCODING_MODE, c_type, c_type,
int32_t>(
static_cast<void*>(workspace_buffer.data_ptr()), workspace_size_in_bytes,
static_cast<int32_t*>(indptr.data_ptr()),
static_cast<int32_t*>(last_page_len.data_ptr()), batch_size, num_qo_heads,
page_size);
TORCH_CHECK(status == cudaSuccess,
"BatchDecodeWithPagedKVCache failed with error ",
cudaGetErrorString(status));
return true;
});
});
});
});
});
}
}

accordingly.

@Yard1
Copy link
Contributor Author

Yard1 commented Jun 12, 2024

@yzh119 correct, I wanted to avoid having to modify the public API. I don't think the information about the query dtype will be used in resource estimation, but please correct me if that's not the case - happy to do the change then

@yzh119
Copy link
Collaborator

yzh119 commented Jun 12, 2024

Hi @Yard1 , I'm a little bit conservative here because this section of code

auto partition_kv_kernel = BatchDecodeWithPagedKVCacheKernel<
/*partition_kv=*/true, POS_ENCODING_MODE, num_stages_smem, tile_size_per_bdx, vec_size, bdx,
bdy, bdz, page_storage, kv_layout, DTypeIn, DTypeOut, IdType>;
int num_blocks_per_sm = 0;
int num_sm = 0;
int dev_id = 0;
FLASHINFER_CUDA_CALL(cudaGetDevice(&dev_id));
FLASHINFER_CUDA_CALL(cudaDeviceGetAttribute(&num_sm, cudaDevAttrMultiProcessorCount, dev_id));
FLASHINFER_CUDA_CALL(cudaOccupancyMaxActiveBlocksPerMultiprocessor(
&num_blocks_per_sm, partition_kv_kernel, num_threads, smem_size));

might produce different num_blocks_per_sm because of the difference of qtype in the kernel.

@Yard1
Copy link
Contributor Author

Yard1 commented Jun 12, 2024

Ok sounds good! Let me make the change.

@Yard1 Yard1 requested a review from yzh119 June 13, 2024 23:10
@Yard1
Copy link
Contributor Author

Yard1 commented Jun 13, 2024

@yzh119 Updated, ptal!

Copy link
Collaborator

@yzh119 yzh119 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM, Thank you @Yard1 !

@yzh119 yzh119 merged commit 5602659 into flashinfer-ai:main Jun 13, 2024
@Yard1 Yard1 deleted the separate_q_kv_dtype_decode branch June 13, 2024 23:51
yzh119 added a commit that referenced this pull request Jun 20, 2024
🤖 I have created a release *beep* *boop*
---


##
[0.1.0](v0.0.4...v0.1.0)
(2024-06-20)

### Highlights

* Support any GQA group size support for tensor-cores kernels.
* Support any page size support for tensor-cores kernels.
* Support CUDA-Graph for prefill/decode APIs.
* Add an option to accelerate decode kernels with Tensor Cores.
* Support custom attention mask.
(https://docs.flashinfer.ai/tutorials/kv_layout.html#mask-layout-2d-ragged-tensor)
* Support logits cap in Grok-1 models.
* Fused GPU-sampling kernels: top-p, top-k, speculative verification.
(https://docs.flashinfer.ai/api/python/sampling.html)
* PyTorch wrapper of group-gemm cutlass kernels.
(https://docs.flashinfer.ai/api/python/sampling.html)

### Acknowledgement

We thank [@ibsidorenko](https://github.com/ibsidorenko),
[@LiuXiaoxuanPKU](https://github.com/LiuXiaoxuanPKU),
[@Yard1](https://github.com/Yard1)
[@AgrawalAmey](https://github.com/AgrawalAmey),
[@xuzhenqi](https://github.com/xuzhenqi),
[@mgerstgrasser](https://github.com/mgerstgrasser),
[@esmeetu](https://github.com/esmeetu),
[@yz-tang](https://github.com/yz-tang),
[@HSQ79815](https://github.com/HSQ79815),
[@Qubitium](https://github.com/Qubitium),
[@shreygupta2809](https://github.com/shreygupta2809),
[@sighingnow](https://github.com/sighingnow),
[@vinx13](https://github.com/vinx13),
[@tqchen](https://github.com/tqchen),
[@merrymercy](https://github.com/merrymercy),
[@comaniac](https://github.com/comaniac) and many others for their
contributions and helpful discussions for 0.0.5 release.

### Refactor

* support any GQA group size for tensor-cores kernels
([#301](#301))
([c111ca](c111ca6))
* support any page size for tensor-cores kernels
([#306](#306))
([82fd8c](82fd8c7))


### Features

* add `use_tensor_cores` option to decode kernels to accelerate GQA
([#317](#317))
([3b50dd5](3b50dd5))
* add group gemm operators
([#282](#282))
([e08ba42](e08ba42))
* initial support of distributed operators
([#289](#289))
([03553da](03553da))
* initial support of logits hook
([#298](#298))
([ab1e2ad](ab1e2ad))
* Separate Q and KV dtypes for decode
([#286](#286))
([5602659](5602659))
* support cuda graph for batched multi-query(prefill/append) attention
([#275](#275))
([83ceb67](83ceb67))
* support cuda graph for batched multi-query(prefill/append) attention
([#277](#277))
([24cc583](24cc583))
* support custom attention mask in prefill/append attention kernels
([#266](#266))
([7304282](7304282))
* fused speculative sampilng kernels
([#259](#259))
([cea2bb](cea2bb9))
* expose sampling APIs in pytorch
([#238](#238))
([092902](0929023))


### Performance Improvements

* initial cuda graph support
([#256](#256))
([7e9cc7f](7e9cc7f))
* split kv-cache for prefill/append kernels
([#310](#310))
([f0bb0a3](f0bb0a3))
* use packed bit array for attention mask
([#308](#308))
([3d43dc9](3d43dc9))

---
This PR was generated with [Release
Please](https://github.com/googleapis/release-please). See
[documentation](https://github.com/googleapis/release-please#release-please).

---------

Co-authored-by: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com>
Co-authored-by: Zihao Ye <expye@outlook.com>
# for free to join this conversation on GitHub. Already have an account? # to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

[Q&A] Any palns for different dtypes for Q (query) and KV (kv-cache)?
2 participants