Skip to content
New issue

Have a question about this project? # for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “#”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? # to your account

reduce cpu host overhead when using moe #5578

Merged

Conversation

ranzhejiang
Copy link
Contributor

@ranzhejiang ranzhejiang commented May 29, 2024

The operation .to('cpu') is not necessary for exp_counts, and it will cause device to host synchronization which damage performance.

@ranzhejiang ranzhejiang requested a review from awan-10 as a code owner May 29, 2024 04:01
@loadams loadams requested a review from tohtana May 31, 2024 22:15
Copy link
Contributor

@tohtana tohtana left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@ranzhejiang Thank you for your contribution! I have a few questions about your changes. Can you clarify them?

deepspeed/moe/sharded_moe.py Show resolved Hide resolved
@ranzhejiang ranzhejiang force-pushed the zhejiang/reduce_host_overhead_moe branch from e9e32f4 to d860d2c Compare June 11, 2024 03:32
@ranzhejiang
Copy link
Contributor Author

Hi, @tohtana I have clarified the modifications you mentioned and retest this PR with Megatron-Deepspeed on GPU platform(8xA800). It runs well and loss remains consistent with the original method, Could you please help review it again? Thanks!

@ranzhejiang ranzhejiang force-pushed the zhejiang/reduce_host_overhead_moe branch from 686f511 to 23ec4a1 Compare August 16, 2024 03:58
@ranzhejiang ranzhejiang force-pushed the zhejiang/reduce_host_overhead_moe branch from 23ec4a1 to 1cb0efd Compare August 16, 2024 03:59
@ranzhejiang
Copy link
Contributor Author

#5881 also adopts this plan to reduce cpu time

@tohtana tohtana added this pull request to the merge queue Aug 21, 2024
Merged via the queue into deepspeedai:master with commit 7260890 Aug 21, 2024
11 checks passed
# for free to join this conversation on GitHub. Already have an account? # to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

4 participants