-
Notifications
You must be signed in to change notification settings - Fork 5.8k
New issue
Have a question about this project? # for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “#”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? # to your account
fix: unscale fp16 gradient problem & potential error (#6086) #6231
Conversation
The docs for this PR live here. All of your documentation changes will be reflected on that endpoint. The docs are available until 30 days after the last update. |
# Make sure vae.dtype is consistent with the unet.dtype | ||
if args.mixed_precision == "fp16": | ||
vae.to(weight_dtype) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This is not needed in my opinion. We already set the torch_dtype
in the pipeline during loading it.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
If no other pretrained_vae_model_name_or_path
is set, then the vae is set to float32,
if args.pretrained_vae_model_name_or_path is None:
vae.to(accelerator.device, dtype=torch.float32)
and the pipeline here does not reload vae. So vae.dtype
float32
!= unet.dtype
fp16
, which in my tests causes RuntimeError: Input type (c10::Half) and bias type (float) should be the same
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Ah okay got it!
# Make sure the trainable params are in float32. | ||
if args.mixed_precision == "fp16": | ||
models = [unet] | ||
if args.train_text_encoder: | ||
models.extend([text_encoder_one, text_encoder_two]) | ||
for model in models: | ||
for param in model.parameters(): | ||
# only upcast trainable parameters (LoRA) into fp32 | ||
if param.requires_grad: | ||
param.data = param.to(torch.float32) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Works for me!
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
TYSM!
Thank you for your contributions! |
… (huggingface#6231) Co-authored-by: Sayak Paul <spsayakpaul@gmail.com>
… (huggingface#6231) Co-authored-by: Sayak Paul <spsayakpaul@gmail.com>
Referring to commit ,
Fixes # (6086) in
train_text_to_image_lora_sdxl.py
, and Fixes # (4619) intrain_text_to_image_lora_sdxl.py
during fixing the former error.