Skip to content
New issue

Have a question about this project? # for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “#”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? # to your account

Enabled Qwen2-MoE Tensor Parallelism (TP) inference #6551

Merged
merged 5 commits into from
Oct 9, 2024

Conversation

gyou2021
Copy link
Contributor

Modified _replace_module in auto_tp.py :
The modification keeps the layers 'shared_expert_gate' and 'gate' in qwen2-moe the original type torch.nn.Linear and not changes them into LinearLayer. In this way, their weights will not be split into multiple HPU/GPU cards. Then the qwen2-moe can run on multiple HPU/GPU cards.
Since the weights of 'gate' are not split into multiple HPU/GPU cards, all gather operations are not needed, which may improve performance.

@delock
Copy link
Collaborator

delock commented Sep 19, 2024

Hi @Yejing-Lai , do you want to provide some comments on this PR for Qwen2-MoE AutoTP support?

@Yejing-Lai
Copy link
Contributor

Could you try to modify this line if it can meet your needs? https://github.com/microsoft/DeepSpeed/blob/master/deepspeed/module_inject/auto_tp.py#L336

@gyou2021
Copy link
Contributor Author

Could you try to modify this line if it can meet your needs? https://github.com/microsoft/DeepSpeed/blob/master/deepspeed/module_inject/auto_tp.py#L336

Yes. It can provide the same function and result if probably coded.

@gyou2021
Copy link
Contributor Author

Could you try to modify this line if it can meet your needs? https://github.com/microsoft/DeepSpeed/blob/master/deepspeed/module_inject/auto_tp.py#L336

Thank you for your comments.
I just moved the linear filter of qwen2-moe from _replace_module() to _replace() for uniform code management. Both have the same function and the same result. The qwen2-moe branch has been updated.

…() for uniform code management. Both have the same function and the same result.
@delock
Copy link
Collaborator

delock commented Sep 27, 2024

Hi @gyou2021 , can you also add qwen2-moe to this list? Some user will check this page for AutoTP supported model.
https://github.com/microsoft/DeepSpeed/blob/master/docs/_tutorials/automatic-tensor-parallelism.md#supported-models

@gyou2021
Copy link
Contributor Author

Hi @gyou2021 , can you also add qwen2-moe to this list? Some user will check this page for AutoTP supported model. https://github.com/microsoft/DeepSpeed/blob/master/docs/_tutorials/automatic-tensor-parallelism.md#supported-models

Added. Thank you for your comment.

@delock
Copy link
Collaborator

delock commented Sep 29, 2024

Hi @tjruwase This PR adds AutoTP support for Qwen2-MoE. @Yejing-Lai and me had reviewed this change. Thanks!

@loadams loadams requested a review from tjruwase as a code owner October 8, 2024 23:06
@loadams loadams added this pull request to the merge queue Oct 9, 2024
@github-merge-queue github-merge-queue bot removed this pull request from the merge queue due to failed status checks Oct 9, 2024
@loadams loadams added this pull request to the merge queue Oct 9, 2024
@loadams loadams removed this pull request from the merge queue due to a manual request Oct 9, 2024
@loadams loadams added this pull request to the merge queue Oct 9, 2024
Merged via the queue into deepspeedai:master with commit 474a328 Oct 9, 2024
12 checks passed
# for free to join this conversation on GitHub. Already have an account? # to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

4 participants