-
Notifications
You must be signed in to change notification settings - Fork 4.2k
New issue
Have a question about this project? # for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “#”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? # to your account
fix memcpy issue on backward for zero-infinity #6670
fix memcpy issue on backward for zero-infinity #6670
Conversation
Hi @xylian86, we have encountered significant performance degradation when applying this patch on Gaudi HPU accelerator. The issues that we see are:
Below is the partition_grads profile before the patch: ![]() And with the patch: ![]() The above traces are taken for a small variation of BERT. For larger models, the performance degradation is more severe. Could you please share some details of performance improvement that you observed? (i.e accelerator/workload/some stats). Thanks |
@deepcharm, thanks for reporting this degradation. It is strange since this PR provided good speedups on CUDA, so I am guessing this some device-specific effect. While waiting for @xylian86 to respond, I wonder if you could test the following modification to L1508-1512, that removes self.pinned_grad_buffer[:buffer_numel].copy_(grad_buffer.to(dtype=torch.float32))
fp32_grad_tensor.copy_(self.pinned_grad_buffer[:buffer_numel]) |
@deepcharm Thank you for reporting it. The setup from my side: 4 GH200 GPUs, GBS=128, MBS=16, ZeRO 3. Following this pull request, we observed a 4x reduction in backward pass time during the accumulation step, as demonstrated in the figure attached to the PR. And as Tunji mentioned, it might be some device-specific effect, what is the bandwidth between CPU and GPU in your testing environment?? This PR is particularly effective for newer GPU clusters with high-bandwidth interconnects (PCIe 5.0). |
@xylian86 which model did you use? |
@nelyahu LlaMA 7B |
@xylian86 Thanks for providing the details! Indeed, it seems that it's a device-specific effect. @tjruwase We've tried making the 1-st copy non-blocking while also removing the device synchronize. Can we add a zero config option (default= |
Hi @tjruwase, if you can please approve/comment on adding a config option to enable/disable this feature? Many thanks |
@deepcharm, apologies for the delayed response. Yes, we will add a ds_config to control this feature. |
@deepcharm, do you know if ZeRO-1/2 on HPU experienced similar perf degradation with the earlier PR #5301? |
@xylian86, did you try |
@tjruwase we will check zero1/2+Offload-Optimizer. we observed it only on our nightly regression for Zero-Inf and Zero3+Offload-optimizer. |
@tjruwase Hi Tunji, I tried both Here are the benchmark results using one GH200 GPU with batch size=1, ZeRO 3, and a 3B model. Before this PR:
After this PR:
Legend:
I agree that we should add one config that ensures compatibility across different hardware. |
@xylian86, thanks for the details. However, the results are confusing because if
fp32_grad_tensor = self.fp32_partitioned_groups_flat[i].grad.narrow(
0, dest_offset, grad_buffer.numel())
fp32_grad_tensor.copy_(grad_buffer.float()) |
As a follow-up to the previous post, the similiar issue in Optimizer Step occurs in the following two lines https://github.com/microsoft/DeepSpeed/blob/master/deepspeed/runtime/zero/stage3.py#L2075-L2076.
When the CPU-GPU bandwidth is high, it would be more efficient to first copy the FP32 tensor to the GPU and then downcast it to FP16. |
@xylian86, thanks for the update. The impact of explicit upcast on BWD is quite dramatic. Can you please confirm that the following results using
|
@tjruwase Yes, the results are correct. |
@xylian86, thanks for confirmation. We need a PR of explicit GPU upcast that undoes this PR urgently given the HPU perf regression. Are you able to help with this? |
@tjruwase For sure. I will open a PR now. |
@xylian86, thanks for the quick action. The PR looks great. @deepcharm, after more investigation we are replacing this PR with #6962. Can you please check for any perf issues with the new PR? Thanks! |
Following discussion in [PR-6670](#6670), the explict upcast is much more efficient than implicit upcast, this PR is to replace implicit upcast with explict one. The results on 3B model are shown below: | Option | BWD (ms) | Speed up | |------------|-----|------| | Before PR-6670 | 25603.30 | 1x | | After PR-6670 | 1174.31 | 21.8X | | After this PR| 309.2 | 82.8X |
Following discussion in [PR-6670](#6670), the explict upcast is much more efficient than implicit upcast, this PR is to replace implicit upcast with explict one. The results on 3B model are shown below: | Option | BWD (ms) | Speed up | |------------|-----|------| | Before PR-6670 | 25603.30 | 1x | | After PR-6670 | 1174.31 | 21.8X | | After this PR| 309.2 | 82.8X | Signed-off-by: Olatunji Ruwase <olruwase@microsoft.com>
Following discussion in [PR-6670](deepspeedai#6670), the explict upcast is much more efficient than implicit upcast, this PR is to replace implicit upcast with explict one. The results on 3B model are shown below: | Option | BWD (ms) | Speed up | |------------|-----|------| | Before PR-6670 | 25603.30 | 1x | | After PR-6670 | 1174.31 | 21.8X | | After this PR| 309.2 | 82.8X | Signed-off-by: siqi <siqi@tecorigin.com>
Following discussion in [PR-6670](deepspeedai#6670), the explict upcast is much more efficient than implicit upcast, this PR is to replace implicit upcast with explict one. The results on 3B model are shown below: | Option | BWD (ms) | Speed up | |------------|-----|------| | Before PR-6670 | 25603.30 | 1x | | After PR-6670 | 1174.31 | 21.8X | | After this PR| 309.2 | 82.8X |
This PR is similar to PR#5301, that optimizes the D2H time use pinned memory.
Previously, the D2H memcpy will be the bottleneck during the final backward pass of each iteration for ZeRO-Infinity(offload), as shown in Trace-1. The new version can eliminate the bottleneck, as shown in Trace-2.
Trace-1
![image](https://private-user-images.githubusercontent.com/62164985/380367085-891e3770-351b-4e03-8a59-b491bc44d03b.png?jwt=eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9.eyJpc3MiOiJnaXRodWIuY29tIiwiYXVkIjoicmF3LmdpdGh1YnVzZXJjb250ZW50LmNvbSIsImtleSI6ImtleTUiLCJleHAiOjE3Mzk0MTA1NzksIm5iZiI6MTczOTQxMDI3OSwicGF0aCI6Ii82MjE2NDk4NS8zODAzNjcwODUtODkxZTM3NzAtMzUxYi00ZTAzLThhNTktYjQ5MWJjNDRkMDNiLnBuZz9YLUFtei1BbGdvcml0aG09QVdTNC1ITUFDLVNIQTI1NiZYLUFtei1DcmVkZW50aWFsPUFLSUFWQ09EWUxTQTUzUFFLNFpBJTJGMjAyNTAyMTMlMkZ1cy1lYXN0LTElMkZzMyUyRmF3czRfcmVxdWVzdCZYLUFtei1EYXRlPTIwMjUwMjEzVDAxMzExOVomWC1BbXotRXhwaXJlcz0zMDAmWC1BbXotU2lnbmF0dXJlPTJjOWE2ZjlhM2RmZDJlZWY4MWM0NmRhNzIzNTljYjU4MjM5YTRhZTdkMDRkNzg0NDliZDEyOTI0ZjdhNjg1MDkmWC1BbXotU2lnbmVkSGVhZGVycz1ob3N0In0.y9U33HajzP623YHCy9fDX-GmXwMi31EB24sjEDsw6Wg)
Trace-2
![image](https://private-user-images.githubusercontent.com/62164985/380367116-f1cf9037-77f8-42a6-adc8-d5c6bacde0aa.png?jwt=eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9.eyJpc3MiOiJnaXRodWIuY29tIiwiYXVkIjoicmF3LmdpdGh1YnVzZXJjb250ZW50LmNvbSIsImtleSI6ImtleTUiLCJleHAiOjE3Mzk0MTA1NzksIm5iZiI6MTczOTQxMDI3OSwicGF0aCI6Ii82MjE2NDk4NS8zODAzNjcxMTYtZjFjZjkwMzctNzdmOC00MmE2LWFkYzgtZDVjNmJhY2RlMGFhLnBuZz9YLUFtei1BbGdvcml0aG09QVdTNC1ITUFDLVNIQTI1NiZYLUFtei1DcmVkZW50aWFsPUFLSUFWQ09EWUxTQTUzUFFLNFpBJTJGMjAyNTAyMTMlMkZ1cy1lYXN0LTElMkZzMyUyRmF3czRfcmVxdWVzdCZYLUFtei1EYXRlPTIwMjUwMjEzVDAxMzExOVomWC1BbXotRXhwaXJlcz0zMDAmWC1BbXotU2lnbmF0dXJlPTVmNmE4NTUxZDc4ODAyMmQyMWJlZTFlNTM5N2UwNTA4OWUyZGNjNWJjZDQwMDgyZGZjNmVkNzNmZDg2ZmMwODgmWC1BbXotU2lnbmVkSGVhZGVycz1ob3N0In0.DWDCiQGH32fD4rr9tq8kTWb6eVB58b5LVcoCFoZweQ4)
cc @tjruwase