-
Notifications
You must be signed in to change notification settings - Fork 5.7k
New issue
Have a question about this project? # for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “#”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? # to your account
[Auto Parallel] Add zero h1 pipeline scheduling for paddle #62865
Merged
Merged
Changes from all commits
Commits
Show all changes
37 commits
Select commit
Hold shift + click to select a range
81fa843
reconstruct_pipeline_scheduler_pass
AndSonder 09731e2
add pipeline_scheduler_pass into __all__
AndSonder 63355db
update __init__.py
AndSonder c7483fc
Merge branch 'develop' of https://github.com/PaddlePaddle/Paddle into…
AndSonder 9d9b0a7
recover __init__.py
AndSonder e52301b
extract split matmul_grad_op to pass_utils
AndSonder b4ee57d
fix
AndSonder 11d241d
Merge branch 'develop' of https://github.com/PaddlePaddle/Paddle into…
AndSonder c87f313
add paddle.distributed.passes.pipeline_scheduler_pass' to setup.py
AndSonder ed96dfc
add paddle.distributed.passes.pipeline_scheduler_pass' to setup.py.in
AndSonder 617d248
apply suggestions from code review
AndSonder ea00f08
update
AndSonder 37e8ca0
fix
AndSonder 8f4867e
change func name
AndSonder 2962d82
Merge branch 'reconstruct_pipeline_scheduler_pass' of https://github.…
AndSonder 001e40c
Merge branch 'split_matmul_grad_v2' of https://github.com/AndSonder/P…
AndSonder 0bd16fb
update
AndSonder 69a56b3
update
AndSonder e7d7f05
Merge branch 'develop' of https://github.com/PaddlePaddle/Paddle into…
AndSonder 29d3699
Merge branch 'reconstruct_pipeline_scheduler_pass' of https://github.…
AndSonder 6ba3ec5
add zero bubble pipeline
AndSonder fe26a06
fix bug
AndSonder e49889f
fix
AndSonder 3fb3235
update
AndSonder 4f17c40
fix error micro step id
AndSonder f405be1
Merge branch 'develop' of https://github.com/PaddlePaddle/Paddle into…
AndSonder 92f31de
add zero bubble unittest
AndSonder 572a0b6
update comment
AndSonder 4911d03
Merge branch 'develop' of https://github.com/PaddlePaddle/Paddle into…
AndSonder 2c43ad4
merge from dev
AndSonder 7a6b150
Merge branch 'develop' into fit_zero_h1
AndSonder 6b70ad5
add zb to __init__.py
AndSonder 6acb7f9
fix
AndSonder b7f8abd
fix
AndSonder 8a4dad3
fix codestyle
AndSonder 5f0dd0c
add enable_send_recv_overlap
AndSonder 96b0895
fix
AndSonder File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
135 changes: 135 additions & 0 deletions
135
python/paddle/distributed/passes/pipeline_scheduler_pass/pipeline_zero_bubble.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,135 @@ | ||
# Copyright (c) 2024 PaddlePaddle Authors. All Rights Reserved. | ||
# | ||
# Licensed under the Apache License, Version 2.0 (the "License"); | ||
# you may not use this file except in compliance with the License. | ||
# You may obtain a copy of the License at | ||
# | ||
# http://www.apache.org/licenses/LICENSE-2.0 | ||
# | ||
# Unless required by applicable law or agreed to in writing, software | ||
# distributed under the License is distributed on an "AS IS" BASIS, | ||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
# See the License for the specific language governing permissions and | ||
# limitations under the License. | ||
|
||
import logging | ||
|
||
from paddle.base import core | ||
|
||
from ...utils.log_utils import get_logger | ||
from ..pass_base import register_pass | ||
from ..pass_utils import _program_for_zero_bubble, split_matmul_grad_to_matmul | ||
from .pipeline_pass_base import PipelinePassBase | ||
|
||
FORWARD = "forward" | ||
BACKWARD = "backward" | ||
OPT = "optimizer" | ||
|
||
logger = get_logger(logging.INFO) | ||
|
||
|
||
@register_pass("pipeline_scheduler_ZBH1") | ||
class PipelineZeroBubblePipelinePass(PipelinePassBase): | ||
def __init__(self): | ||
super().__init__() | ||
self.set_attr("enable_optimizer_post_validation", 0) | ||
|
||
def _create_job_list(self): | ||
num_micro_batches = self.get_attr("num_micro_batches") | ||
pp_stage = self.get_attr("pp_stage") | ||
pp_degree = self.get_attr("pp_degree") | ||
|
||
job_list = [] | ||
assert ( | ||
pp_degree <= num_micro_batches | ||
), "Num of micro batches should larger than or equal to pp degree." | ||
|
||
micro_batch_in_warmup = pp_degree - pp_stage | ||
micro_batch_in_zero_bubble = num_micro_batches - pp_degree | ||
|
||
forward_micro_batch_id = 0 | ||
for _ in range(micro_batch_in_warmup): | ||
forward_job = core.Job(FORWARD) | ||
forward_job.set_micro_batch_id(forward_micro_batch_id) | ||
job_list.append(forward_job) | ||
forward_micro_batch_id += 1 | ||
|
||
backward_micro_batch_id = 0 | ||
for _ in range(pp_stage): | ||
backward_b_job = core.Job(BACKWARD + '_b') | ||
backward_b_job.set_micro_batch_id(backward_micro_batch_id) | ||
job_list.append(backward_b_job) | ||
backward_micro_batch_id += 1 | ||
|
||
forward_job = core.Job(FORWARD) | ||
forward_job.set_micro_batch_id(forward_micro_batch_id) | ||
job_list.append(forward_job) | ||
forward_micro_batch_id += 1 | ||
|
||
for _ in range(micro_batch_in_zero_bubble): | ||
backward_job = core.Job(BACKWARD) | ||
backward_job.set_micro_batch_id(backward_micro_batch_id) | ||
job_list.append(backward_job) | ||
|
||
forward_job = core.Job(FORWARD) | ||
forward_job.set_micro_batch_id(forward_micro_batch_id) | ||
job_list.append(forward_job) | ||
|
||
forward_micro_batch_id += 1 | ||
backward_micro_batch_id += 1 | ||
|
||
for _ in range(micro_batch_in_warmup - 1): | ||
backward_job = core.Job(BACKWARD) | ||
backward_job.set_micro_batch_id(backward_micro_batch_id) | ||
job_list.append(backward_job) | ||
backward_micro_batch_id += 1 | ||
|
||
if pp_stage > 0: | ||
backward_b_job = core.Job(BACKWARD + '_b') | ||
backward_b_job.set_micro_batch_id(backward_micro_batch_id) | ||
job_list.append(backward_b_job) | ||
|
||
backward_w_job = core.Job(BACKWARD + '_w') | ||
backward_w_job.set_micro_batch_id(backward_micro_batch_id) | ||
job_list.append(backward_w_job) | ||
else: | ||
backward_job = core.Job(BACKWARD) | ||
backward_job.set_micro_batch_id(backward_micro_batch_id) | ||
job_list.append(backward_job) | ||
backward_micro_batch_id += 1 | ||
|
||
for i in range(pp_stage): | ||
backward_w_job = core.Job(BACKWARD + '_w') | ||
backward_w_job.set_micro_batch_id(i) | ||
job_list.append(backward_w_job) | ||
|
||
opt_job = core.Job(OPT) | ||
opt_job.set_micro_batch_id(0) | ||
job_list.append(opt_job) | ||
return job_list | ||
|
||
def _split_matmul_grad_ops_to_matmul(self, program, dist_context): | ||
for block in program.blocks: | ||
matmul_grad_op_idx = [] | ||
ops = block.ops | ||
for i, op_i in enumerate(ops): | ||
if ( | ||
op_i.type == "matmul_v2_grad" | ||
and not op_i.attr("trans_x") | ||
and not op_i.attr("trans_y") | ||
): | ||
matmul_grad_op_idx.append(i) | ||
|
||
for matmul_grad_id in reversed(matmul_grad_op_idx): | ||
split_matmul_grad_to_matmul( | ||
block, matmul_grad_id, dist_context=dist_context | ||
) | ||
|
||
def _partial_programs(self, program): | ||
dist_context = self.get_attr("dist_context") | ||
self._split_matmul_grad_ops_to_matmul(program, dist_context) | ||
enable_send_recv_overlap = self.get_attr("enable_send_recv_overlap") | ||
types, sub_program_list = _program_for_zero_bubble( | ||
program, enable_send_recv_overlap | ||
) | ||
return types, sub_program_list |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Deal the operators without
output
such assend_v2
,c_sync_calc_stream
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
done, 麻烦研发老师有空的时候再测试一下是否还会 hang 住 ~