-
Notifications
You must be signed in to change notification settings - Fork 27k
New issue
Have a question about this project? # for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “#”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? # to your account
Fix RNG reload in resume training from epoch checkpoint #17055
Conversation
The documentation is not available anymore as the PR was closed or merged. |
tests/trainer/test_trainer.py
Outdated
# For more than 1 GPUs, since the randomness is introduced in the model and with DataParallel (which is used | ||
# in this test for more than 2 GPUs), the calls to the torch RNG will happen in a random order (sometimes | ||
# GPU 0 will call first and sometimes GPU 1). | ||
random_torch = torch.cuda.is_available() and torch.cuda.device_count() >= 1 |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Sorry, just a question regarding this line. AFAICT random_torch
would only be True
if at least one GPU is available. But this would mean this test case will not cover torch
randomness when using the CPU. The unit test before this commit however did test randomness on the CPU, or at least was able to if no GPU was available. Is this change intended?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Good catch! I'll fix this :-)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM, thanks @sgugger!
…17055) * Fix RNG reload in resume training from epoch checkpoint * Fix test
…17055) * Fix RNG reload in resume training from epoch checkpoint * Fix test
…17055) * Fix RNG reload in resume training from epoch checkpoint * Fix test
What does this PR do?
This PR fixes the reproducibility in training when checkpoints are saved every epoch. The main reason it was failing (as pointed out in #17032) is that the RNG states were never reloaded. They need to be reloaded exactly before iterating through the new epoch, as the call to this will change the global PyTorch RNG (even if the dataloader uses its own generator...) The new test added makes sure this reproducibility is fully tested.
While debugging this, two issues occurred, which this PR also fixes.
DataParallel
(an issue that wouldn't be the case withDistributedDataParallel
but we would need to execute the test via a launcher in that case). So in the test, we only do PyTorch randomness on one or zero GPU to fix this flakiness.Fixes #17032