Skip to content
New issue

Have a question about this project? # for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “#”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? # to your account

feat: support bmm fp8 #469

Merged
merged 16 commits into from
Aug 26, 2024
Merged

feat: support bmm fp8 #469

merged 16 commits into from
Aug 26, 2024

Conversation

zhyncs
Copy link
Member

@zhyncs zhyncs commented Aug 26, 2024

torch.bmm doesn't support fp8 and torch._scaled_mm doesn't support 3d, so I write this one. @yzh119 cc @merrymercy @Ying1123 @ispobock

Thanks @yzh119 for assisting with debug.

AType: fp8 e4m3, fp8 e5m2
BType: fp8 e4m3, fp8 e5m2
DType: bf16, fp16

Does not support both AType and BType fp8 e5m2. ref https://docs.nvidia.com/cuda/cublas/#cublasltmatmul

pytest python/tests/test_bmm_fp8.py

works on H100

=================================================================================== test session starts ===================================================================================
platform linux -- Python 3.12.4, pytest-8.3.2, pluggy-1.5.0
rootdir: /flashinfer
collected 8 items

python/tests/test_bmm_fp8.py ...s...s                                                                                                                                                                       [100%]

============================================================================== 6 passed, 2 skipped in 2.16s ===============================================================================

@zhyncs zhyncs added the enhancement New feature or request label Aug 26, 2024
@zhyncs zhyncs requested a review from yzh119 August 26, 2024 18:04
@zhyncs zhyncs self-assigned this Aug 26, 2024
@yzh119
Copy link
Collaborator

yzh119 commented Aug 26, 2024

Another suggestion is to move group gemm and bmm fp8 to a common gemm.py, we should also update the group_gemm.rst (to gemm.rst) as well.

@zhyncs
Copy link
Member Author

zhyncs commented Aug 26, 2024

Another suggestion is to move group gemm and bmm fp8 to a common gemm.py, we should also update the group_gemm.rst (to gemm.rst) as well.

make sense

python/flashinfer/gemm.py Outdated Show resolved Hide resolved
Copy link
Collaborator

@yzh119 yzh119 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM, thanks for your contribution @zhyncs !

@yzh119 yzh119 merged commit f1c0b68 into main Aug 26, 2024
@zhyncs zhyncs deleted the fp8-bmm-scale branch August 26, 2024 19:32
yzh119 added a commit that referenced this pull request Aug 27, 2024
The documentation was not indexed properly in #469 , this PR fixes the
issue.
yzh119 added a commit that referenced this pull request Aug 27, 2024
🤖 I have created a release *beep* *boop*
---


##
[0.1.6](v0.1.5...v0.1.6)
(2024-08-27)

### SM75 Support

Starting from
[0.1.6](v0.1.5...v0.1.6),
our pre-built wheels include experimental support sm75 (Turing
architecture GPUs such as Tesla T4, Quadro RTX 6000 and RTX 2080).

### API Changes

#### `plan`/`run`

Since
[0.1.6](v0.1.5...v0.1.6)
on, `begin_forward`/`forward`/`end_forward` APIs are replaced with the
new `plan`/`run` API.
- `forward` is renamed to `run`, which is more precise and consistent
with the naming convention of cutlass's python API.
- `begin_forward` is renamed to `plan`, which is consistent with the
naming convention of nvmath API.
- `end_forward` is deprecated and has no effect after this PR.

There is some slight difference between the old `forward` and the new
`run` API:
- All extra arguments such as `causal` and `logits_soft_cap` will be
provided in `plan` (previously `begin_forward`) API, and cached until
next `plan` call, and we only need to provide query and KV-Cache tensors
in `run` API.

The old `begin_forward`/`forward`/`end_forward` APIs are still
functional, but we will gradually deprecate them in future releases.

Check [#466](#466) for
more details.

#### `MultiLevelCascadeAttentionWrapper`

Since
[0.1.6](v0.1.5...v0.1.6)
on, we introduce a new `MultiLevelCascadeAttentionWrapper` API for
cascade inference,
which supports multi-level cascade inference where all levels' KV-Cache
can be managed in a unified Paged KV-Cache.

See
[documentation](https://docs.flashinfer.ai/api/python/cascade.html#flashinfer.cascade.MultiLevelCascadeAttentionWrapper)
and
[tutorial](https://docs.flashinfer.ai/tutorials/kv_layout.html#multi-level-cascade-inference-data-layout)
on API usage and layout explaination.

The old `BatchDecodeWithSharedPrefixPagedKVCacheWrapper` and
`BatchPrefillWithSharedPrefixPagedKVCacheWrapper` will be deprecated in
future releases.

### Features

* sm75 support
([#448](#448),
[#449](#449))
* add `MultiLevelCascadeAttentionWrapper` API
([#462](#462))
([1e37989](1e37989))
* add accept num, emit num metric for ChainSpeculativeSampling
([#450](#450))
([fa38b5e](fa38b5e))
* support bmm fp8
([#469](#469))
([f1c0b68](f1c0b68))

### Refactor

* refactor: replace `begin_forward`/`forward`/`end_forward` with
`plan`/`run`
[#466](#466)

### Misc

* misc: improve error handling of sampling kernels
([#456](#456))
([0dce178](0dce178))

### Performance Improvements

* slight optimization on f16->f8 fragment layout swizzling
([#453](#453))
([0d61871](0d61871))
* slight optimization on fragment layout swizzle
([#458](#458))
([7c397cb](7c397cb))
* use persistent kernel for merging attention states
([#459](#459))
([be6bf5b](be6bf5b))

### Acknowledgement

We thank [@LiuXiaoxuanPKU](https://github.com/LiuXiaoxuanPKU) on enhance
of speculative sampling operator,
[@merrymercy](https://github.com/merrymercy) on API change suggestion
and [@zhyncs](https://github.com/zhyncs) on integrating fp8 BMM cublas
implementation.

---
This PR was generated with [Release
Please](https://github.com/googleapis/release-please). See
[documentation](https://github.com/googleapis/release-please#release-please).

---------

Co-authored-by: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com>
Co-authored-by: Zihao Ye <expye@outlook.com>
# for free to join this conversation on GitHub. Already have an account? # to comment
Labels
enhancement New feature or request
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants