Skip to content
New issue

Have a question about this project? # for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “#”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? # to your account

feat: support non-contiguous (packed) input for prefill kernels #404

Merged
merged 6 commits into from
Jul 29, 2024

Conversation

yzh119
Copy link
Collaborator

@yzh119 yzh119 commented Jul 29, 2024

This PR implements #311 , after this PR, we support packed qkv input without explictly convert make the input contiguous:

packed_qkv = W_qkv(x) # (nnz, (num_qo_heads + 2 * num_kv_heads) * head_dim)
q = packed_qkv[..., : num_qo_heads * head_dim].reshape(-1, num_qo_heads, head_dim)
k = packed_qkv[..., num_qo_heads * head_dim: (num_qo_heads + num_kv_heads) * head_dim].reshape(-1, num_kv_heads, head_dim)
v = packed_qkv[..., (num_qo_heads + num_kv_heads) * head_dim:].reshape(-1, num_kv_heads, head_dim)
apply_rope_inplace(q, k, indptr, offsets)
ragged_prefill_wrapper.forward(q, k, v)

Before this PR, we need to make q/k/v contiguous before we launch the attention kernel, which incurs some overhead.

I observe slight (<1%) performance degration after this PR for non-packed input, which is acceptable IMO.

@yzh119 yzh119 merged commit 68c3719 into main Jul 29, 2024
yzh119 added a commit that referenced this pull request Jul 29, 2024
🤖 I have created a release *beep* *boop*
---

##
[0.1.2](v0.1.1...v0.1.2)
(2024-07-29)

### Bugfix
* Fix the sampling kernel bug for cu118
([#386](#386),
[#387](#387))
([0cd499](0cd4994),
[dc3f18](dc3f184))

### Features

* add llama 3.1 style rope
([#401](#401))
([4c89dec](4c89dec))
* non-inplace rope operators
([#405](#405))
([74ffba1](74ffba1))
* sliding window attention
([#406](#406))
([28cffd3](28cffd3))
* support non-contiguous (packed) input for prefill kernels
([#404](#404))
([68c3719](68c3719))


### Performance Improvements

* slight optimization on merge states
([#313](#313))
([701c813](701c813))

---
This PR was generated with [Release
Please](https://github.com/googleapis/release-please). See
[documentation](https://github.com/googleapis/release-please#release-please).

---------

Co-authored-by: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com>
Co-authored-by: Zihao Ye <expye@outlook.com>
@yzh119 yzh119 deleted the allow-packed branch August 3, 2024 00:20
yzh119 pushed a commit that referenced this pull request Oct 25, 2024
## Motivation

Previously, only ragged version of prefill kernel supported
non-contiguous query tensor (#404). But with paged kv cache, you have to
make query tensor contiguous. Libraries like vLLM or SGLang must make
query tensor contiguous before calling flashinfer kernels ([vLLM call of
flashinfer](https://github.com/vllm-project/vllm/blob/b7df53cd42f3eab007b4f287c151960858e949df/vllm/attention/backends/flashinfer.py#L839),
[SGLang call of
flashinfer](https://github.com/sgl-project/sglang/blob/87a7cfa080cec3f123618c1429b5f998bf5d99cb/python/sglang/srt/layers/attention/flashinfer_backend.py#L236)).
This PR solves it, ensuring that prefill/decode kernels with paged kv
cache support non-contiguous query tensor.

## Main Changes

1. Add strides of query tensor in `BatchPrefillPagedParams` and
`BatchDecodeParams`.
2. Set stride parameters before calling those kernels.
3. Modify JIT compiling templates to support new kernel parameters.
4. Add some tests.

The Python interfaces remain the same. Nothing changes except it accepts
non-contiguous query tensors now!

---------

Signed-off-by: LinHeLurking <LinHe.Lurking@gmail.com>
# for free to join this conversation on GitHub. Already have an account? # to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

1 participant