-
Notifications
You must be signed in to change notification settings - Fork 135
New issue
Have a question about this project? # for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “#”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? # to your account
feat: append attention kernels for fp8 kv-cache #420
Conversation
CMakeLists.txt
Outdated
@@ -91,6 +91,7 @@ set (IDTYPES "i32") | |||
if(FLASHINFER_ENABLE_FP8) | |||
list(APPEND DECODE_DTYPES "e4m3" "e5m2") | |||
list(APPEND DECODE_FP8_DTYPES "e4m3" "e5m2") | |||
list(APPEND PREFILL_FP8_DTYPES "e4m3") |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Do we also support e5m2?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
e5m2 support would be great
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
e5m2
added, however, this argument only affects C++ tests and has nothing to do with python wheels.
@Yard1 @comaniac @cassiewilliam The functionality tests have passed. Feel free to try it and report any possible issues. |
The swizzling mode name in #420 is wrong, this PR aligns it with ptx documentation: 32B -> 64B 64B -> 128B
🤖 I have created a release *beep* *boop* --- ## [0.1.4](v0.1.3...v0.1.4) (2024-08-09) ### Features * append attention kernels for fp8 kv-cache ([#420](#420)) ([906c2f5](906c2f5)) * support min_p sampling ([#422](#422)) ([d52f2da](d52f2da)) * deterministic sampling ([#417](#417)) ([0dd801d](0dd801d)) * more sampling operator options ([#431](#431)) ([68df9c4](68df9c4)) * support fused add rmsnorm ([#419](#419)) ([b781513](b781513)) * support fused silu mul ([#427](#427)) ([ea0ba9a](ea0ba9a)) ### Bug Fixes * fix dispatch fp16 type when enable fp8 ([#430](#430)) ([daa5566](daa5566)) * improve numerical stability of sampling kernels ([#429](#429)) ([898d8ea](898d8ea)) ### Other improvements * break up `_kernels` into multiple modules ([#428](#428)) ([8e482d9](8e482d9)) ### Acknowledgement We thank contributions and feedbacks from the community: [@comaniac](https://github.com/comaniac), [@esmeetu](https://github.com/esmeetu), [@LiuXiaoxuanPKU](https://github.com/LiuXiaoxuanPKU), [@peng1999](https://github.com/peng1999), [@xslingcn](https://github.com/xslingcn), [@Yard1](https://github.com/Yard1), [@zhyncs](https://github.com/zhyncs). --- This PR was generated with [Release Please](https://github.com/googleapis/release-please). See [documentation](https://github.com/googleapis/release-please#release-please). --------- Co-authored-by: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com> Co-authored-by: Zihao Ye <expye@outlook.com>
hardware fp8->fp16 fast conversion instruction is not available for sm_80 & sm_89, which makes #420 slow for these architectures. this pr uses marlin's fast fp8->fp16x4 conversion algorithm (copied from vllm project) to accelerate such cases. Co-authored-by: Antoni Baum <antoni@anyscale.com> Co-authored-by: Cody Yu <cody@anyscale.com>
hardware fp8->fp16 fast conversion instruction is not available for sm_80 & sm_89, which makes #420 slow for these architectures. this pr uses marlin's fast fp8->fp16x4 conversion algorithm (copied from vllm project) to accelerate such cases. Co-authored-by: Antoni Baum <antoni@anyscale.com> Co-authored-by: Cody Yu <cody@anyscale.com>
This implementation do not rely on fp8 tensor cores, but uses fp16 tensor cores instead (so sm_80 architectures can also use it), the fp8 kv-cache will be dequantized on-the-fly.
sm_89 and sm_90 append attention kernels that uses native fp8 tensor cores will be available in later PRs.