Skip to content

Commit

Permalink
[Auto Parallel]add align mode code for dp (PaddlePaddle#69941)
Browse files Browse the repository at this point in the history
* add align mode code for dp and moe

* fix pp in grad allreduce before adamw

* add grad allreduce hack for static mode

* fix static acc scale
  • Loading branch information
Wennie396 authored Dec 25, 2024
1 parent 6368f90 commit 3c2d8c7
Show file tree
Hide file tree
Showing 3 changed files with 55 additions and 4 deletions.
51 changes: 48 additions & 3 deletions python/paddle/distributed/auto_parallel/api.py
Original file line number Diff line number Diff line change
Expand Up @@ -1245,18 +1245,63 @@ def state_dict(self):
def _append_optimize_op(self, block, param_and_grad):
if (
in_auto_parallel_align_mode() # In align mode, we use enable_delay_scale_loss by default
and in_dygraph_mode()
and param_and_grad[1].is_dist()
):
placements = param_and_grad[1].placements
meshs = param_and_grad[1].process_mesh
grad = param_and_grad[1]
grad_mesh = grad.process_mesh

def get_mesh(pp_idx=0):
"""
获得pp_idx的mesh
"""
mesh = fleet.auto.get_mesh()
if "pp" in mesh.dim_names:
mesh = mesh.get_mesh_with_dim("pp", pp_idx)
return mesh

ipp = 0
global_mesh = fleet.auto.get_mesh()
if "pp" in global_mesh.dim_names:
pp_degree = global_mesh.get_dim_size("pp")
for i in range(pp_degree):
if meshs.process_ids == get_mesh(i).process_ids:
ipp = i
break

change_mesh = False
if any(
isinstance(placement, dist.Partial) for placement in placements
) and (
(meshs.process_ids == get_mesh(ipp).process_ids)
and (meshs.dim_names != get_mesh(ipp).dim_names)
):
change_mesh = True

if change_mesh:
grad = dist.auto_parallel.moe_utils._dist_reshape(
grad,
grad.shape,
get_mesh(ipp),
[
dist.Partial(dist.ReduceType.kRedSum),
dist.Partial(dist.ReduceType.kRedSum),
],
)
placements = grad.placements

for i in range(len(placements) - 1, -1, -1):
if isinstance(placements[i], dist.Partial):
placements[i] = dist.Replicate()
grad = dist.reshard(grad, meshs, placements)
grad /= self.gradient_accumulation_steps
grad = dist.reshard(grad, grad.process_mesh, placements)
if self.gradient_accumulation_steps > 1 and in_dygraph_mode():
grad /= self.gradient_accumulation_steps

if change_mesh:
grad = dist.auto_parallel.moe_utils._dist_reshape(
grad, grad.shape, grad_mesh, [dist.Replicate()]
)
param_and_grad = (param_and_grad[0], grad)
return self._inner_opt._append_optimize_op(block, param_and_grad)

Expand Down
4 changes: 3 additions & 1 deletion python/paddle/distributed/auto_parallel/static/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -2746,6 +2746,8 @@ def split_mesh(global_mesh: ProcessMesh, sub_mesh_dim: int):
)
sub_mesh_list = []
for sub_process_ids in splitted_process_ids:
sub_mesh_list.append(ProcessMesh(sub_process_ids))
sub_mesh_list.append(
ProcessMesh(sub_process_ids, global_mesh.dim_names)
)

return sub_mesh_list
4 changes: 4 additions & 0 deletions python/paddle/distributed/fleet/utils/hybrid_parallel_util.py
Original file line number Diff line number Diff line change
Expand Up @@ -280,6 +280,10 @@ def fused_allreduce_gradients(parameter_list, hcg):
group = sep_group if group is None else dp_sep_group

logger.debug("dp or sep start fuse allreduce gradients")
from paddle.distributed import in_auto_parallel_align_mode

if in_auto_parallel_align_mode():
scale = 1.0
fused_allreduce_gradients_with_group(parameter_list, group, scale=scale)


Expand Down

0 comments on commit 3c2d8c7

Please # to comment.