-
Notifications
You must be signed in to change notification settings - Fork 5.8k
New issue
Have a question about this project? # for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “#”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? # to your account
[Training] fix training resuming problem when using FP16 (SDXL LoRA DreamBooth) #6514
Conversation
# Make sure the trainable params are in float32. | ||
if args.mixed_precision == "fp16": | ||
models = [unet] | ||
if args.train_text_encoder: | ||
models.extend([text_encoder_one, text_encoder_two]) | ||
for model in models: | ||
for param in model.parameters(): | ||
# only upcast trainable parameters (LoRA) into fp32 | ||
if param.requires_grad: | ||
param.data = param.to(torch.float32) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
We do it just before assigning the parameters to the optimizer to avoid any consequences.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
In a follow-up PR, I can wrap this utility into a function and move to training_utils.py
. It's shared by a number of scripts.
The docs for this PR live here. All of your documentation changes will be reflected on that endpoint. The docs are available until 30 days after the last update. |
LoraLoaderMixin.load_lora_into_unet(lora_state_dict, network_alphas=network_alphas, unet=unet_) | ||
|
||
text_encoder_state_dict = {k: v for k, v in lora_state_dict.items() if "text_encoder." in k} | ||
LoraLoaderMixin.load_lora_into_text_encoder( | ||
text_encoder_state_dict, network_alphas=network_alphas, text_encoder=text_encoder_one_ | ||
) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
We cannot be using load_lora_into_unet()
and load_lora_into_text_encoder()
for the following reason (described only for the UNet but applicable to the text encoders, too).
- We call
add_adapter()
once on unet at the beginning of training. This creates an adapter config inside of the UNet. - Then during loading an intermediate checkpoint in the accelerate hook, we call
load_lora_into_unet()
. It internally again callsinject_adapter_in_model()
with the config inferred from the state dict provided. So, it internally creates another adapter. This is undesirable, right?
@@ -996,17 +997,6 @@ def main(args): | |||
text_encoder_one.add_adapter(text_lora_config) | |||
text_encoder_two.add_adapter(text_lora_config) | |||
|
|||
# Make sure the trainable params are in float32. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks Sayak, LGTM.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Looking great ! thanks for all your work on this! I left one comment, wdyt?
@BenjaminBossan pinging for #6514 (comment). |
…reamBooth) (huggingface#6514) * fix: training resume from fp16. * add: comment * remove residue from another branch. * remove more residues. * thanks to Younes; no hacks. * style. * clean things a bit and modularize _set_state_dict_into_text_encoder * add comment about the fix detailed.
What does this PR do?
Tries to solve issues like #6442 in a clean way. Limits the changes to the DreamBooth SDXL LoRA script for now.
To test
First run
And then resume:
Tested with
train_text_encoder
flag on as well.I would appreciate if one of the reviewers could cross-check this.