Skip to content
New issue

Have a question about this project? # for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “#”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? # to your account

Fix PagedPrefill python api and some typos #441

Merged
merged 3 commits into from
Aug 13, 2024

Conversation

jianfei-wangg
Copy link
Contributor

Fix two small bugs:

  1. “NHD” and "HND" used confusing
  2. PagedPrefill use self._custom_mask_buf to judge whether is customized_mask, but uninitialized
    Here is the code snippet to reproduce the 2nd bug:
import torch
import flashinfer

# try to reproduce the bug under speculative decoding case
device = torch.device("cuda:0")
num_heads = 32
num_qo_heads = num_heads
num_kv_heads = 32
head_dim = 128
page_size = 4
max_num_pages = 4
batch_size = 1
seq_len = 4
query = torch.randn(seq_len, num_heads, head_dim, dtype=torch.bfloat16, device=device)
packed_kv_cache = torch.randn(max_num_pages, 2, page_size, num_kv_heads, head_dim, dtype=torch.bfloat16, device=device)
ragged_key_cache = packed_kv_cache[:, 0].reshape(-1, num_kv_heads, head_dim)
ragged_value_cache = packed_kv_cache[:, 1].reshape(-1, num_kv_heads, head_dim)

# [4, 15] shape
attn_mask = torch.tensor([
    [ True,  True,  True,  True,  True,  True,  True,  True, False, False, False,  True, False, False, False],
    [ True,  True,  True,  True,  True,  True,  True, False,  True, False, False, False,  True, False, False],
    [ True,  True,  True,  True,  True,  True,  True,  True, False, False, False, False, False,  True, False],
    [ True,  True,  True,  True,  True,  True,  True, False, False,  True, False, False, False, False,  True]
    ], device=device)

mask = attn_mask.reshape(-1)
# packed_mask = flashinfer.quantization.packbits(mask)
workspace_buffer = torch.empty(128 * 1024 * 1024, dtype=torch.uint8, device="cuda:0")
paged_prefill_wrapper = flashinfer.BatchPrefillWithPagedKVCacheWrapper(
    workspace_buffer, "NHD"
)
kv_page_indices = torch.arange(max_num_pages).int().to("cuda:0")
kv_page_indptr = torch.tensor(
    [0, 4], dtype=torch.int32, device="cuda:0"
)
# 1 <= kv_last_page_len <= page_size
kv_last_page_len = torch.tensor(
    [3], dtype=torch.int32, device="cuda:0"
)
qo_indptr = torch.tensor(
[0, 4], dtype=torch.int32, device="cuda:0")

# create auxiliary data structures for batch decode attention
paged_prefill_wrapper.begin_forward(
    qo_indptr,
    kv_page_indptr,
    kv_page_indices,
    kv_last_page_len,
    num_qo_heads,
    num_kv_heads,
    head_dim,
    page_size,
    mask,
    q_data_type=torch.bfloat16
)
# assert torch.equal(paged_prefill_wrapper._custom_mask, packed_mask)
# assert paged_prefill_wrapper._custom_mask_buf is not None
q = query
o = paged_prefill_wrapper.forward(q, packed_kv_cache, causal=False)
paged_prefill_wrapper.end_forward()

# ragged attn
workspace_buffer_ragged = torch.empty(128 * 1024 * 1024, dtype=torch.uint8, device="cuda:0")
ragged_prefill_wrapper = flashinfer.BatchPrefillWithRaggedKVCacheWrapper(
    workspace_buffer_ragged, "NHD"
)
kv_indptr = torch.tensor(
    [0, 15], dtype=torch.int32, device="cuda:0"
)
ragged_prefill_wrapper.begin_forward(
    qo_indptr,
    kv_indptr,
    num_qo_heads,
    num_kv_heads,
    head_dim,
    mask,
    q_data_type='bfloat16'
    )
ragged_o = ragged_prefill_wrapper.forward(q, ragged_key_cache, ragged_value_cache)
ragged_prefill_wrapper.end_forward()
print("query shape: ", q.shape)
print("paged vs ragged allclose: ", torch.allclose(o, ragged_o, rtol=1e-3, atol=1e-3))
print("paged vs ragged equal: ", torch.equal(o, ragged_o))
assert torch.allclose(o, ragged_o, rtol=1e-3, atol=1e-3)
assert torch.equal(o, ragged_o)

Copy link
Collaborator

@yzh119 yzh119 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM @jianfei-wangg , thanks for your contribution!

@yzh119 yzh119 merged commit 3fff008 into flashinfer-ai:main Aug 13, 2024
yzh119 added a commit that referenced this pull request Aug 13, 2024
🤖 I have created a release *beep* *boop*
---


##
[0.1.5](v0.1.4...v0.1.5)
(2024-08-13)


### Bugfix

* Fix PagedPrefill python api and some typos
([#441](#441))
([3fff008](3fff008))
* fix prefill kernels' lse result for empty kv-cache
([#440](#440))
([6ac28f4](6ac28f4))

### Features

* decouple float and int workspace buffer
([#442](#442))
([a7ee566](a7ee566))


### Performance Improvements

* faster fp8-&gt;fp16 dequantization for pre sm_90 arch
([#439](#439))
([c93f647](c93f647))

### Acknowledgement

We thank contributions and feedbacks from the community:
[@comaniac](https://github.com/comaniac),
[@hnyls2002](https://github.com/hnyls2002),
[@jianfei-wangg](https://github.com/jianfei-wangg),
[@Yard1](https://github.com/Yard1).


---
This PR was generated with [Release
Please](https://github.com/googleapis/release-please). See
[documentation](https://github.com/googleapis/release-please#release-please).

---------

Co-authored-by: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com>
Co-authored-by: Zihao Ye <expye@outlook.com>
# for free to join this conversation on GitHub. Already have an account? # to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants