-
Notifications
You must be signed in to change notification settings - Fork 5.8k
New issue
Have a question about this project? # for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “#”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? # to your account
Fix: training resume from fp16 for SDXL Consistency Distillation #6840
Fix: training resume from fp16 for SDXL Consistency Distillation #6840
Conversation
Thanks for your contribution. Could we please reduce the number of steps to quickly check this? |
The docs for this PR live here. All of your documentation changes will be reflected on that endpoint. The docs are available until 30 days after the last update. |
Do you mean reducing the |
Yeah. I am sure the above command will work but will take longer to validate. The effectivity can essentially be tested with far fewer steps, as I did here. Also, please run the code quality linter by running |
I've pushed a commit for the code style, and I've edited the test command to save a checkpoint for fewer number of steps, as per your example. |
I don't think I made myself super clear. Sorry. Your example commands still have Can we please keep the example commands super lean? I hope I made myself more clear this time. |
Yeah, makes sense! I've made the changes so that the training command is leaner and more efficient. |
I just tried it and I am facing shape mismatch problems while running the second command. |
Updates: I just got a Runpod machine with higher VRAM for testing. I was able to test the updated script with lower settings. I also made some changes to the code and it works on my local machine. Please test it out and let me know. |
Thank you! But now, I am seeing some state dict key mismtaches when I am resuming training. Could you look into that? |
Pushed a new commit to fix the missing keys issue. Please have a look and let me know. |
@@ -305,7 +310,7 @@ def parse_args(): | |||
parser.add_argument( | |||
"--cache_dir", | |||
type=str, | |||
default=None, | |||
default="/workspace/cache", |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
We don't need to change this.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Forgot to remove it after local testing. Reverted it to None and pushed a commit. Sorry for the inconvenience.
Thanks so much for iterating. I have checked and it works: https://colab.research.google.com/gist/sayakpaul/fd2e863c9911031ad01fa9cf6863a5da/scratchpad.ipynb. Will merge the PR once the test suite passes. |
Thank you for the consistent feedback. Happy to contribute to HuggingFace. |
…gingface#6840) * Fix: training resume from fp16 for lcm distill lora sdxl * Fix coding quality - run linter * Fix 1 - shift mixed precision cast before optimizer * Fix 2 - State dict errors by removing load_lora_into_unet * Update train_lcm_distill_lora_sdxl.py - Revert default cache dir to None --------- Co-authored-by: Sayak Paul <spsayakpaul@gmail.com>
What does this PR do?
Part of #6552 for SDXL Consistency Distillation
Before submitting
documentation guidelines, and
here are tips on formatting docstrings.
What to test?
First run this:
To resume training, run the following command:
Who can review?
@sayakpaul