Skip to content

Commit

Permalink
fix ds-sp grad scale for zero0
Browse files Browse the repository at this point in the history
  • Loading branch information
inkcherry committed May 21, 2024
1 parent 1d81967 commit cb15ffa
Showing 1 changed file with 9 additions and 5 deletions.
14 changes: 9 additions & 5 deletions deepspeed/runtime/engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -2411,20 +2411,24 @@ def _reduce_non_expert_gradients(self, grads, elements_per_buffer):
else:
dp_group = groups._get_sequence_data_parallel_group()

dp_world_size = dist.get_world_size(dp_group) / float(self.sequence_parallel_size)
for _, sparse_bucket_tuple in enumerate(split_sparse_tensor_buckets):
if sparse_bucket_tuple:
bucket_type, sparse_bucket = sparse_bucket_tuple
self.sparse_allreduce_no_retain(sparse_bucket, dp_group=dp_group)
self.sparse_allreduce_no_retain(sparse_bucket, dp_group=dp_group, dp_world_size=dp_world_size)

for _, dense_bucket_tuple in enumerate(split_dense_tensor_buckets):
if dense_bucket_tuple:
bucket_type, dense_bucket = dense_bucket_tuple
self.allreduce_no_retain(dense_bucket, dp_group=dp_group, numel_per_bucket=elements_per_buffer)
self.allreduce_no_retain(dense_bucket,
dp_group=dp_group,
numel_per_bucket=elements_per_buffer,
dp_world_size=dp_world_size)

def _reduce_expert_gradients(self, expert_grads, elements_per_buffer):
# to maintain the gradients value unaffected by ep_size setting,
# utilize dp_world_size for allreduce average
dp_world_size = dist.get_world_size(groups._get_data_parallel_group())
dp_world_size = dist.get_world_size(groups._get_data_parallel_group()) / float(self.sequence_parallel_size)
for ep_name, expert_grads_group in expert_grads.items():
ep_dp_group = groups._get_expert_data_parallel_group(ep_name)
split_sparse_tensor_buckets, split_dense_tensor_buckets = split_half_float_double_sparse(
Expand Down Expand Up @@ -2491,9 +2495,9 @@ def sparse_allreduce(self, sparse, dp_group, dp_world_size=None):
dp_world_size = dist.get_world_size(group=dp_group)
if self.postscale_gradients():
if self.gradient_average:
values.mul_(self.gradient_predivide_factor() / (dp_world_size / float(self.sequence_parallel_size)))
values.mul_(self.gradient_predivide_factor() / (dp_world_size))
else:
values.mul_(1. / (dp_world_size / float(self.sequence_parallel_size)))
values.mul_(1. / (dp_world_size))

indices_device_list = self.sparse_all_gather(indices, dp_group)
values_device_list = self.sparse_all_gather(values, dp_group)
Expand Down

0 comments on commit cb15ffa

Please # to comment.