From 3c2d8c70906b92d145717428e5663f2639e4ad7d Mon Sep 17 00:00:00 2001 From: Wennie396 <44974020+Wennie396@users.noreply.github.com> Date: Wed, 25 Dec 2024 11:47:12 +0800 Subject: [PATCH] [Auto Parallel]add align mode code for dp (#69941) * add align mode code for dp and moe * fix pp in grad allreduce before adamw * add grad allreduce hack for static mode * fix static acc scale --- .../paddle/distributed/auto_parallel/api.py | 51 +++++++++++++++++-- .../distributed/auto_parallel/static/utils.py | 4 +- .../fleet/utils/hybrid_parallel_util.py | 4 ++ 3 files changed, 55 insertions(+), 4 deletions(-) diff --git a/python/paddle/distributed/auto_parallel/api.py b/python/paddle/distributed/auto_parallel/api.py index ae8d48b2f90100..686c2a401df0e9 100644 --- a/python/paddle/distributed/auto_parallel/api.py +++ b/python/paddle/distributed/auto_parallel/api.py @@ -1245,18 +1245,63 @@ def state_dict(self): def _append_optimize_op(self, block, param_and_grad): if ( in_auto_parallel_align_mode() # In align mode, we use enable_delay_scale_loss by default - and in_dygraph_mode() and param_and_grad[1].is_dist() ): placements = param_and_grad[1].placements meshs = param_and_grad[1].process_mesh grad = param_and_grad[1] + grad_mesh = grad.process_mesh + + def get_mesh(pp_idx=0): + """ + 获得pp_idx的mesh + """ + mesh = fleet.auto.get_mesh() + if "pp" in mesh.dim_names: + mesh = mesh.get_mesh_with_dim("pp", pp_idx) + return mesh + + ipp = 0 + global_mesh = fleet.auto.get_mesh() + if "pp" in global_mesh.dim_names: + pp_degree = global_mesh.get_dim_size("pp") + for i in range(pp_degree): + if meshs.process_ids == get_mesh(i).process_ids: + ipp = i + break + + change_mesh = False + if any( + isinstance(placement, dist.Partial) for placement in placements + ) and ( + (meshs.process_ids == get_mesh(ipp).process_ids) + and (meshs.dim_names != get_mesh(ipp).dim_names) + ): + change_mesh = True + + if change_mesh: + grad = dist.auto_parallel.moe_utils._dist_reshape( + grad, + grad.shape, + get_mesh(ipp), + [ + dist.Partial(dist.ReduceType.kRedSum), + dist.Partial(dist.ReduceType.kRedSum), + ], + ) + placements = grad.placements for i in range(len(placements) - 1, -1, -1): if isinstance(placements[i], dist.Partial): placements[i] = dist.Replicate() - grad = dist.reshard(grad, meshs, placements) - grad /= self.gradient_accumulation_steps + grad = dist.reshard(grad, grad.process_mesh, placements) + if self.gradient_accumulation_steps > 1 and in_dygraph_mode(): + grad /= self.gradient_accumulation_steps + + if change_mesh: + grad = dist.auto_parallel.moe_utils._dist_reshape( + grad, grad.shape, grad_mesh, [dist.Replicate()] + ) param_and_grad = (param_and_grad[0], grad) return self._inner_opt._append_optimize_op(block, param_and_grad) diff --git a/python/paddle/distributed/auto_parallel/static/utils.py b/python/paddle/distributed/auto_parallel/static/utils.py index b5597e50a1b124..ecbaf6405e7d67 100644 --- a/python/paddle/distributed/auto_parallel/static/utils.py +++ b/python/paddle/distributed/auto_parallel/static/utils.py @@ -2746,6 +2746,8 @@ def split_mesh(global_mesh: ProcessMesh, sub_mesh_dim: int): ) sub_mesh_list = [] for sub_process_ids in splitted_process_ids: - sub_mesh_list.append(ProcessMesh(sub_process_ids)) + sub_mesh_list.append( + ProcessMesh(sub_process_ids, global_mesh.dim_names) + ) return sub_mesh_list diff --git a/python/paddle/distributed/fleet/utils/hybrid_parallel_util.py b/python/paddle/distributed/fleet/utils/hybrid_parallel_util.py index fa02323f87111f..b887e5a3e85169 100644 --- a/python/paddle/distributed/fleet/utils/hybrid_parallel_util.py +++ b/python/paddle/distributed/fleet/utils/hybrid_parallel_util.py @@ -280,6 +280,10 @@ def fused_allreduce_gradients(parameter_list, hcg): group = sep_group if group is None else dp_sep_group logger.debug("dp or sep start fuse allreduce gradients") + from paddle.distributed import in_auto_parallel_align_mode + + if in_auto_parallel_align_mode(): + scale = 1.0 fused_allreduce_gradients_with_group(parameter_list, group, scale=scale)