Skip to content
New issue

Have a question about this project? # for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “#”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? # to your account

[Training] Add datasets version of LCM LoRA SDXL #5778

Merged
merged 109 commits into from
Dec 26, 2023
Merged
Changes from all commits
Commits
Show all changes
109 commits
Select commit Hold shift + click to select a range
ca7f220
add: script to train lcm lora for sdxl with 🤗 datasets
sayakpaul Nov 13, 2023
88efd15
suit up the args.
sayakpaul Nov 13, 2023
9e49fd2
remove comments.
sayakpaul Nov 13, 2023
728aa8a
fix num_update_steps
sayakpaul Nov 13, 2023
bc8cfdd
fix batch unmarshalling
sayakpaul Nov 13, 2023
8c4d4b6
fix num_update_steps_per_epoch
sayakpaul Nov 13, 2023
6d2f740
fix; dataloading.
sayakpaul Nov 13, 2023
c7f2828
fix microconditions.
sayakpaul Nov 13, 2023
df70754
unconditional predictions debug
sayakpaul Nov 13, 2023
dd93227
fix batch size.
sayakpaul Nov 13, 2023
3d4b1da
no need to use use_auth_token
sayakpaul Nov 13, 2023
7967247
Apply suggestions from code review
sayakpaul Nov 13, 2023
6b2e42f
make vae encoding batch size an arg
sayakpaul Nov 13, 2023
d7f632e
final serialization in kohya
sayakpaul Nov 13, 2023
e4edb31
style
sayakpaul Nov 13, 2023
858009b
Merge branch 'main' into lcm-lora-sdxl-datasets
sayakpaul Nov 14, 2023
6aa2dd8
state dict rejigging
sayakpaul Nov 14, 2023
1fd3378
feat: no separate teacher unet.
sayakpaul Nov 14, 2023
4135414
debug
sayakpaul Nov 14, 2023
3b066d2
fix state dict serialization
sayakpaul Nov 14, 2023
fc5546f
debug
sayakpaul Nov 14, 2023
ba0d0f2
debug
sayakpaul Nov 14, 2023
35e30fb
debug
sayakpaul Nov 14, 2023
53c13f7
remove prints.
sayakpaul Nov 14, 2023
cff23ed
remove kohya utility and make style
sayakpaul Nov 14, 2023
ca076c7
fix serialization
sayakpaul Nov 14, 2023
808f61e
fix
sayakpaul Nov 14, 2023
842df25
add test
sayakpaul Nov 14, 2023
0027673
add peft dependency.
sayakpaul Nov 14, 2023
c625553
add: peft
sayakpaul Nov 14, 2023
c5317ff
remove peft
sayakpaul Nov 14, 2023
6a690ab
autocast device determination from accelerator
sayakpaul Nov 14, 2023
8c4eaf6
autocast
sayakpaul Nov 14, 2023
cece781
reduce lora rank.
sayakpaul Nov 14, 2023
beb8aa2
remove unneeded space
sayakpaul Nov 14, 2023
33cb9d0
Apply suggestions from code review
sayakpaul Nov 14, 2023
795cc9f
style
sayakpaul Nov 14, 2023
042f357
remove prompt dropout.
sayakpaul Nov 14, 2023
283af65
also save in native diffusers ckpt format.
sayakpaul Nov 14, 2023
5e099a2
debug
sayakpaul Nov 14, 2023
71db43a
debug
sayakpaul Nov 14, 2023
e1346d5
debug
sayakpaul Nov 14, 2023
dfcf234
better formation of the null embeddings.
sayakpaul Nov 14, 2023
5ce6cc1
remove space.
sayakpaul Nov 14, 2023
7ee9d5d
autocast fixes.
sayakpaul Nov 14, 2023
1b359ae
autocast fix.
sayakpaul Nov 14, 2023
82b628a
hacky
sayakpaul Nov 14, 2023
17d5c0d
remove lora_sayak
sayakpaul Nov 16, 2023
fea95e0
Apply suggestions from code review
sayakpaul Nov 16, 2023
83801a6
style
sayakpaul Nov 16, 2023
0c5d934
make log validation leaner.
sayakpaul Nov 16, 2023
3b034be
Merge branch 'main' into lcm-lora-sdxl-datasets
sayakpaul Nov 16, 2023
0f42185
move back enabled in.
sayakpaul Nov 16, 2023
41f1925
fix: log_validation call.
sayakpaul Nov 16, 2023
bf5c5d6
add: checkpointing tests
sayakpaul Nov 16, 2023
64063c7
Merge branch 'main' into lcm-lora-sdxl-datasets
sayakpaul Nov 17, 2023
53cf0e7
Merge branch 'main' into lcm-lora-sdxl-datasets
sayakpaul Nov 17, 2023
de958dc
Merge branch 'main' into lcm-lora-sdxl-datasets
sayakpaul Nov 17, 2023
5824fa3
Merge branch 'main' into lcm-lora-sdxl-datasets
sayakpaul Nov 27, 2023
f52cb6e
Merge branch 'main' into lcm-lora-sdxl-datasets
sayakpaul Nov 27, 2023
5534b0c
taking my chances to see if disabling autocasting has any effect?
sayakpaul Nov 27, 2023
3bacd82
resolve conflicts
sayakpaul Nov 30, 2023
1da3071
start debugging
sayakpaul Nov 30, 2023
bd4d1c4
name
sayakpaul Nov 30, 2023
26f16c1
name
sayakpaul Nov 30, 2023
9174027
name
sayakpaul Nov 30, 2023
92ba868
more debug
sayakpaul Nov 30, 2023
1fba251
more debug
sayakpaul Nov 30, 2023
3751ca9
index
sayakpaul Nov 30, 2023
63649d3
remove index.
sayakpaul Nov 30, 2023
05de542
print length
sayakpaul Nov 30, 2023
5e604a8
print length
sayakpaul Nov 30, 2023
8fecdda
print length
sayakpaul Nov 30, 2023
023866f
move unet.train() after add_adapter()
sayakpaul Dec 1, 2023
07c28de
disable some prints.
sayakpaul Dec 1, 2023
c6a61da
enable_adapters() manually.
sayakpaul Dec 1, 2023
ec33085
remove prints.
sayakpaul Dec 2, 2023
d14dd41
Merge branch 'main' into lcm-lora-sdxl-datasets
sayakpaul Dec 3, 2023
ed7969d
some changes.
sayakpaul Dec 3, 2023
8c549e4
fix params_to_optimize
sayakpaul Dec 3, 2023
9446066
more fixes
sayakpaul Dec 3, 2023
0153665
debug
sayakpaul Dec 3, 2023
b9891ff
debug
sayakpaul Dec 3, 2023
b11b0a6
remove print
sayakpaul Dec 3, 2023
539bda3
disable grad for certain contexts.
sayakpaul Dec 3, 2023
dfe916d
Merge branch 'main' into lcm-lora-sdxl-datasets
patil-suraj Dec 7, 2023
d5a40cd
Add support for IPAdapterFull (#5911)
fabiorigano Dec 7, 2023
e3d76c4
Fix a bug in `add_noise` function (#6085)
yiyixuxu Dec 7, 2023
472c397
[Advanced Diffusion Script] Add Widget default text (#6100)
apolinario Dec 8, 2023
373d392
[Advanced Training Script] Fix pipe example (#6106)
apolinario Dec 8, 2023
be46b6e
IP-Adapter for StableDiffusionControlNetImg2ImgPipeline (#5901)
charchit7 Dec 9, 2023
c7a87ca
IP adapter support for most pipelines (#5900)
a-r-r-o-w Dec 10, 2023
556b797
resolve conflicts
sayakpaul Dec 15, 2023
a8d9785
Merge branch 'main' into lcm-lora-sdxl-datasets
sayakpaul Dec 20, 2023
47abcf6
fix: lora_alpha
sayakpaul Dec 20, 2023
b7c0f95
make vae casting conditional/
sayakpaul Dec 20, 2023
7a1d6c9
param upcasting
sayakpaul Dec 20, 2023
87f87a7
propagate comments from https://github.com/huggingface/diffusers/pull…
sayakpaul Dec 20, 2023
404351f
Merge branch 'main' into lcm-lora-sdxl-datasets
sayakpaul Dec 26, 2023
4c7e983
[Peft] fix saving / loading when unet is not "unet" (#6046)
kashif Dec 26, 2023
0bb9cf0
[Wuerstchen] fix fp16 training and correct lora args (#6245)
kashif Dec 26, 2023
11659a6
[docs] fix: animatediff docs (#6339)
sayakpaul Dec 26, 2023
f645b87
add: note about the new script in readme_sdxl.
sayakpaul Dec 26, 2023
fd64acf
Revert "[Peft] fix saving / loading when unet is not "unet" (#6046)"
sayakpaul Dec 26, 2023
121567b
Revert "[Wuerstchen] fix fp16 training and correct lora args (#6245)"
sayakpaul Dec 26, 2023
c24626a
Revert "[docs] fix: animatediff docs (#6339)"
sayakpaul Dec 26, 2023
4c689b2
remove tokenize_prompt().
sayakpaul Dec 26, 2023
1b49fb9
assistive comments around enable_adapters() and diable_adapters().
sayakpaul Dec 26, 2023
9b3dbaa
Merge branch 'main' into lcm-lora-sdxl-datasets
sayakpaul Dec 26, 2023
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
@@ -161,6 +161,8 @@ def save_model_card(
base_model: {base_model}
instance_prompt: {instance_prompt}
license: openrail++
widget:
- text: '{validation_prompt if validation_prompt else instance_prompt}'
---
"""

36 changes: 35 additions & 1 deletion examples/consistency_distillation/README_sdxl.md
Original file line number Diff line number Diff line change
@@ -111,4 +111,38 @@ accelerate launch train_lcm_distill_lora_sdxl_wds.py \
--report_to=wandb \
--seed=453645634 \
--push_to_hub \
```
```

We provide another version for LCM LoRA SDXL that follows best practices of `peft` and leverages the `datasets` library for quick experimentation. The script doesn't load two UNets unlike `train_lcm_distill_lora_sdxl_wds.py` which reduces the memory requirements quite a bit.

Below is an example training command that trains an LCM LoRA on the [Pokemons dataset](https://huggingface.co/datasets/lambdalabs/pokemon-blip-captions):

```bash
export MODEL_NAME="stabilityai/stable-diffusion-xl-base-1.0"
export DATASET_NAME="lambdalabs/pokemon-blip-captions"
export VAE_PATH="madebyollin/sdxl-vae-fp16-fix"

accelerate launch train_lcm_distill_lora_sdxl.py \
--pretrained_teacher_model=${MODEL_NAME} \
--pretrained_vae_model_name_or_path=${VAE_PATH} \
--output_dir="pokemons-lora-lcm-sdxl" \
--mixed_precision="fp16" \
--dataset_name=$DATASET_NAME \
--resolution=1024 \
--train_batch_size=24 \
--gradient_accumulation_steps=1 \
--gradient_checkpointing \
--use_8bit_adam \
--lora_rank=64 \
--learning_rate=1e-4 \
--report_to="wandb" \
--lr_scheduler="constant" \
--lr_warmup_steps=0 \
--max_train_steps=3000 \
--checkpointing_steps=500 \
--validation_steps=50 \
--seed="0" \
--report_to="wandb" \
--push_to_hub
```

112 changes: 112 additions & 0 deletions examples/consistency_distillation/test_lcm_lora.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,112 @@
# coding=utf-8
# Copyright 2023 HuggingFace Inc.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import logging
import os
import sys
import tempfile

import safetensors


sys.path.append("..")
from test_examples_utils import ExamplesTestsAccelerate, run_command # noqa: E402


logging.basicConfig(level=logging.DEBUG)

logger = logging.getLogger()
stream_handler = logging.StreamHandler(sys.stdout)
logger.addHandler(stream_handler)


class TextToImageLCM(ExamplesTestsAccelerate):
def test_text_to_image_lcm_lora_sdxl(self):
with tempfile.TemporaryDirectory() as tmpdir:
test_args = f"""
examples/consistency_distillation/train_lcm_distill_lora_sdxl.py
--pretrained_teacher_model hf-internal-testing/tiny-stable-diffusion-xl-pipe
--dataset_name hf-internal-testing/dummy_image_text_data
--resolution 64
--lora_rank 4
--train_batch_size 1
--gradient_accumulation_steps 1
--max_train_steps 2
--learning_rate 5.0e-04
--scale_lr
--lr_scheduler constant
--lr_warmup_steps 0
--output_dir {tmpdir}
""".split()

run_command(self._launch_args + test_args)
# save_pretrained smoke test
self.assertTrue(os.path.isfile(os.path.join(tmpdir, "pytorch_lora_weights.safetensors")))

# make sure the state_dict has the correct naming in the parameters.
lora_state_dict = safetensors.torch.load_file(os.path.join(tmpdir, "pytorch_lora_weights.safetensors"))
is_lora = all("lora" in k for k in lora_state_dict.keys())
self.assertTrue(is_lora)

def test_text_to_image_lcm_lora_sdxl_checkpointing(self):
with tempfile.TemporaryDirectory() as tmpdir:
test_args = f"""
examples/consistency_distillation/train_lcm_distill_lora_sdxl.py
--pretrained_teacher_model hf-internal-testing/tiny-stable-diffusion-xl-pipe
--dataset_name hf-internal-testing/dummy_image_text_data
--resolution 64
--lora_rank 4
--train_batch_size 1
--gradient_accumulation_steps 1
--max_train_steps 7
--checkpointing_steps 2
--learning_rate 5.0e-04
--scale_lr
--lr_scheduler constant
--lr_warmup_steps 0
--output_dir {tmpdir}
""".split()

run_command(self._launch_args + test_args)

self.assertEqual(
{x for x in os.listdir(tmpdir) if "checkpoint" in x},
{"checkpoint-2", "checkpoint-4", "checkpoint-6"},
)

test_args = f"""
examples/consistency_distillation/train_lcm_distill_lora_sdxl.py
--pretrained_teacher_model hf-internal-testing/tiny-stable-diffusion-xl-pipe
--dataset_name hf-internal-testing/dummy_image_text_data
--resolution 64
--lora_rank 4
--train_batch_size 1
--gradient_accumulation_steps 1
--max_train_steps 9
--checkpointing_steps 2
--resume_from_checkpoint latest
--learning_rate 5.0e-04
--scale_lr
--lr_scheduler constant
--lr_warmup_steps 0
--output_dir {tmpdir}
""".split()

run_command(self._launch_args + test_args)

self.assertEqual(
{x for x in os.listdir(tmpdir) if "checkpoint" in x},
{"checkpoint-2", "checkpoint-4", "checkpoint-6", "checkpoint-8"},
)
Loading