Skip to content

Commit

Permalink
[Auto Parallel] Add zero h1 pipeline scheduling for paddle (#62865)
Browse files Browse the repository at this point in the history
* reconstruct_pipeline_scheduler_pass

* add pipeline_scheduler_pass into __all__

* update __init__.py

* recover __init__.py

* extract split matmul_grad_op to pass_utils

* fix

* add paddle.distributed.passes.pipeline_scheduler_pass' to setup.py

* add paddle.distributed.passes.pipeline_scheduler_pass' to setup.py.in

* apply suggestions from code review

* update

* fix

* change func name

* update

* update

* add zero bubble pipeline

* fix bug

* fix

* update

* fix error micro step id

* add zero bubble unittest

* update comment

* add zb to __init__.py

* fix

* fix

* fix codestyle

* add enable_send_recv_overlap

* fix
  • Loading branch information
AndSonder authored Apr 18, 2024
1 parent cd050fc commit adf8689
Show file tree
Hide file tree
Showing 7 changed files with 474 additions and 3 deletions.
1 change: 1 addition & 0 deletions python/paddle/distributed/passes/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -102,6 +102,7 @@
Pipeline1F1BPass,
PipelineEager1F1BPass,
PipelineVirtualPipelinePass,
PipelineZeroBubblePipelinePass,
apply_pass,
)
from .ps_trainer_pass import ( # noqa: F401
Expand Down
141 changes: 139 additions & 2 deletions python/paddle/distributed/passes/pass_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -276,10 +276,10 @@ def set_skip_gc_vars(num_micro_batches, job_types, sub_programs, jobs):
f"Skip gc vars for {job_type}-({micro_batch_id}): {skip_gc_vars}"
)

if job_type == "backward":
if job_type in ["backward", "backward_w"]:
assert (
len(skip_gc_vars) == 0
), f"When enabling pipeline parallelism strategy, the skip_gc_vars for backward subprogram must be empty, but it is {skip_gc_vars}."
), f"When enabling pipeline parallelism strategy, the skip_gc_vars for {job_type} subprogram must be empty, but it is {skip_gc_vars}."

job.set_skip_gc_vars(skip_gc_vars)
suffixed_required_vars[micro_batch_id] |= required_vars
Expand Down Expand Up @@ -778,6 +778,143 @@ def _split_ops(block):
return list(type_to_program.keys()), list(type_to_program.values())


def _get_backward_op_type(block, op):
# For the op doesn't have output such as 'send_v2', it should be backward_b.
if len(op.output_arg_names) == 0:
return "backward_b"
for name in op.output_arg_names:
name = name.split("@")[0]
if not block._find_var_recursive(name):
return "backward_b"
var = block._find_var_recursive(name)
if not var.is_parameter:
return "backward_b"

return "backward_w"


def _program_for_zero_bubble(program, enable_send_recv_overlap=False):
if enable_send_recv_overlap:
_overlap_send_recv(program)
else:
_insert_sync_for_fthenb_1f1b(program)

oprole_type = {
0: "forward",
1: "backward",
2: "backward_b",
3: 'backward_w',
4: "optimizer",
}

def _split_ops(block):
# split the program based on the op_role
type_to_ops = OrderedDict()
for type in oprole_type.values():
type_to_ops[type] = []
type_to_ops["fetch"] = []

for op in block.ops:
if _is_fetch_op(op):
type_to_ops["fetch"].append(op)
elif is_forward_op(op):
type_to_ops["forward"].append(op)
elif is_backward_op(op):
type = _get_backward_op_type(block, op)
type_to_ops[type].append(op)
type_to_ops["backward"].append(op)
elif is_optimize_op(op):
type_to_ops["optimizer"].append(op)
else:
raise ValueError(
"The op role: "
+ str(op.attr('op_role'))
+ " isn't one of Forward, Backward or Optimizer."
)
return type_to_ops

type_to_program = OrderedDict()
for type in oprole_type.values():
type_to_program[type] = Program()

for idx, src_block in enumerate(program.blocks):
type_to_ops = _split_ops(src_block)
fwd_ops, bwd_ops, bwd_b_ops, bwd_w_ops, opt_ops, fetch_ops = (
type_to_ops["forward"],
type_to_ops["backward"],
type_to_ops["backward_b"],
type_to_ops["backward_w"],
type_to_ops["optimizer"],
type_to_ops["fetch"],
)
if idx == 0:
fwd_block = type_to_program["forward"].block(0)
_add_ops_into_block(src_block, fwd_block, fwd_ops)

bwd_block = type_to_program["backward"].block(0)
_add_ops_into_block(src_block, bwd_block, bwd_ops)

bwd_block_b = type_to_program["backward_b"].block(0)
_add_ops_into_block(src_block, bwd_block_b, bwd_b_ops)

bwd_block_w = type_to_program["backward_w"].block(0)
_add_ops_into_block(src_block, bwd_block_w, bwd_w_ops)

opt_block = type_to_program["optimizer"].block(0)
_add_ops_into_block(src_block, opt_block, opt_ops)
else:
if len(fwd_ops):
fwd_block = type_to_program["forward"]._create_block(
parent_idx=src_block.parent_idx
)
fwd_block._set_forward_block_idx(src_block.forward_block_idx)
_add_ops_into_block(src_block, fwd_block, fwd_ops)

if len(bwd_ops):
bwd_block = type_to_program["backward"]._create_block(
parent_idx=src_block.parent_idx
)
bwd_block._set_forward_block_idx(src_block.forward_block_idx)
_add_ops_into_block(src_block, bwd_block, bwd_ops)

if len(bwd_b_ops):
bwd_block_b = type_to_program["backward_b"]._create_block(
parent_idx=src_block.parent_idx
)
bwd_block_b._set_forward_block_idx(src_block.forward_block_idx)
_add_ops_into_block(src_block, bwd_block_b, bwd_b_ops)

if len(bwd_w_ops):
bwd_block_w = type_to_program["backward_w"]._create_block(
parent_idx=src_block.parent_idx
)
bwd_block_w._set_forward_block_idx(src_block.forward_block_idx)
_add_ops_into_block(src_block, bwd_block_w, bwd_w_ops)

if len(opt_ops):
opt_block = type_to_program["optimizer"]._create_block(
parent_idx=src_block.parent_idx
)
opt_block._set_forward_block_idx(src_block.forward_block_idx)
_add_ops_into_block(src_block, opt_block, opt_ops)

for fetch_op in fetch_ops:
in_name = fetch_op.input_arg_names[0]
dst_block = None
for block in [fwd_block, bwd_block_b, bwd_block_w, opt_block]:
if block._find_var_recursive(in_name):
dst_block = block
break
if dst_block:
_create_program(src_block, dst_block, fetch_op)

for prog in type_to_program.values():
prog._sync_with_cpp()
prog._roll_to_global_block()

return list(type_to_program.keys()), list(type_to_program.values())


def _add_event_dependency(recorder_op, waiter_op):
'''
Add the extra event dependency of the two operators.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
from .pipeline_eager_1f1b import PipelineEager1F1BPass # noqa: F401
from .pipeline_fthenb import PipelineFThenBPass # noqa: F401
from .pipeline_vpp import PipelineVirtualPipelinePass # noqa: F401
from .pipeline_zero_bubble import PipelineZeroBubblePipelinePass # noqa: F401

__all__ = []

Expand All @@ -29,7 +30,8 @@ def apply_pass(main_program, startup_program, pass_name, pass_attr={}):
"1F1B",
"Eager1F1B",
"VPP",
], f"pipeline scheduler only support FThenB, 1F1B, Eager1F1B and VPP, but receive {pass_name}"
"ZBH1",
], f"pipeline scheduler only support FThenB, 1F1B, Eager1F1B, VPP and ZBH1, but receive {pass_name}"

if pass_name == "1F1B":
# TODO(Ruibiao): Move FLAGS_1f1b_backward_forward_overlap and
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,135 @@
# Copyright (c) 2024 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import logging

from paddle.base import core

from ...utils.log_utils import get_logger
from ..pass_base import register_pass
from ..pass_utils import _program_for_zero_bubble, split_matmul_grad_to_matmul
from .pipeline_pass_base import PipelinePassBase

FORWARD = "forward"
BACKWARD = "backward"
OPT = "optimizer"

logger = get_logger(logging.INFO)


@register_pass("pipeline_scheduler_ZBH1")
class PipelineZeroBubblePipelinePass(PipelinePassBase):
def __init__(self):
super().__init__()
self.set_attr("enable_optimizer_post_validation", 0)

def _create_job_list(self):
num_micro_batches = self.get_attr("num_micro_batches")
pp_stage = self.get_attr("pp_stage")
pp_degree = self.get_attr("pp_degree")

job_list = []
assert (
pp_degree <= num_micro_batches
), "Num of micro batches should larger than or equal to pp degree."

micro_batch_in_warmup = pp_degree - pp_stage
micro_batch_in_zero_bubble = num_micro_batches - pp_degree

forward_micro_batch_id = 0
for _ in range(micro_batch_in_warmup):
forward_job = core.Job(FORWARD)
forward_job.set_micro_batch_id(forward_micro_batch_id)
job_list.append(forward_job)
forward_micro_batch_id += 1

backward_micro_batch_id = 0
for _ in range(pp_stage):
backward_b_job = core.Job(BACKWARD + '_b')
backward_b_job.set_micro_batch_id(backward_micro_batch_id)
job_list.append(backward_b_job)
backward_micro_batch_id += 1

forward_job = core.Job(FORWARD)
forward_job.set_micro_batch_id(forward_micro_batch_id)
job_list.append(forward_job)
forward_micro_batch_id += 1

for _ in range(micro_batch_in_zero_bubble):
backward_job = core.Job(BACKWARD)
backward_job.set_micro_batch_id(backward_micro_batch_id)
job_list.append(backward_job)

forward_job = core.Job(FORWARD)
forward_job.set_micro_batch_id(forward_micro_batch_id)
job_list.append(forward_job)

forward_micro_batch_id += 1
backward_micro_batch_id += 1

for _ in range(micro_batch_in_warmup - 1):
backward_job = core.Job(BACKWARD)
backward_job.set_micro_batch_id(backward_micro_batch_id)
job_list.append(backward_job)
backward_micro_batch_id += 1

if pp_stage > 0:
backward_b_job = core.Job(BACKWARD + '_b')
backward_b_job.set_micro_batch_id(backward_micro_batch_id)
job_list.append(backward_b_job)

backward_w_job = core.Job(BACKWARD + '_w')
backward_w_job.set_micro_batch_id(backward_micro_batch_id)
job_list.append(backward_w_job)
else:
backward_job = core.Job(BACKWARD)
backward_job.set_micro_batch_id(backward_micro_batch_id)
job_list.append(backward_job)
backward_micro_batch_id += 1

for i in range(pp_stage):
backward_w_job = core.Job(BACKWARD + '_w')
backward_w_job.set_micro_batch_id(i)
job_list.append(backward_w_job)

opt_job = core.Job(OPT)
opt_job.set_micro_batch_id(0)
job_list.append(opt_job)
return job_list

def _split_matmul_grad_ops_to_matmul(self, program, dist_context):
for block in program.blocks:
matmul_grad_op_idx = []
ops = block.ops
for i, op_i in enumerate(ops):
if (
op_i.type == "matmul_v2_grad"
and not op_i.attr("trans_x")
and not op_i.attr("trans_y")
):
matmul_grad_op_idx.append(i)

for matmul_grad_id in reversed(matmul_grad_op_idx):
split_matmul_grad_to_matmul(
block, matmul_grad_id, dist_context=dist_context
)

def _partial_programs(self, program):
dist_context = self.get_attr("dist_context")
self._split_matmul_grad_ops_to_matmul(program, dist_context)
enable_send_recv_overlap = self.get_attr("enable_send_recv_overlap")
types, sub_program_list = _program_for_zero_bubble(
program, enable_send_recv_overlap
)
return types, sub_program_list
3 changes: 3 additions & 0 deletions test/auto_parallel/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -87,6 +87,9 @@ if(WITH_DISTRIBUTE AND WITH_GPU)
test_pipeline_scheduler_vpp)
set_tests_properties(test_pipeline_scheduler_vpp
PROPERTIES LABELS "RUN_TYPE=EXCLUSIVE" TIMEOUT 50)
py_test_modules(test_pipeline_scheduler_zb MODULES test_pipeline_scheduler_zb)
set_tests_properties(test_pipeline_scheduler_zb
PROPERTIES LABELS "RUN_TYPE=EXCLUSIVE" TIMEOUT 50)
py_test_modules(test_auto_tuner_compare MODULES test_auto_tuner_compare)
set_tests_properties(test_auto_tuner_compare
PROPERTIES LABELS "RUN_TYPE=EXCLUSIVE" TIMEOUT 100)
Expand Down
Loading

0 comments on commit adf8689

Please # to comment.