Skip to content
New issue

Have a question about this project? # for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “#”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? # to your account

Fix: training resume from fp16 for SDXL Consistency Distillation #6840

Merged

Conversation

asrimanth
Copy link
Contributor

@asrimanth asrimanth commented Feb 4, 2024

What does this PR do?

Part of #6552 for SDXL Consistency Distillation

Before submitting

What to test?

First run this:

accelerate launch train_lcm_distill_lora_sdxl.py \
  --pretrained_teacher_model="stabilityai/stable-diffusion-xl-base-1.0"  \
  --pretrained_vae_model_name_or_path="madebyollin/sdxl-vae-fp16-fix" \
  --output_dir="pokemons-lora-lcm-sdxl" \
  --mixed_precision="fp16" \
  --dataset_name="lambdalabs/pokemon-blip-captions" \
  --resolution=512 \
  --train_batch_size=1 \
  --gradient_accumulation_steps=1 --gradient_checkpointing \
  --use_8bit_adam \
  --lora_rank=16 \
  --learning_rate=1e-4 \
  --lr_scheduler="constant" \
  --lr_warmup_steps=0 \
  --max_train_steps=4 --checkpointing_steps=2 --checkpoints_total_limit=2 \
  --seed="0"

To resume training, run the following command:

accelerate launch train_lcm_distill_lora_sdxl.py \
  --pretrained_teacher_model="stabilityai/stable-diffusion-xl-base-1.0"  \
  --pretrained_vae_model_name_or_path="madebyollin/sdxl-vae-fp16-fix" \
  --output_dir="pokemons-lora-lcm-sdxl" \
  --mixed_precision="fp16" \
  --dataset_name="lambdalabs/pokemon-blip-captions" \
  --resolution=512 \
  --train_batch_size=1 \
  --gradient_accumulation_steps=1 --gradient_checkpointing \
  --use_8bit_adam \
  --lora_rank=16 \
  --learning_rate=1e-4 \
  --lr_scheduler="constant" \
  --lr_warmup_steps=0 \
  --max_train_steps=8 --checkpointing_steps=2 --checkpoints_total_limit=2 \
  --seed="0" \
  --resume_from_checkpoint="latest"

Who can review?

@sayakpaul

@asrimanth asrimanth changed the title Fix: training resume from fp16 for lcm distill lora sdxl Fix: training resume from fp16 for SDXL Consistency Distillation Feb 4, 2024
@sayakpaul
Copy link
Member

Thanks for your contribution. Could we please reduce the number of steps to quickly check this?

@HuggingFaceDocBuilderDev

The docs for this PR live here. All of your documentation changes will be reflected on that endpoint. The docs are available until 30 days after the last update.

@asrimanth
Copy link
Contributor Author

asrimanth commented Feb 4, 2024

Do you mean reducing the max_train_steps or the checkpointing_steps? FYI, the above-mentioned tests work for me in my local machine. Also, how do I fix the code quality check?

@sayakpaul
Copy link
Member

Do you mean reducing the max_train_steps or the checkpointing_steps? FYI, the above-mentioned tests work for me in my local machine. Also, how do I fix the code quality check?

Yeah. I am sure the above command will work but will take longer to validate. The effectivity can essentially be tested with far fewer steps, as I did here. Also, please run the code quality linter by running make style && make quality.

@asrimanth
Copy link
Contributor Author

I've pushed a commit for the code style, and I've edited the test command to save a checkpoint for fewer number of steps, as per your example.

@sayakpaul
Copy link
Member

I don't think I made myself super clear. Sorry.

Your example commands still have max_train_steps set to 3000. Why though? Here, I have it set to 4 8 respectively along with checkpointing_steps and checkpoints_total_limit set accordingly. I don't have any unnecessary arguments here such as report_to.

Can we please keep the example commands super lean?

I hope I made myself more clear this time.

@asrimanth
Copy link
Contributor Author

Yeah, makes sense! I've made the changes so that the training command is leaner and more efficient.

@sayakpaul
Copy link
Member

I just tried it and I am facing shape mismatch problems while running the second command.

@asrimanth
Copy link
Contributor Author

Updates: I just got a Runpod machine with higher VRAM for testing. I was able to test the updated script with lower settings. I also made some changes to the code and it works on my local machine. Please test it out and let me know.

@sayakpaul
Copy link
Member

Thank you! But now, I am seeing some state dict key mismtaches when I am resuming training. Could you look into that?

@asrimanth
Copy link
Contributor Author

asrimanth commented Feb 8, 2024

Pushed a new commit to fix the missing keys issue. Please have a look and let me know.

@@ -305,7 +310,7 @@ def parse_args():
parser.add_argument(
"--cache_dir",
type=str,
default=None,
default="/workspace/cache",
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We don't need to change this.

Copy link
Contributor Author

@asrimanth asrimanth Feb 8, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Forgot to remove it after local testing. Reverted it to None and pushed a commit. Sorry for the inconvenience.

@sayakpaul
Copy link
Member

Thanks so much for iterating. I have checked and it works: https://colab.research.google.com/gist/sayakpaul/fd2e863c9911031ad01fa9cf6863a5da/scratchpad.ipynb.

Will merge the PR once the test suite passes.

@asrimanth
Copy link
Contributor Author

Thank you for the consistent feedback. Happy to contribute to HuggingFace.

@sayakpaul sayakpaul merged commit a11b0f8 into huggingface:main Feb 8, 2024
AmericanPresidentJimmyCarter pushed a commit to AmericanPresidentJimmyCarter/diffusers that referenced this pull request Apr 26, 2024
…gingface#6840)

* Fix: training resume from fp16 for lcm distill lora sdxl

* Fix coding quality - run linter

* Fix 1 - shift mixed precision cast before optimizer

* Fix 2 - State dict errors by removing load_lora_into_unet

* Update train_lcm_distill_lora_sdxl.py - Revert default cache dir to None

---------

Co-authored-by: Sayak Paul <spsayakpaul@gmail.com>
# for free to join this conversation on GitHub. Already have an account? # to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants