-
Notifications
You must be signed in to change notification settings - Fork 443
feat: add trtllm moe_allreduce_fusion #1108
New issue
Have a question about this project? # for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “#”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? # to your account
Conversation
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Please remove all usage of packed/unpacked data type and use vec_t
instead.
<!-- .github/pull_request_template.md --> ## 📌 Description Update the create_ipc_buffer implementation. Add unit tests for create_ipc_buffer. ## 🔍 Related Issues To help debug #1108. ## 🚀 Pull Request Checklist Thank you for contributing to FlashInfer! Before we review your pull request, please make sure the following items are complete. ### ✅ Pre-commit Checks - [x] I have installed `pre-commit` by running `pip install pre-commit` (or used your preferred method). - [x] I have installed the hooks with `pre-commit install`. - [x] I have run the hooks manually with `pre-commit run --all-files` and fixed any reported issues. > If you are unsure about how to set up `pre-commit`, see [the pre-commit documentation](https://pre-commit.com/). ## 🧪 Tests - [x] Tests have been added or updated as needed. - [x] All tests are passing (`unittest`, etc.). ## Reviewer Notes <!-- Optional: anything you'd like reviewers to focus on, concerns, etc. -->
Next step: uncomment and complete the fused quantization. Maybe dependent on #1142 |
"hidden_dim * sizeof(T) must be a multiple of kBytesPerAccess"); | ||
if (params.residual_out && not params.norm_out && params.quant_out) { | ||
// pattern1: AR+Add_RMS+Quant | ||
// [m, 7168] bf16 allreduce_in, [m, 7168] bf16 residual_in |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Do we have shape check somewhere?
torch.cuda.synchronize() | ||
|
||
# 6. Check correctness | ||
tolerance = 8e-2 if dtype == torch.float16 else 8e-1 |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
8e-1 seems too large for me, can you give an example about the distribution of all_reduce_out
?
// [m, d] bf16 allreduce_in, [m, d] bf16 residual_in | ||
// [m, d] bf16 residual_out, [m, d] bf16 norm_out, [m, d] fp4 quant_out | ||
|
||
if (params.allreduce_out && params.residual_out && !params.norm_out && params.quant_out) { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
the remaining part can still be dispatched:
DISPATCH_MOEREDUCTION_KERNEL(T, params, launch_with_pdl, ar, res, rms, quant)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I'm good with the PR, thanks so much for your contribution!
Please refer to
9c229c9 on how to simplify the macro.
Some naming conventions (in flashinfer we usually write both runtime variable and constexpr in the macro definition, to make it easier to developer to track what are the new constexpr introduced in the macro):
#define DISPATCH_*(var, CONST_EXPR)
and we capitalize the CONST_EXPR
.
<!-- .github/pull_request_template.md --> ## 📌 Description Update the create_ipc_buffer implementation. Add unit tests for create_ipc_buffer. ## 🔍 Related Issues To help debug flashinfer-ai#1108. ## 🚀 Pull Request Checklist Thank you for contributing to FlashInfer! Before we review your pull request, please make sure the following items are complete. ### ✅ Pre-commit Checks - [x] I have installed `pre-commit` by running `pip install pre-commit` (or used your preferred method). - [x] I have installed the hooks with `pre-commit install`. - [x] I have run the hooks manually with `pre-commit run --all-files` and fixed any reported issues. > If you are unsure about how to set up `pre-commit`, see [the pre-commit documentation](https://pre-commit.com/). ## 🧪 Tests - [x] Tests have been added or updated as needed. - [x] All tests are passing (`unittest`, etc.). ## Reviewer Notes <!-- Optional: anything you'd like reviewers to focus on, concerns, etc. -->
<!-- .github/pull_request_template.md --> We try to add moe_all_reduce_fusion kernels from trt-llm. We split this PR into multiple ones. flashinfer-ai#1061 And all_reduce_fusion will be the next. Thank you for contributing to FlashInfer! Before we review your pull request, please make sure the following items are complete. - [x] I have installed `pre-commit` by running `pip install pre-commit` (or used your preferred method). - [x] I have installed the hooks with `pre-commit install`. - [x] I have run the hooks manually with `pre-commit run --all-files` and fixed any reported issues. > If you are unsure about how to set up `pre-commit`, see [the pre-commit documentation](https://pre-commit.com/). - [x] Tests have been added or updated as needed. - [x] All tests are passing (`unittest`, etc.). <!-- Optional: anything you'd like reviewers to focus on, concerns, etc. --> --------- Co-authored-by: Zihao Ye <expye@outlook.com> Address the review comments. Allow d_qk without paged attention Update the cubins a bit more (Need to update the sha) Update the commit sha Updated the cubin loading path - Address code review comments.
📌 Description
We try to add moe_all_reduce_fusion kernels from trt-llm.
🔍 Related Issues
We split this PR into multiple ones. #1061
And all_reduce_fusion will be the next.
🚀 Pull Request Checklist
Thank you for contributing to FlashInfer! Before we review your pull request, please make sure the following items are complete.
✅ Pre-commit Checks
pre-commit
by runningpip install pre-commit
(or used your preferred method).pre-commit install
.pre-commit run --all-files
and fixed any reported issues.🧪 Tests
unittest
, etc.).Reviewer Notes