-
Notifications
You must be signed in to change notification settings - Fork 135
New issue
Have a question about this project? # for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “#”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? # to your account
feat: support bmm fp8 #469
Conversation
Another suggestion is to move group gemm and bmm fp8 to a common gemm.py, we should also update the group_gemm.rst (to gemm.rst) as well. |
make sense |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM, thanks for your contribution @zhyncs !
The documentation was not indexed properly in #469 , this PR fixes the issue.
🤖 I have created a release *beep* *boop* --- ## [0.1.6](v0.1.5...v0.1.6) (2024-08-27) ### SM75 Support Starting from [0.1.6](v0.1.5...v0.1.6), our pre-built wheels include experimental support sm75 (Turing architecture GPUs such as Tesla T4, Quadro RTX 6000 and RTX 2080). ### API Changes #### `plan`/`run` Since [0.1.6](v0.1.5...v0.1.6) on, `begin_forward`/`forward`/`end_forward` APIs are replaced with the new `plan`/`run` API. - `forward` is renamed to `run`, which is more precise and consistent with the naming convention of cutlass's python API. - `begin_forward` is renamed to `plan`, which is consistent with the naming convention of nvmath API. - `end_forward` is deprecated and has no effect after this PR. There is some slight difference between the old `forward` and the new `run` API: - All extra arguments such as `causal` and `logits_soft_cap` will be provided in `plan` (previously `begin_forward`) API, and cached until next `plan` call, and we only need to provide query and KV-Cache tensors in `run` API. The old `begin_forward`/`forward`/`end_forward` APIs are still functional, but we will gradually deprecate them in future releases. Check [#466](#466) for more details. #### `MultiLevelCascadeAttentionWrapper` Since [0.1.6](v0.1.5...v0.1.6) on, we introduce a new `MultiLevelCascadeAttentionWrapper` API for cascade inference, which supports multi-level cascade inference where all levels' KV-Cache can be managed in a unified Paged KV-Cache. See [documentation](https://docs.flashinfer.ai/api/python/cascade.html#flashinfer.cascade.MultiLevelCascadeAttentionWrapper) and [tutorial](https://docs.flashinfer.ai/tutorials/kv_layout.html#multi-level-cascade-inference-data-layout) on API usage and layout explaination. The old `BatchDecodeWithSharedPrefixPagedKVCacheWrapper` and `BatchPrefillWithSharedPrefixPagedKVCacheWrapper` will be deprecated in future releases. ### Features * sm75 support ([#448](#448), [#449](#449)) * add `MultiLevelCascadeAttentionWrapper` API ([#462](#462)) ([1e37989](1e37989)) * add accept num, emit num metric for ChainSpeculativeSampling ([#450](#450)) ([fa38b5e](fa38b5e)) * support bmm fp8 ([#469](#469)) ([f1c0b68](f1c0b68)) ### Refactor * refactor: replace `begin_forward`/`forward`/`end_forward` with `plan`/`run` [#466](#466) ### Misc * misc: improve error handling of sampling kernels ([#456](#456)) ([0dce178](0dce178)) ### Performance Improvements * slight optimization on f16->f8 fragment layout swizzling ([#453](#453)) ([0d61871](0d61871)) * slight optimization on fragment layout swizzle ([#458](#458)) ([7c397cb](7c397cb)) * use persistent kernel for merging attention states ([#459](#459)) ([be6bf5b](be6bf5b)) ### Acknowledgement We thank [@LiuXiaoxuanPKU](https://github.com/LiuXiaoxuanPKU) on enhance of speculative sampling operator, [@merrymercy](https://github.com/merrymercy) on API change suggestion and [@zhyncs](https://github.com/zhyncs) on integrating fp8 BMM cublas implementation. --- This PR was generated with [Release Please](https://github.com/googleapis/release-please). See [documentation](https://github.com/googleapis/release-please#release-please). --------- Co-authored-by: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com> Co-authored-by: Zihao Ye <expye@outlook.com>
torch.bmm
doesn't support fp8 andtorch._scaled_mm
doesn't support 3d, so I write this one. @yzh119 cc @merrymercy @Ying1123 @ispobockThanks @yzh119 for assisting with debug.
AType: fp8 e4m3, fp8 e5m2
BType: fp8 e4m3, fp8 e5m2
DType: bf16, fp16
Does not support both AType and BType fp8 e5m2. ref https://docs.nvidia.com/cuda/cublas/#cublasltmatmul
works on H100