-
Notifications
You must be signed in to change notification settings - Fork 219
New issue
Have a question about this project? # for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “#”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? # to your account
feat: Separate Q and KV dtypes for decode #286
Conversation
@yzh119 Please let me know if this is on the right track! I couldn't see anything directly related to the dtype of the query in the kernels, so my assumption is this should "just work", but I don't know if this will not affect eg. |
Yes I do think you are on the right track, thank you!
I don't think so. |
@yzh119 The modified unit test passes for me, can you review and validate? |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Hi @Yard1 , thanks so much for doing this and it look good to me in general.
I beg some other changes, mainly around BeginForward
functions because it seems you assume we are using the same data type for q and kv and it might affect some resource estimation.
I left some suggested changes, besides them, you also need to separate qtype and kvtype in this function (pass the qtype also as an empty tensor):
flashinfer/python/flashinfer/decode.py
Lines 532 to 620 in 1250b68
def begin_forward( | |
self, | |
indptr: torch.Tensor, | |
indices: torch.Tensor, | |
last_page_len: torch.Tensor, | |
num_qo_heads: int, | |
num_kv_heads: int, | |
head_dim: int, | |
page_size: int, | |
pos_encoding_mode: str = "NONE", | |
data_type: Union[str, torch.dtype] = "float16", | |
): | |
r"""Create auxiliary data structures for batch decode for multiple forward calls | |
within the same decode step. | |
Parameters | |
---------- | |
indptr : torch.Tensor | |
The indptr of the paged kv cache, shape: ``[batch_size + 1]`` | |
indices : torch.Tensor | |
The page indices of the paged kv cache, shape: ``[qo_indptr[-1]]`` | |
last_page_len : torch.Tensor | |
The number of entries in the last page of each request in the paged kv | |
cache, shape: ``[batch_size]`` | |
num_qo_heads : int | |
The number of query/output heads | |
num_kv_heads : int | |
The number of key/value heads | |
head_dim : int | |
The dimension of the heads | |
page_size : int | |
The page size of the paged kv cache | |
pos_encoding_mode : str | |
Whether to apply RoPE on-the-fly inside attention kernels, could be | |
``NONE``/``ROPE_LLAMA`` (LLAMA style rotary embedding) /``ALIBI``. | |
data_type : Union[str, torch.dtype] | |
The data type of the paged kv cache | |
Note | |
---- | |
The :meth:`begin_forward` method should be called before any :meth:`forward` or | |
:meth:`forward_return_lse` calls, auxiliary data structures will be created | |
during this call and cached for multiple forward calls. | |
The ``num_qo_heads`` must be a multiple of ``num_kv_heads``. If ``num_qo_heads`` | |
is not equal to ``num_kv_heads``, the function will use | |
`grouped query attention <https://arxiv.org/abs/2305.13245>`_. | |
""" | |
batch_size = len(last_page_len) | |
if self.is_cuda_graph_enabled: | |
if batch_size != self._fixed_batch_size: | |
raise ValueError( | |
"The batch size should be fixed in cudagraph mode, the runtime batch size {} " | |
" mismatches the batch size set during initialization {}".format( | |
batch_size, self._fixed_batch_size | |
) | |
) | |
if len(indices) > len(self._paged_kv_indices_buf): | |
raise ValueError( | |
"The size of indices should be less than or equal to the allocated buffer" | |
) | |
self._paged_kv_indptr_buf.copy_(indptr) | |
self._paged_kv_indices_buf[: len(indices)] = indices | |
self._paged_kv_last_page_len_buf.copy_(last_page_len) | |
else: | |
self._paged_kv_indptr_buf = indptr | |
self._paged_kv_indices_buf = indices | |
self._paged_kv_last_page_len_buf = last_page_len | |
# NOTE(Zihao): the following tensor acts as placeholder to pass dtype info | |
empty_data = torch.empty( | |
0, | |
dtype=( | |
getattr(torch, data_type) if isinstance(data_type, str) else data_type | |
), | |
) | |
self._wrapper.begin_forward( | |
self._workspace_buffer, | |
indptr, | |
last_page_len, | |
batch_size, | |
num_qo_heads, | |
num_kv_heads, | |
head_dim, | |
page_size, | |
PosEncodingMode[pos_encoding_mode].value, | |
empty_data, | |
) |
and update
flashinfer/python/csrc/flashinfer_ops.h
Lines 77 to 80 in 1250b68
void BeginForward(torch::Tensor workspace_buffer, torch::Tensor indptr, | |
torch::Tensor last_page_len, unsigned int batch_size, unsigned int num_qo_heads, | |
unsigned int num_kv_heads, unsigned int head_dim, unsigned int page_size, | |
unsigned int pos_encoding_mode, torch::Tensor empty_data); |
flashinfer/python/csrc/batch_decode.cu
Lines 120 to 188 in 1250b68
void BatchDecodeWithPagedKVCachePyTorchWrapper::BeginForward( | |
torch::Tensor workspace_buffer, torch::Tensor indptr, torch::Tensor last_page_len, | |
unsigned int batch_size, unsigned int num_qo_heads, unsigned int num_kv_heads, | |
unsigned int head_dim, unsigned int page_size, unsigned int pos_encoding_mode, | |
torch::Tensor empty_data) { | |
// NOTE(zihao): not necessary to be CUDA tensor | |
CHECK_CONTIGUOUS(indptr); | |
CHECK_CONTIGUOUS(last_page_len); | |
CHECK_CONTIGUOUS(workspace_buffer); | |
CHECK_DIM(1, indptr); | |
CHECK_DIM(1, last_page_len); | |
CHECK_DIM(1, workspace_buffer); | |
CHECK_EQ(indptr.scalar_type(), torch::kInt32); | |
CHECK_EQ(indptr.scalar_type(), torch::kInt32); | |
CHECK_GQA_HEAD_DIVISIBLE(num_qo_heads, num_kv_heads); | |
size_t workspace_size_in_bytes = workspace_buffer.size(0) * workspace_buffer.element_size(); | |
cudaStream_t torch_current_stream = c10::cuda::getCurrentCUDAStream(); | |
handler_->SetCUDAStream(torch_current_stream); | |
if (is_float8_tensor(empty_data)) { | |
DISPATCH_PYTORCH_DTYPE_TO_CTYPE_FP8(empty_data.scalar_type(), c_type, [&] { | |
return DISPATCH_group_size(num_qo_heads / num_kv_heads, GROUP_SIZE, [&] { | |
return DISPATCH_head_dim(head_dim, HEAD_DIM, [&] { | |
return DISPATCH_kv_layout(kv_layout_, KV_LAYOUT, [&] { | |
return DISPATCH_pos_encoding_mode( | |
PosEncodingMode(pos_encoding_mode), POS_ENCODING_MODE, [&] { | |
cudaError_t status = | |
handler_->BeginForwardDispatched<GROUP_SIZE, HEAD_DIM, PageStorage::kIndices, | |
KV_LAYOUT, POS_ENCODING_MODE, c_type, | |
nv_half, int32_t>( | |
static_cast<void*>(workspace_buffer.data_ptr()), workspace_size_in_bytes, | |
static_cast<int32_t*>(indptr.data_ptr()), | |
static_cast<int32_t*>(last_page_len.data_ptr()), batch_size, num_qo_heads, | |
page_size); | |
TORCH_CHECK(status == cudaSuccess, | |
"BatchDecodeWithPagedKVCache failed with error ", | |
cudaGetErrorString(status)); | |
return true; | |
}); | |
}); | |
}); | |
}); | |
}); | |
} else { | |
DISPATCH_PYTORCH_DTYPE_TO_CTYPE(empty_data.scalar_type(), c_type, [&] { | |
return DISPATCH_group_size(num_qo_heads / num_kv_heads, GROUP_SIZE, [&] { | |
return DISPATCH_head_dim(head_dim, HEAD_DIM, [&] { | |
return DISPATCH_kv_layout(kv_layout_, KV_LAYOUT, [&] { | |
return DISPATCH_pos_encoding_mode( | |
PosEncodingMode(pos_encoding_mode), POS_ENCODING_MODE, [&] { | |
cudaError_t status = | |
handler_->BeginForwardDispatched<GROUP_SIZE, HEAD_DIM, PageStorage::kIndices, | |
KV_LAYOUT, POS_ENCODING_MODE, c_type, c_type, | |
int32_t>( | |
static_cast<void*>(workspace_buffer.data_ptr()), workspace_size_in_bytes, | |
static_cast<int32_t*>(indptr.data_ptr()), | |
static_cast<int32_t*>(last_page_len.data_ptr()), batch_size, num_qo_heads, | |
page_size); | |
TORCH_CHECK(status == cudaSuccess, | |
"BatchDecodeWithPagedKVCache failed with error ", | |
cudaGetErrorString(status)); | |
return true; | |
}); | |
}); | |
}); | |
}); | |
}); | |
} | |
} |
accordingly.
@yzh119 correct, I wanted to avoid having to modify the public API. I don't think the information about the query dtype will be used in resource estimation, but please correct me if that's not the case - happy to do the change then |
Hi @Yard1 , I'm a little bit conservative here because this section of code flashinfer/include/flashinfer/attention/handler.cuh Lines 121 to 130 in 1250b68
might produce different |
Ok sounds good! Let me make the change. |
@yzh119 Updated, ptal! |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM, Thank you @Yard1 !
🤖 I have created a release *beep* *boop* --- ## [0.1.0](v0.0.4...v0.1.0) (2024-06-20) ### Highlights * Support any GQA group size support for tensor-cores kernels. * Support any page size support for tensor-cores kernels. * Support CUDA-Graph for prefill/decode APIs. * Add an option to accelerate decode kernels with Tensor Cores. * Support custom attention mask. (https://docs.flashinfer.ai/tutorials/kv_layout.html#mask-layout-2d-ragged-tensor) * Support logits cap in Grok-1 models. * Fused GPU-sampling kernels: top-p, top-k, speculative verification. (https://docs.flashinfer.ai/api/python/sampling.html) * PyTorch wrapper of group-gemm cutlass kernels. (https://docs.flashinfer.ai/api/python/sampling.html) ### Acknowledgement We thank [@ibsidorenko](https://github.com/ibsidorenko), [@LiuXiaoxuanPKU](https://github.com/LiuXiaoxuanPKU), [@Yard1](https://github.com/Yard1) [@AgrawalAmey](https://github.com/AgrawalAmey), [@xuzhenqi](https://github.com/xuzhenqi), [@mgerstgrasser](https://github.com/mgerstgrasser), [@esmeetu](https://github.com/esmeetu), [@yz-tang](https://github.com/yz-tang), [@HSQ79815](https://github.com/HSQ79815), [@Qubitium](https://github.com/Qubitium), [@shreygupta2809](https://github.com/shreygupta2809), [@sighingnow](https://github.com/sighingnow), [@vinx13](https://github.com/vinx13), [@tqchen](https://github.com/tqchen), [@merrymercy](https://github.com/merrymercy), [@comaniac](https://github.com/comaniac) and many others for their contributions and helpful discussions for 0.0.5 release. ### Refactor * support any GQA group size for tensor-cores kernels ([#301](#301)) ([c111ca](c111ca6)) * support any page size for tensor-cores kernels ([#306](#306)) ([82fd8c](82fd8c7)) ### Features * add `use_tensor_cores` option to decode kernels to accelerate GQA ([#317](#317)) ([3b50dd5](3b50dd5)) * add group gemm operators ([#282](#282)) ([e08ba42](e08ba42)) * initial support of distributed operators ([#289](#289)) ([03553da](03553da)) * initial support of logits hook ([#298](#298)) ([ab1e2ad](ab1e2ad)) * Separate Q and KV dtypes for decode ([#286](#286)) ([5602659](5602659)) * support cuda graph for batched multi-query(prefill/append) attention ([#275](#275)) ([83ceb67](83ceb67)) * support cuda graph for batched multi-query(prefill/append) attention ([#277](#277)) ([24cc583](24cc583)) * support custom attention mask in prefill/append attention kernels ([#266](#266)) ([7304282](7304282)) * fused speculative sampilng kernels ([#259](#259)) ([cea2bb](cea2bb9)) * expose sampling APIs in pytorch ([#238](#238)) ([092902](0929023)) ### Performance Improvements * initial cuda graph support ([#256](#256)) ([7e9cc7f](7e9cc7f)) * split kv-cache for prefill/append kernels ([#310](#310)) ([f0bb0a3](f0bb0a3)) * use packed bit array for attention mask ([#308](#308)) ([3d43dc9](3d43dc9)) --- This PR was generated with [Release Please](https://github.com/googleapis/release-please). See [documentation](https://github.com/googleapis/release-please#release-please). --------- Co-authored-by: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com> Co-authored-by: Zihao Ye <expye@outlook.com>
Closes #285
Modified unit tests pass. May need some extra validation.