Skip to content
New issue

Have a question about this project? # for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “#”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? # to your account

fix memcpy issue on backward for zero-infinity #6670

Merged
merged 7 commits into from
Oct 31, 2024
13 changes: 11 additions & 2 deletions deepspeed/runtime/zero/stage3.py
Original file line number Diff line number Diff line change
Expand Up @@ -538,10 +538,15 @@ def _setup_for_real_optimizer(self):
self.grad_partitions_flat_buffer = get_accelerator().pin_memory(self.grad_partitions_flat_buffer)

offset = 0
max_partition_numel = 0
for param in all_params:
self.__param_id_to_grad_partition[param.ds_id] = self.grad_partitions_flat_buffer.narrow(
0, offset, param.partition_numel())
offset += param.partition_numel()
max_partition_numel = max(max_partition_numel, param.partition_numel())
if self.offload_optimizer:
self.pinned_grad_buffer: Tensor = get_accelerator().pin_memory(
torch.empty(max_partition_numel, device=self.device))

def _link_all_hp_params(self):
for p in self.module.parameters():
Expand Down Expand Up @@ -1498,9 +1503,13 @@ def partition_grads(self, params_to_release: List[Parameter], grad_partitions: L
offload_fp32_gradients[i].append(grad_buffer.float())
offload_fp32_offsets[i].append(dest_offset)
else:
buffer_numel = grad_buffer.numel()
fp32_grad_tensor = self.fp32_partitioned_groups_flat[i].grad.narrow(
0, dest_offset, grad_buffer.numel())
fp32_grad_tensor.copy_(grad_buffer)
0, dest_offset, buffer_numel)
self.pinned_grad_buffer[:buffer_numel].copy_(
grad_buffer.to(dtype=torch.float32, non_blocking=True))
get_accelerator().synchronize()
fp32_grad_tensor.copy_(self.pinned_grad_buffer[:buffer_numel], non_blocking=True)

# free the gradient
if not get_accelerator().is_synchronized_device():
Expand Down
Loading