-
Notifications
You must be signed in to change notification settings - Fork 4.2k
New issue
Have a question about this project? # for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “#”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? # to your account
[BUG] Crash with a minimal ZeRO stage 3 NVMe checkpointing example #4565
Comments
Forgot to comment, the minimal example causes the engine to go through this code path ( if allreduce_gradients and self.enable_backward_allreduce:
# Traditional code path that allreduces the module parameter grads
self.allreduce_gradients() Here it erases the gradients without offloading them, as described above. |
@eisene, thanks for reporting this issue. I believe you have correctly identified the buggy re-initialization of the two dicts: DeepSpeed/deepspeed/runtime/zero/stage3.py Lines 1379 to 1380 in a3926bb
These dicts are correctly initialized at the beginning of the function: DeepSpeed/deepspeed/runtime/zero/stage3.py Lines 1345 to 1346 in a3926bb
Are you able to submit a PR deleting the re-initializations? Thanks! |
Done, see #4702 |
As discussed in microsoft#4565 with @tjruwase Fix microsoft#4565 Fix microsoft#4696
Describe the bug
Simplest possible training code with ZeRO stage 3 with NVMe offload for the optimizer crashes on
model.step()
with the errorTo Reproduce
Expected behavior
This script should exit with no error.
ds_report output
System info (please complete the following information):
Environment:
Build command:
Launcher context
deepspeed
launcherDocker context
No docker.
Additional context
It seems that the problem is being caused by the following two lines 1334-1335 in
deepspeed/runtime/zero/stage3.py
inDeepSpeedZeroOptimizer_Stage3.partition_grads
:This resets the dictionary of offloaded gradients so that, later in the same function, lines 1357-1361 do nothing:
Commenting the lines marked BUG causes the script to work as expected.
The text was updated successfully, but these errors were encountered: