Skip to content
New issue

Have a question about this project? # for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “#”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? # to your account

[Training] Add datasets version of LCM LoRA SDXL #5778

Merged
merged 109 commits into from
Dec 26, 2023
Merged

Conversation

sayakpaul
Copy link
Member

@sayakpaul sayakpaul commented Nov 13, 2023

What does this PR do?

Add a datasets compatible variant of https://github.com/huggingface/diffusers/blob/main/examples/consistency_distillation/train_lcm_distill_lora_sdxl_wds.py. It also adapts a couple of best practices from peft:

  • Instead of initializing two UNets (teacher and student), it leverages the disable_adapters() and enable_adapters() functions. In this case, the teacher is without any adapters and the student is with LoRA. We only update the LoRA params in the student.
  • Makes use of peft utility modules positioning it as a utility library rather than a modelling library.

Running a couple experiments. Will report back the findings. But should be more or less ready to be reviewed.

Basic training command

export MODEL_NAME="stabilityai/stable-diffusion-xl-base-1.0"
export OUTPUT_DIR="lora-lora-sdxl"
export DATASET_NAME="lambdalabs/pokemon-blip-captions"
export VAE_PATH="madebyollin/sdxl-vae-fp16-fix"

accelerate launch train_lcm_distill_lora_sdxl.py \
  --pretrained_teacher_model=${MODEL_NAME}  \
  --pretrained_vae_model_name_or_path=${VAE_PATH} \
  --output_dir=${OUTPUT_DIR} \
  --mixed_precision="fp16" \
  --dataset_name=$DATASET_NAME \
  --resolution=1024 \
  --train_batch_size=4 \
  --gradient_accumulation_steps=1 \
  --gradient_checkpointing \
  --use_8bit_adam \
  --learning_rate=1e-4 --loss_type="huber" \
  --report_to="wandb" \
  --lr_scheduler="constant" \
  --lr_warmup_steps=0 \
  --max_train_steps=1500 \
  --checkpointing_steps=50 \
  --validation_steps=25 \
  --seed="0" \
  --report_to="wandb" \
  --push_to_hub

TODOs

  • Update the PR with results
  • Docs
  • Tests

Copy link
Member

@pcuenca pcuenca left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Looks good so far, let me know if you need help testing / debugging

@sayakpaul
Copy link
Member Author

It works, needs hyperparameter tuning. If you could take it to a test run, that would be great!

Copy link
Contributor

@patil-suraj patil-suraj left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks a lot for working on this. Just left some comments. Most importantly:

Let's make sure we can load the saved lora with the pipeline for inference. I had some issues with it when I tried.

Everything else looks good !

Comment on lines +779 to +793
target_modules=[
"to_q",
"to_k",
"to_v",
"to_out.0",
"proj_in",
"proj_out",
"ff.net.0.proj",
"ff.net.2",
"conv1",
"conv2",
"conv_shortcut",
"downsamplers.0.conv",
"upsamplers.0.conv",
"time_emb_proj",
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

for later:

We could also think of making this an argument.

@sayakpaul
Copy link
Member Author

Let's make sure we can load the saved lora with the pipeline for inference. I had some issues with it when I tried.

Could you post the snippet with which you tried and the error you got?

@patil-suraj
Copy link
Contributor

Let's make sure we can load the saved lora with the pipeline for inference. I had some issues with it when I tried.

Could you post the snippet with which you tried and the error you got?

When I try to load the lora saved with script using

pipe.load_lora_weights(path_to_saved_lora, weight_name="pytorch_lora_weights.safetensors")

I get

---------------------------------------------------------------------------
IndexError                                Traceback (most recent call last)
Cell In [13], line 1
----> 1 pipe.load_lora_weights("valhalla/lcm-sdxl-distill-lora-k5", weight_name="pytorch_lora_weights.safetensors")

File ~/diffusers/src/diffusers/loaders.py:3233, in StableDiffusionXLLoraLoaderMixin.load_lora_weights(self, pretrained_model_name_or_path_or_dict, adapter_name, **kwargs)
   3230 if not is_correct_format:
   3231     raise ValueError("Invalid LoRA checkpoint.")
-> 3233 self.load_lora_into_unet(
   3234     state_dict, network_alphas=network_alphas, unet=self.unet, adapter_name=adapter_name, _pipeline=self
   3235 )
   3236 text_encoder_state_dict = {k: v for k, v in state_dict.items() if "text_encoder." in k}
   3237 if len(text_encoder_state_dict) > 0:

File ~/diffusers/src/diffusers/loaders.py:1630, in LoraLoaderMixin.load_lora_into_unet(cls, state_dict, network_alphas, unet, low_cpu_mem_usage, adapter_name, _pipeline)
   1627     if "lora_B" in key:
   1628         rank[key] = val.shape[1]
-> 1630 lora_config_kwargs = get_peft_kwargs(rank, network_alphas, state_dict, is_unet=True)
   1631 lora_config = LoraConfig(**lora_config_kwargs)
   1633 # adapter_name

File ~/diffusers/src/diffusers/utils/peft_utils.py:122, in get_peft_kwargs(rank_dict, network_alpha_dict, peft_state_dict, is_unet)
    120 rank_pattern = {}
    121 alpha_pattern = {}
--> 122 r = lora_alpha = list(rank_dict.values())[0]
    124 if len(set(rank_dict.values())) > 1:
    125     # get the rank occuring the most number of times
    126     r = collections.Counter(rank_dict.values()).most_common()[0][0]

@HuggingFaceDocBuilderDev

The docs for this PR live here. All of your documentation changes will be reflected on that endpoint.

@sayakpaul sayakpaul changed the title [WIP][Training] Add datasets version of LCM LoRA SDXL [Training] Add datasets version of LCM LoRA SDXL Dec 20, 2023
@sayakpaul
Copy link
Member Author

@dg845 as well if you want to give this a look :-)

Comment on lines +195 to +198
def scalings_for_boundary_conditions(timestep, sigma_data=0.5, timestep_scaling=10.0):
c_skip = sigma_data**2 / ((timestep / 0.1) ** 2 + sigma_data**2)
c_out = (timestep / 0.1) / ((timestep / 0.1) ** 2 + sigma_data**2) ** 0.5
return c_skip, c_out
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We could make this a little more general by using the timestep_scaling argument (and perhaps expose this as an argument):

Suggested change
def scalings_for_boundary_conditions(timestep, sigma_data=0.5, timestep_scaling=10.0):
c_skip = sigma_data**2 / ((timestep / 0.1) ** 2 + sigma_data**2)
c_out = (timestep / 0.1) / ((timestep / 0.1) ** 2 + sigma_data**2) ** 0.5
return c_skip, c_out
def scalings_for_boundary_conditions(timestep, sigma_data=0.5, timestep_scaling=10.0):
scaled_timestep = timestep_scaling * timestep
c_skip = sigma_data**2 / (scaled_timestep**2 + sigma_data**2)
c_out = scaled_timestep / (scaled_timestep**2 + sigma_data**2) ** 0.5
return c_skip, c_out

(The current function is the same as in examples/consistency_distillation/train_lcm_distill_lora_sdxl_wds.py so if we make this change we should probably also propagate it to the WebDataset scripts.)

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Would prefer reflecting that after it's changed in WDS so that it's easier to track.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

+1, let's adapt this for both scripts.

Comment on lines +792 to +793
r=args.lora_rank,
lora_alpha=args.lora_rank,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Would it make sense to allow lora_alpha to be set independently of the r/lora_rank, to allow the LoRA layer scaling to be controlled?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Maybe in a future PR. Since this PR is almost just a copy-paste of the WDS version.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

makes sense.

@sayakpaul
Copy link
Member Author

@patil-suraj @dg845 ran a recent experiment with the following:

export MODEL_NAME="stabilityai/stable-diffusion-xl-base-1.0"
export OUTPUT_DIR="lora-lcm-sdxl-new"
export DATASET_NAME="lambdalabs/pokemon-blip-captions"
export VAE_PATH="madebyollin/sdxl-vae-fp16-fix"

CUDA_VISIBLE_DEVICES=1 accelerate launch train_lcm_distill_lora_sdxl.py \
  --pretrained_teacher_model=${MODEL_NAME}  \
  --pretrained_vae_model_name_or_path=${VAE_PATH} \
  --output_dir="lora-lcm-sdxl-new" \
  --mixed_precision="fp16" \
  --dataset_name=$DATASET_NAME \
  --resolution=1024 \
  --train_batch_size=24 \
  --gradient_accumulation_steps=1 \
  --gradient_checkpointing \
  --use_8bit_adam \
  --lora_rank=64 \
  --learning_rate=1e-4 \
  --report_to="wandb" \
  --lr_scheduler="constant" \
  --lr_warmup_steps=0 \
  --max_train_steps=10000 \
  --checkpointing_steps=3000 \
  --validation_steps=50 \
  --seed="0" \
  --report_to="wandb" \
  --push_to_hub

WandB: https://wandb.ai/sayakpaul/text2image-fine-tune/runs/tv3zw00t
Weights: https://huggingface.co/sayakpaul/pokemons-lora-lcm-sdxl

Given it's only 10k steps and I haven't ablated the hyperparameters, I'd say it's still good enough. WDYT? I'd be keen on merging this soon as it reduces the memory requirements significantly by following good PEFT practices.

WDYT?

@patil-suraj
Copy link
Contributor

Sounds good! Usually in my experiment 2-3k steps were enough, 10k results in overfitting with large BS.

Comment on lines +195 to +198
def scalings_for_boundary_conditions(timestep, sigma_data=0.5, timestep_scaling=10.0):
c_skip = sigma_data**2 / ((timestep / 0.1) ** 2 + sigma_data**2)
c_out = (timestep / 0.1) / ((timestep / 0.1) ** 2 + sigma_data**2) ** 0.5
return c_skip, c_out
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

+1, let's adapt this for both scripts.

Comment on lines +792 to +793
r=args.lora_rank,
lora_alpha=args.lora_rank,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

makes sense.

kashif and others added 9 commits December 26, 2023 20:22
* [Peft] fix saving / loading when unet is not "unet"

* Update src/diffusers/loaders/lora.py

Co-authored-by: Sayak Paul <spsayakpaul@gmail.com>

* undo stablediffusion-xl changes

* use unet_name to get unet for lora helpers

* use unet_name

---------

Co-authored-by: Sayak Paul <spsayakpaul@gmail.com>
fix fp16 training

Co-authored-by: Sayak Paul <spsayakpaul@gmail.com>
@sayakpaul
Copy link
Member Author

sayakpaul commented Dec 26, 2023

@patil-suraj I added the following for the reference training command when training on a very small dataset like Pokemons:

export MODEL_NAME="stabilityai/stable-diffusion-xl-base-1.0"
export DATASET_NAME="lambdalabs/pokemon-blip-captions"
export VAE_PATH="madebyollin/sdxl-vae-fp16-fix"

accelerate launch train_lcm_distill_lora_sdxl.py \
  --pretrained_teacher_model=${MODEL_NAME}  \
  --pretrained_vae_model_name_or_path=${VAE_PATH} \
  --output_dir="pokemons-lora-lcm-sdxl" \
  --mixed_precision="fp16" \
  --dataset_name=$DATASET_NAME \
  --resolution=1024 \
  --train_batch_size=24 \
  --gradient_accumulation_steps=1 \
  --gradient_checkpointing \
  --use_8bit_adam \
  --lora_rank=64 \
  --learning_rate=1e-4 \
  --report_to="wandb" \
  --lr_scheduler="constant" \
  --lr_warmup_steps=0 \
  --max_train_steps=3000 \
  --checkpointing_steps=500 \
  --validation_steps=50 \
  --seed="0" \
  --report_to="wandb" \
  --push_to_hub

This has 3k steps instead of 10k. The training dynamics for this aren't clear but https://wandb.ai/sayakpaul/text2image-fine-tune/runs/tv3zw00t definitely shows progress IMO.

@sayakpaul
Copy link
Member Author

Will merge after the CI is green.

@sayakpaul sayakpaul merged commit 6683f97 into main Dec 26, 2023
@sayakpaul sayakpaul deleted the lcm-lora-sdxl-datasets branch December 26, 2023 15:52
donhardman pushed a commit to donhardman/diffusers that referenced this pull request Dec 29, 2023
* add: script to train lcm lora for sdxl with 🤗 datasets

* suit up the args.

* remove comments.

* fix num_update_steps

* fix batch unmarshalling

* fix num_update_steps_per_epoch

* fix; dataloading.

* fix microconditions.

* unconditional predictions debug

* fix batch size.

* no need to use use_auth_token

* Apply suggestions from code review

Co-authored-by: Suraj Patil <surajp815@gmail.com>

* make vae encoding batch size an arg

* final serialization in kohya

* style

* state dict rejigging

* feat: no separate teacher unet.

* debug

* fix state dict serialization

* debug

* debug

* debug

* remove prints.

* remove kohya utility and make style

* fix serialization

* fix

* add test

* add peft dependency.

* add: peft

* remove peft

* autocast device determination from accelerator

* autocast

* reduce lora rank.

* remove unneeded space

* Apply suggestions from code review

Co-authored-by: Suraj Patil <surajp815@gmail.com>

* style

* remove prompt dropout.

* also save in native diffusers ckpt format.

* debug

* debug

* debug

* better formation of the null embeddings.

* remove space.

* autocast fixes.

* autocast fix.

* hacky

* remove lora_sayak

* Apply suggestions from code review

Co-authored-by: Younes Belkada <49240599+younesbelkada@users.noreply.github.com>

* style

* make log validation leaner.

* move back enabled in.

* fix: log_validation call.

* add: checkpointing tests

* taking my chances to see if disabling autocasting has any effect?

* start debugging

* name

* name

* name

* more debug

* more debug

* index

* remove index.

* print length

* print length

* print length

* move unet.train() after add_adapter()

* disable some prints.

* enable_adapters() manually.

* remove prints.

* some changes.

* fix params_to_optimize

* more fixes

* debug

* debug

* remove print

* disable grad for certain contexts.

* Add support for IPAdapterFull (huggingface#5911)

* Add support for IPAdapterFull


Co-authored-by: Patrick von Platen <patrick.v.platen@gmail.com>

---------

Co-authored-by: YiYi Xu <yixu310@gmail.com>
Co-authored-by: Patrick von Platen <patrick.v.platen@gmail.com>

* Fix a bug in `add_noise` function  (huggingface#6085)

* fix

* copies

---------

Co-authored-by: yiyixuxu <yixu310@gmail,com>

* [Advanced Diffusion Script] Add Widget default text (huggingface#6100)

add widget

* [Advanced Training Script] Fix pipe example (huggingface#6106)

* IP-Adapter for StableDiffusionControlNetImg2ImgPipeline (huggingface#5901)

* adapter for StableDiffusionControlNetImg2ImgPipeline

* fix-copies

* fix-copies

---------

Co-authored-by: Sayak Paul <spsayakpaul@gmail.com>

* IP adapter support for most pipelines (huggingface#5900)

* support ip-adapter in src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_upscale.py

* support ip-adapter in src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_attend_and_excite.py

* support ip-adapter in src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_instruct_pix2pix.py

* update tests

* support ip-adapter in src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_panorama.py

* support ip-adapter in src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_sag.py

* support ip-adapter in src/diffusers/pipelines/stable_diffusion_safe/pipeline_stable_diffusion_safe.py

* support ip-adapter in src/diffusers/pipelines/latent_consistency_models/pipeline_latent_consistency_text2img.py

* support ip-adapter in src/diffusers/pipelines/latent_consistency_models/pipeline_latent_consistency_img2img.py

* support ip-adapter in src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_ldm3d.py

* revert changes to sd_attend_and_excite and sd_upscale

* make style

* fix broken tests

* update ip-adapter implementation to latest

* apply suggestions from review

---------

Co-authored-by: YiYi Xu <yixu310@gmail.com>
Co-authored-by: Sayak Paul <spsayakpaul@gmail.com>

* fix: lora_alpha

* make vae casting conditional/

* param upcasting

* propagate comments from huggingface#6145

Co-authored-by: dg845 <dgu8957@gmail.com>

* [Peft] fix saving / loading when unet is not "unet" (huggingface#6046)

* [Peft] fix saving / loading when unet is not "unet"

* Update src/diffusers/loaders/lora.py

Co-authored-by: Sayak Paul <spsayakpaul@gmail.com>

* undo stablediffusion-xl changes

* use unet_name to get unet for lora helpers

* use unet_name

---------

Co-authored-by: Sayak Paul <spsayakpaul@gmail.com>

* [Wuerstchen] fix fp16 training and correct lora args (huggingface#6245)

fix fp16 training

Co-authored-by: Sayak Paul <spsayakpaul@gmail.com>

* [docs] fix: animatediff docs (huggingface#6339)

fix: animatediff docs

* add: note about the new script in readme_sdxl.

* Revert "[Peft] fix saving / loading when unet is not "unet" (huggingface#6046)"

This reverts commit 4c7e983.

* Revert "[Wuerstchen] fix fp16 training and correct lora args (huggingface#6245)"

This reverts commit 0bb9cf0.

* Revert "[docs] fix: animatediff docs (huggingface#6339)"

This reverts commit 11659a6.

* remove tokenize_prompt().

* assistive comments around enable_adapters() and diable_adapters().

---------

Co-authored-by: Suraj Patil <surajp815@gmail.com>
Co-authored-by: Younes Belkada <49240599+younesbelkada@users.noreply.github.com>
Co-authored-by: Fabio Rigano <57982783+fabiorigano@users.noreply.github.com>
Co-authored-by: YiYi Xu <yixu310@gmail.com>
Co-authored-by: Patrick von Platen <patrick.v.platen@gmail.com>
Co-authored-by: yiyixuxu <yixu310@gmail,com>
Co-authored-by: apolinário <joaopaulo.passos@gmail.com>
Co-authored-by: Charchit Sharma <charchitsharma11@gmail.com>
Co-authored-by: Aryan V S <contact.aryanvs@gmail.com>
Co-authored-by: dg845 <dgu8957@gmail.com>
Co-authored-by: Kashif Rasul <kashif.rasul@gmail.com>
antoine-scenario pushed a commit to antoine-scenario/diffusers that referenced this pull request Jan 2, 2024
* add: script to train lcm lora for sdxl with 🤗 datasets

* suit up the args.

* remove comments.

* fix num_update_steps

* fix batch unmarshalling

* fix num_update_steps_per_epoch

* fix; dataloading.

* fix microconditions.

* unconditional predictions debug

* fix batch size.

* no need to use use_auth_token

* Apply suggestions from code review

Co-authored-by: Suraj Patil <surajp815@gmail.com>

* make vae encoding batch size an arg

* final serialization in kohya

* style

* state dict rejigging

* feat: no separate teacher unet.

* debug

* fix state dict serialization

* debug

* debug

* debug

* remove prints.

* remove kohya utility and make style

* fix serialization

* fix

* add test

* add peft dependency.

* add: peft

* remove peft

* autocast device determination from accelerator

* autocast

* reduce lora rank.

* remove unneeded space

* Apply suggestions from code review

Co-authored-by: Suraj Patil <surajp815@gmail.com>

* style

* remove prompt dropout.

* also save in native diffusers ckpt format.

* debug

* debug

* debug

* better formation of the null embeddings.

* remove space.

* autocast fixes.

* autocast fix.

* hacky

* remove lora_sayak

* Apply suggestions from code review

Co-authored-by: Younes Belkada <49240599+younesbelkada@users.noreply.github.com>

* style

* make log validation leaner.

* move back enabled in.

* fix: log_validation call.

* add: checkpointing tests

* taking my chances to see if disabling autocasting has any effect?

* start debugging

* name

* name

* name

* more debug

* more debug

* index

* remove index.

* print length

* print length

* print length

* move unet.train() after add_adapter()

* disable some prints.

* enable_adapters() manually.

* remove prints.

* some changes.

* fix params_to_optimize

* more fixes

* debug

* debug

* remove print

* disable grad for certain contexts.

* Add support for IPAdapterFull (huggingface#5911)

* Add support for IPAdapterFull


Co-authored-by: Patrick von Platen <patrick.v.platen@gmail.com>

---------

Co-authored-by: YiYi Xu <yixu310@gmail.com>
Co-authored-by: Patrick von Platen <patrick.v.platen@gmail.com>

* Fix a bug in `add_noise` function  (huggingface#6085)

* fix

* copies

---------

Co-authored-by: yiyixuxu <yixu310@gmail,com>

* [Advanced Diffusion Script] Add Widget default text (huggingface#6100)

add widget

* [Advanced Training Script] Fix pipe example (huggingface#6106)

* IP-Adapter for StableDiffusionControlNetImg2ImgPipeline (huggingface#5901)

* adapter for StableDiffusionControlNetImg2ImgPipeline

* fix-copies

* fix-copies

---------

Co-authored-by: Sayak Paul <spsayakpaul@gmail.com>

* IP adapter support for most pipelines (huggingface#5900)

* support ip-adapter in src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_upscale.py

* support ip-adapter in src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_attend_and_excite.py

* support ip-adapter in src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_instruct_pix2pix.py

* update tests

* support ip-adapter in src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_panorama.py

* support ip-adapter in src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_sag.py

* support ip-adapter in src/diffusers/pipelines/stable_diffusion_safe/pipeline_stable_diffusion_safe.py

* support ip-adapter in src/diffusers/pipelines/latent_consistency_models/pipeline_latent_consistency_text2img.py

* support ip-adapter in src/diffusers/pipelines/latent_consistency_models/pipeline_latent_consistency_img2img.py

* support ip-adapter in src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_ldm3d.py

* revert changes to sd_attend_and_excite and sd_upscale

* make style

* fix broken tests

* update ip-adapter implementation to latest

* apply suggestions from review

---------

Co-authored-by: YiYi Xu <yixu310@gmail.com>
Co-authored-by: Sayak Paul <spsayakpaul@gmail.com>

* fix: lora_alpha

* make vae casting conditional/

* param upcasting

* propagate comments from huggingface#6145

Co-authored-by: dg845 <dgu8957@gmail.com>

* [Peft] fix saving / loading when unet is not "unet" (huggingface#6046)

* [Peft] fix saving / loading when unet is not "unet"

* Update src/diffusers/loaders/lora.py

Co-authored-by: Sayak Paul <spsayakpaul@gmail.com>

* undo stablediffusion-xl changes

* use unet_name to get unet for lora helpers

* use unet_name

---------

Co-authored-by: Sayak Paul <spsayakpaul@gmail.com>

* [Wuerstchen] fix fp16 training and correct lora args (huggingface#6245)

fix fp16 training

Co-authored-by: Sayak Paul <spsayakpaul@gmail.com>

* [docs] fix: animatediff docs (huggingface#6339)

fix: animatediff docs

* add: note about the new script in readme_sdxl.

* Revert "[Peft] fix saving / loading when unet is not "unet" (huggingface#6046)"

This reverts commit 4c7e983.

* Revert "[Wuerstchen] fix fp16 training and correct lora args (huggingface#6245)"

This reverts commit 0bb9cf0.

* Revert "[docs] fix: animatediff docs (huggingface#6339)"

This reverts commit 11659a6.

* remove tokenize_prompt().

* assistive comments around enable_adapters() and diable_adapters().

---------

Co-authored-by: Suraj Patil <surajp815@gmail.com>
Co-authored-by: Younes Belkada <49240599+younesbelkada@users.noreply.github.com>
Co-authored-by: Fabio Rigano <57982783+fabiorigano@users.noreply.github.com>
Co-authored-by: YiYi Xu <yixu310@gmail.com>
Co-authored-by: Patrick von Platen <patrick.v.platen@gmail.com>
Co-authored-by: yiyixuxu <yixu310@gmail,com>
Co-authored-by: apolinário <joaopaulo.passos@gmail.com>
Co-authored-by: Charchit Sharma <charchitsharma11@gmail.com>
Co-authored-by: Aryan V S <contact.aryanvs@gmail.com>
Co-authored-by: dg845 <dgu8957@gmail.com>
Co-authored-by: Kashif Rasul <kashif.rasul@gmail.com>
AmericanPresidentJimmyCarter pushed a commit to AmericanPresidentJimmyCarter/diffusers that referenced this pull request Apr 26, 2024
* add: script to train lcm lora for sdxl with 🤗 datasets

* suit up the args.

* remove comments.

* fix num_update_steps

* fix batch unmarshalling

* fix num_update_steps_per_epoch

* fix; dataloading.

* fix microconditions.

* unconditional predictions debug

* fix batch size.

* no need to use use_auth_token

* Apply suggestions from code review

Co-authored-by: Suraj Patil <surajp815@gmail.com>

* make vae encoding batch size an arg

* final serialization in kohya

* style

* state dict rejigging

* feat: no separate teacher unet.

* debug

* fix state dict serialization

* debug

* debug

* debug

* remove prints.

* remove kohya utility and make style

* fix serialization

* fix

* add test

* add peft dependency.

* add: peft

* remove peft

* autocast device determination from accelerator

* autocast

* reduce lora rank.

* remove unneeded space

* Apply suggestions from code review

Co-authored-by: Suraj Patil <surajp815@gmail.com>

* style

* remove prompt dropout.

* also save in native diffusers ckpt format.

* debug

* debug

* debug

* better formation of the null embeddings.

* remove space.

* autocast fixes.

* autocast fix.

* hacky

* remove lora_sayak

* Apply suggestions from code review

Co-authored-by: Younes Belkada <49240599+younesbelkada@users.noreply.github.com>

* style

* make log validation leaner.

* move back enabled in.

* fix: log_validation call.

* add: checkpointing tests

* taking my chances to see if disabling autocasting has any effect?

* start debugging

* name

* name

* name

* more debug

* more debug

* index

* remove index.

* print length

* print length

* print length

* move unet.train() after add_adapter()

* disable some prints.

* enable_adapters() manually.

* remove prints.

* some changes.

* fix params_to_optimize

* more fixes

* debug

* debug

* remove print

* disable grad for certain contexts.

* Add support for IPAdapterFull (huggingface#5911)

* Add support for IPAdapterFull


Co-authored-by: Patrick von Platen <patrick.v.platen@gmail.com>

---------

Co-authored-by: YiYi Xu <yixu310@gmail.com>
Co-authored-by: Patrick von Platen <patrick.v.platen@gmail.com>

* Fix a bug in `add_noise` function  (huggingface#6085)

* fix

* copies

---------

Co-authored-by: yiyixuxu <yixu310@gmail,com>

* [Advanced Diffusion Script] Add Widget default text (huggingface#6100)

add widget

* [Advanced Training Script] Fix pipe example (huggingface#6106)

* IP-Adapter for StableDiffusionControlNetImg2ImgPipeline (huggingface#5901)

* adapter for StableDiffusionControlNetImg2ImgPipeline

* fix-copies

* fix-copies

---------

Co-authored-by: Sayak Paul <spsayakpaul@gmail.com>

* IP adapter support for most pipelines (huggingface#5900)

* support ip-adapter in src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_upscale.py

* support ip-adapter in src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_attend_and_excite.py

* support ip-adapter in src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_instruct_pix2pix.py

* update tests

* support ip-adapter in src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_panorama.py

* support ip-adapter in src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_sag.py

* support ip-adapter in src/diffusers/pipelines/stable_diffusion_safe/pipeline_stable_diffusion_safe.py

* support ip-adapter in src/diffusers/pipelines/latent_consistency_models/pipeline_latent_consistency_text2img.py

* support ip-adapter in src/diffusers/pipelines/latent_consistency_models/pipeline_latent_consistency_img2img.py

* support ip-adapter in src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_ldm3d.py

* revert changes to sd_attend_and_excite and sd_upscale

* make style

* fix broken tests

* update ip-adapter implementation to latest

* apply suggestions from review

---------

Co-authored-by: YiYi Xu <yixu310@gmail.com>
Co-authored-by: Sayak Paul <spsayakpaul@gmail.com>

* fix: lora_alpha

* make vae casting conditional/

* param upcasting

* propagate comments from huggingface#6145

Co-authored-by: dg845 <dgu8957@gmail.com>

* [Peft] fix saving / loading when unet is not "unet" (huggingface#6046)

* [Peft] fix saving / loading when unet is not "unet"

* Update src/diffusers/loaders/lora.py

Co-authored-by: Sayak Paul <spsayakpaul@gmail.com>

* undo stablediffusion-xl changes

* use unet_name to get unet for lora helpers

* use unet_name

---------

Co-authored-by: Sayak Paul <spsayakpaul@gmail.com>

* [Wuerstchen] fix fp16 training and correct lora args (huggingface#6245)

fix fp16 training

Co-authored-by: Sayak Paul <spsayakpaul@gmail.com>

* [docs] fix: animatediff docs (huggingface#6339)

fix: animatediff docs

* add: note about the new script in readme_sdxl.

* Revert "[Peft] fix saving / loading when unet is not "unet" (huggingface#6046)"

This reverts commit 4c7e983.

* Revert "[Wuerstchen] fix fp16 training and correct lora args (huggingface#6245)"

This reverts commit 0bb9cf0.

* Revert "[docs] fix: animatediff docs (huggingface#6339)"

This reverts commit 11659a6.

* remove tokenize_prompt().

* assistive comments around enable_adapters() and diable_adapters().

---------

Co-authored-by: Suraj Patil <surajp815@gmail.com>
Co-authored-by: Younes Belkada <49240599+younesbelkada@users.noreply.github.com>
Co-authored-by: Fabio Rigano <57982783+fabiorigano@users.noreply.github.com>
Co-authored-by: YiYi Xu <yixu310@gmail.com>
Co-authored-by: Patrick von Platen <patrick.v.platen@gmail.com>
Co-authored-by: yiyixuxu <yixu310@gmail,com>
Co-authored-by: apolinário <joaopaulo.passos@gmail.com>
Co-authored-by: Charchit Sharma <charchitsharma11@gmail.com>
Co-authored-by: Aryan V S <contact.aryanvs@gmail.com>
Co-authored-by: dg845 <dgu8957@gmail.com>
Co-authored-by: Kashif Rasul <kashif.rasul@gmail.com>
# for free to join this conversation on GitHub. Already have an account? # to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.