Skip to content
New issue

Have a question about this project? # for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “#”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? # to your account

fix: unscale fp16 gradient problem & potential error (#6086) #6231

Merged
merged 3 commits into from
Dec 21, 2023

Conversation

lvzii
Copy link
Contributor

@lvzii lvzii commented Dec 19, 2023

Referring to commit ,
Fixes # (6086) in train_text_to_image_lora_sdxl.py, and Fixes # (4619) in train_text_to_image_lora_sdxl.py during fixing the former error.

@HuggingFaceDocBuilderDev

The docs for this PR live here. All of your documentation changes will be reflected on that endpoint. The docs are available until 30 days after the last update.

Comment on lines +1201 to +1203
# Make sure vae.dtype is consistent with the unet.dtype
if args.mixed_precision == "fp16":
vae.to(weight_dtype)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is not needed in my opinion. We already set the torch_dtype in the pipeline during loading it.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If no other pretrained_vae_model_name_or_path is set, then the vae is set to float32,

if args.pretrained_vae_model_name_or_path is None:
        vae.to(accelerator.device, dtype=torch.float32)

and the pipeline here does not reload vae. So vae.dtype float32 != unet.dtype fp16, which in my tests causes RuntimeError: Input type (c10::Half) and bias type (float) should be the same

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ah okay got it!

Comment on lines +643 to +652
# Make sure the trainable params are in float32.
if args.mixed_precision == "fp16":
models = [unet]
if args.train_text_encoder:
models.extend([text_encoder_one, text_encoder_two])
for model in models:
for param in model.parameters():
# only upcast trainable parameters (LoRA) into fp32
if param.requires_grad:
param.data = param.to(torch.float32)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Works for me!

Copy link
Member

@sayakpaul sayakpaul left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

TYSM!

@sayakpaul sayakpaul merged commit 6ca9c4a into huggingface:main Dec 21, 2023
@sayakpaul
Copy link
Member

Thank you for your contributions!

donhardman pushed a commit to donhardman/diffusers that referenced this pull request Dec 29, 2023
AmericanPresidentJimmyCarter pushed a commit to AmericanPresidentJimmyCarter/diffusers that referenced this pull request Apr 26, 2024
# for free to join this conversation on GitHub. Already have an account? # to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants