Skip to content
New issue

Have a question about this project? # for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “#”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? # to your account

feat: append attention kernels for fp8 kv-cache #420

Merged
merged 19 commits into from
Aug 6, 2024
Merged

Conversation

yzh119
Copy link
Collaborator

@yzh119 yzh119 commented Aug 5, 2024

This implementation do not rely on fp8 tensor cores, but uses fp16 tensor cores instead (so sm_80 architectures can also use it), the fp8 kv-cache will be dequantized on-the-fly.

sm_89 and sm_90 append attention kernels that uses native fp8 tensor cores will be available in later PRs.

CMakeLists.txt Outdated
@@ -91,6 +91,7 @@ set (IDTYPES "i32")
if(FLASHINFER_ENABLE_FP8)
list(APPEND DECODE_DTYPES "e4m3" "e5m2")
list(APPEND DECODE_FP8_DTYPES "e4m3" "e5m2")
list(APPEND PREFILL_FP8_DTYPES "e4m3")
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do we also support e5m2?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

e5m2 support would be great

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

e5m2 added, however, this argument only affects C++ tests and has nothing to do with python wheels.

@yzh119
Copy link
Collaborator Author

yzh119 commented Aug 6, 2024

@Yard1 @comaniac @cassiewilliam
To keep binary size reasonable (<2GB), I only kept fp8 support for BatchPrefillWithPagedKVCacheWrapper, which should be general enough for append attention.

The functionality tests have passed. Feel free to try it and report any possible issues.

@yzh119 yzh119 merged commit 906c2f5 into main Aug 6, 2024
yzh119 added a commit that referenced this pull request Aug 6, 2024
The swizzling mode name in #420 is wrong, this PR aligns it with ptx
documentation:
32B -> 64B
64B -> 128B
yzh119 added a commit that referenced this pull request Aug 9, 2024
🤖 I have created a release *beep* *boop*
---
##
[0.1.4](v0.1.3...v0.1.4)
(2024-08-09)


### Features

* append attention kernels for fp8 kv-cache
([#420](#420))
([906c2f5](906c2f5))
* support min_p sampling
([#422](#422))
([d52f2da](d52f2da))
* deterministic sampling
([#417](#417))
([0dd801d](0dd801d))
* more sampling operator options
([#431](#431))
([68df9c4](68df9c4))
* support fused add rmsnorm
([#419](#419))
([b781513](b781513))
* support fused silu mul
([#427](#427))
([ea0ba9a](ea0ba9a))

### Bug Fixes

* fix dispatch fp16 type when enable fp8
([#430](#430))
([daa5566](daa5566))
* improve numerical stability of sampling kernels
([#429](#429))
([898d8ea](898d8ea))

### Other improvements
* break up `_kernels` into multiple modules
([#428](#428))
([8e482d9](8e482d9))

### Acknowledgement

We thank contributions and feedbacks from the community:
[@comaniac](https://github.com/comaniac),
[@esmeetu](https://github.com/esmeetu),
[@LiuXiaoxuanPKU](https://github.com/LiuXiaoxuanPKU),
[@peng1999](https://github.com/peng1999),
[@xslingcn](https://github.com/xslingcn),
[@Yard1](https://github.com/Yard1),
[@zhyncs](https://github.com/zhyncs).

---
This PR was generated with [Release
Please](https://github.com/googleapis/release-please). See
[documentation](https://github.com/googleapis/release-please#release-please).

---------

Co-authored-by: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com>
Co-authored-by: Zihao Ye <expye@outlook.com>
@yzh119 yzh119 deleted the fp8-with-fp16-tc branch August 10, 2024 18:38
yzh119 added a commit that referenced this pull request Aug 11, 2024
hardware fp8->fp16 fast conversion instruction is not available for
sm_80 & sm_89, which makes #420 slow for these architectures.

this pr uses marlin's fast fp8->fp16x4 conversion algorithm (copied from
vllm project) to accelerate such cases.

Co-authored-by: Antoni Baum <antoni@anyscale.com>
Co-authored-by: Cody Yu <cody@anyscale.com>
zhyncs pushed a commit that referenced this pull request Aug 14, 2024
hardware fp8->fp16 fast conversion instruction is not available for
sm_80 & sm_89, which makes #420 slow for these architectures.

this pr uses marlin's fast fp8->fp16x4 conversion algorithm (copied from
vllm project) to accelerate such cases.

Co-authored-by: Antoni Baum <antoni@anyscale.com>
Co-authored-by: Cody Yu <cody@anyscale.com>
# for free to join this conversation on GitHub. Already have an account? # to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

4 participants