diff --git a/deepspeed/runtime/zero/stage3.py b/deepspeed/runtime/zero/stage3.py index e2c273fd913f..65460eb72a2f 100644 --- a/deepspeed/runtime/zero/stage3.py +++ b/deepspeed/runtime/zero/stage3.py @@ -538,10 +538,15 @@ def _setup_for_real_optimizer(self): self.grad_partitions_flat_buffer = get_accelerator().pin_memory(self.grad_partitions_flat_buffer) offset = 0 + max_partition_numel = 0 for param in all_params: self.__param_id_to_grad_partition[param.ds_id] = self.grad_partitions_flat_buffer.narrow( 0, offset, param.partition_numel()) offset += param.partition_numel() + max_partition_numel = max(max_partition_numel, param.partition_numel()) + if self.offload_optimizer: + self.pinned_grad_buffer: Tensor = get_accelerator().pin_memory( + torch.empty(max_partition_numel, device=self.device)) def _link_all_hp_params(self): for p in self.module.parameters(): @@ -1498,9 +1503,13 @@ def partition_grads(self, params_to_release: List[Parameter], grad_partitions: L offload_fp32_gradients[i].append(grad_buffer.float()) offload_fp32_offsets[i].append(dest_offset) else: + buffer_numel = grad_buffer.numel() fp32_grad_tensor = self.fp32_partitioned_groups_flat[i].grad.narrow( - 0, dest_offset, grad_buffer.numel()) - fp32_grad_tensor.copy_(grad_buffer) + 0, dest_offset, buffer_numel) + self.pinned_grad_buffer[:buffer_numel].copy_( + grad_buffer.to(dtype=torch.float32, non_blocking=True)) + get_accelerator().synchronize() + fp32_grad_tensor.copy_(self.pinned_grad_buffer[:buffer_numel], non_blocking=True) # free the gradient if not get_accelerator().is_synchronized_device():