Skip to content
New issue

Have a question about this project? # for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “#”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? # to your account

[Training] fix training resuming problem when using FP16 (SDXL LoRA DreamBooth) #6514

Merged
merged 11 commits into from
Jan 12, 2024

Conversation

sayakpaul
Copy link
Member

@sayakpaul sayakpaul commented Jan 10, 2024

What does this PR do?

Tries to solve issues like #6442 in a clean way. Limits the changes to the DreamBooth SDXL LoRA script for now.

To test

First run

CUDA_VISIBLE_DEVICES=0 accelerate launch train_dreambooth_lora_sdxl.py \
  --pretrained_model_name_or_path=stabilityai/stable-diffusion-xl-base-1.0 \
  --pretrained_vae_model_name_or_path=madebyollin/sdxl-vae-fp16-fix \
  --instance_data_dir="dog" \
  --mixed_precision="fp16" \
  --instance_prompt="a photo of sks dog" \
  --output_dir="lora-trained-sdxl" \
  --resolution=1024 \
  --train_batch_size=1 \
  --gradient_accumulation_steps=4 --gradient_checkpointing \
  --learning_rate=1e-4 \
  --lr_scheduler="constant" \
  --lr_warmup_steps=0 \
  --max_train_steps=4 --checkpointing_steps=2 --checkpoints_total_limit=2 \
  --use_8bit_adam \
  --seed="0"

And then resume:

CUDA_VISIBLE_DEVICES=0 accelerate launch train_dreambooth_lora_sdxl.py \
  --pretrained_model_name_or_path=stabilityai/stable-diffusion-xl-base-1.0 \
  --pretrained_vae_model_name_or_path=madebyollin/sdxl-vae-fp16-fix \
  --instance_data_dir="dog" \
  --mixed_precision="fp16" \
  --instance_prompt="a photo of sks dog" \
  --output_dir="lora-trained-sdxl" \
  --resolution=1024 \
  --train_batch_size=1 \
  --gradient_accumulation_steps=4 --gradient_checkpointing \
  --learning_rate=1e-4 \
  --lr_scheduler="constant" \
  --lr_warmup_steps=0 \
  --max_train_steps=8 --checkpointing_steps=2 --checkpoints_total_limit=2 \
  --resume_from_checkpoint="latest" \
  --use_8bit_adam \
  --seed="0"

Tested with train_text_encoder flag on as well.

I would appreciate if one of the reviewers could cross-check this.

Comment on lines +1103 to +1112
# Make sure the trainable params are in float32.
if args.mixed_precision == "fp16":
models = [unet]
if args.train_text_encoder:
models.extend([text_encoder_one, text_encoder_two])
for model in models:
for param in model.parameters():
# only upcast trainable parameters (LoRA) into fp32
if param.requires_grad:
param.data = param.to(torch.float32)
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We do it just before assigning the parameters to the optimizer to avoid any consequences.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

In a follow-up PR, I can wrap this utility into a function and move to training_utils.py. It's shared by a number of scripts.

@HuggingFaceDocBuilderDev

The docs for this PR live here. All of your documentation changes will be reflected on that endpoint. The docs are available until 30 days after the last update.

@sayakpaul sayakpaul changed the title [Training] fix training resuming problem when using FP16 [Training] fix training resuming problem when using FP16 (SDXL LoRA DreamBooth) Jan 10, 2024
Comment on lines -1061 to -1072
LoraLoaderMixin.load_lora_into_unet(lora_state_dict, network_alphas=network_alphas, unet=unet_)

text_encoder_state_dict = {k: v for k, v in lora_state_dict.items() if "text_encoder." in k}
LoraLoaderMixin.load_lora_into_text_encoder(
text_encoder_state_dict, network_alphas=network_alphas, text_encoder=text_encoder_one_
)
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We cannot be using load_lora_into_unet() and load_lora_into_text_encoder() for the following reason (described only for the UNet but applicable to the text encoders, too).

  • We call add_adapter() once on unet at the beginning of training. This creates an adapter config inside of the UNet.
  • Then during loading an intermediate checkpoint in the accelerate hook, we call load_lora_into_unet(). It internally again calls inject_adapter_in_model() with the config inferred from the state dict provided. So, it internally creates another adapter. This is undesirable, right?

@@ -996,17 +997,6 @@ def main(args):
text_encoder_one.add_adapter(text_lora_config)
text_encoder_two.add_adapter(text_lora_config)

# Make sure the trainable params are in float32.
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Copy link
Member

@BenjaminBossan BenjaminBossan left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks Sayak, LGTM.

Copy link
Contributor

@younesbelkada younesbelkada left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Looking great ! thanks for all your work on this! I left one comment, wdyt?

@sayakpaul
Copy link
Member Author

@BenjaminBossan pinging for #6514 (comment).

AmericanPresidentJimmyCarter pushed a commit to AmericanPresidentJimmyCarter/diffusers that referenced this pull request Apr 26, 2024
…reamBooth) (huggingface#6514)

* fix: training resume from fp16.

* add: comment

* remove residue from another branch.

* remove more residues.

* thanks to Younes; no hacks.

* style.

* clean things a bit and modularize _set_state_dict_into_text_encoder

* add comment about the fix detailed.
# for free to join this conversation on GitHub. Already have an account? # to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

4 participants