From 45f03f621418daa90c88e471cab91f7751af8cfa Mon Sep 17 00:00:00 2001 From: Xinyu Lian Date: Sat, 26 Oct 2024 00:45:53 -0500 Subject: [PATCH 1/2] fix memcpy issue on backward for zero-infinity --- deepspeed/runtime/zero/stage3.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/deepspeed/runtime/zero/stage3.py b/deepspeed/runtime/zero/stage3.py index e2c273fd913f..28e25fb836c3 100644 --- a/deepspeed/runtime/zero/stage3.py +++ b/deepspeed/runtime/zero/stage3.py @@ -1500,7 +1500,8 @@ def partition_grads(self, params_to_release: List[Parameter], grad_partitions: L else: fp32_grad_tensor = self.fp32_partitioned_groups_flat[i].grad.narrow( 0, dest_offset, grad_buffer.numel()) - fp32_grad_tensor.copy_(grad_buffer) + fp32_grad_tensor.copy_( + grad_buffer.to(dtype=torch.float32, device=self.device, non_blocking=True).pin_memory()) # free the gradient if not get_accelerator().is_synchronized_device(): From 109bb6d16dd5da93ac064c2ec77b64f8399d865f Mon Sep 17 00:00:00 2001 From: Xinyu Lian Date: Tue, 29 Oct 2024 03:23:17 -0500 Subject: [PATCH 2/2] fix: use a pre-pinned buffer for grad D2H copy --- deepspeed/runtime/zero/stage3.py | 14 +++++++++++--- 1 file changed, 11 insertions(+), 3 deletions(-) diff --git a/deepspeed/runtime/zero/stage3.py b/deepspeed/runtime/zero/stage3.py index 28e25fb836c3..65460eb72a2f 100644 --- a/deepspeed/runtime/zero/stage3.py +++ b/deepspeed/runtime/zero/stage3.py @@ -538,10 +538,15 @@ def _setup_for_real_optimizer(self): self.grad_partitions_flat_buffer = get_accelerator().pin_memory(self.grad_partitions_flat_buffer) offset = 0 + max_partition_numel = 0 for param in all_params: self.__param_id_to_grad_partition[param.ds_id] = self.grad_partitions_flat_buffer.narrow( 0, offset, param.partition_numel()) offset += param.partition_numel() + max_partition_numel = max(max_partition_numel, param.partition_numel()) + if self.offload_optimizer: + self.pinned_grad_buffer: Tensor = get_accelerator().pin_memory( + torch.empty(max_partition_numel, device=self.device)) def _link_all_hp_params(self): for p in self.module.parameters(): @@ -1498,10 +1503,13 @@ def partition_grads(self, params_to_release: List[Parameter], grad_partitions: L offload_fp32_gradients[i].append(grad_buffer.float()) offload_fp32_offsets[i].append(dest_offset) else: + buffer_numel = grad_buffer.numel() fp32_grad_tensor = self.fp32_partitioned_groups_flat[i].grad.narrow( - 0, dest_offset, grad_buffer.numel()) - fp32_grad_tensor.copy_( - grad_buffer.to(dtype=torch.float32, device=self.device, non_blocking=True).pin_memory()) + 0, dest_offset, buffer_numel) + self.pinned_grad_buffer[:buffer_numel].copy_( + grad_buffer.to(dtype=torch.float32, non_blocking=True)) + get_accelerator().synchronize() + fp32_grad_tensor.copy_(self.pinned_grad_buffer[:buffer_numel], non_blocking=True) # free the gradient if not get_accelerator().is_synchronized_device():