Skip to content

Commit 5c1776e

Browse files
patil-surajkashif
authored andcommitted
Add LCM Scripts (huggingface#5727)
* add lcm scripts * Co-authored-by: dgu8957@gmail.com
1 parent 33a8556 commit 5c1776e

7 files changed

+5616
-0
lines changed
+104
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,104 @@
1+
# Latent Consistency Distillation Example:
2+
3+
[Latent Consistency Models (LCMs)](https://arxiv.org/abs/2310.04378) is method to distill latent diffusion model to enable swift inference with minimal steps. This example demonstrates how to use the latent consistency distillation to distill stable-diffusion-v1.5 for less timestep inference.
4+
5+
## Full model distillation
6+
7+
### Running locally with PyTorch
8+
9+
#### Installing the dependencies
10+
11+
Before running the scripts, make sure to install the library's training dependencies:
12+
13+
**Important**
14+
15+
To make sure you can successfully run the latest versions of the example scripts, we highly recommend **installing from source** and keeping the install up to date as we update the example scripts frequently and install some example-specific requirements. To do this, execute the following steps in a new virtual environment:
16+
```bash
17+
git clone https://github.com/huggingface/diffusers
18+
cd diffusers
19+
pip install -e .
20+
```
21+
22+
Then cd in the example folder and run
23+
```bash
24+
pip install -r requirements.txt
25+
```
26+
27+
And initialize an [🤗Accelerate](https://github.com/huggingface/accelerate/) environment with:
28+
29+
```bash
30+
accelerate config
31+
```
32+
33+
Or for a default accelerate configuration without answering questions about your environment
34+
35+
```bash
36+
accelerate config default
37+
```
38+
39+
Or if your environment doesn't support an interactive shell e.g. a notebook
40+
41+
```python
42+
from accelerate.utils import write_basic_config
43+
write_basic_config()
44+
```
45+
46+
When running `accelerate config`, if we specify torch compile mode to True there can be dramatic speedups.
47+
48+
49+
#### Example with LAION-A6+ dataset
50+
51+
```bash
52+
runwayml/stable-diffusion-v1-5
53+
PROGRAM="train_lcm_distill_sd_wds.py \
54+
--pretrained_teacher_model=$MODEL_DIR \
55+
--output_dir=$OUTPUT_DIR \
56+
--mixed_precision=fp16 \
57+
--resolution=512 \
58+
--learning_rate=1e-6 --loss_type="huber" --ema_decay=0.95 --adam_weight_decay=0.0 \
59+
--max_train_steps=1000 \
60+
--max_train_samples=4000000 \
61+
--dataloader_num_workers=8 \
62+
--train_shards_path_or_url='pipe:aws s3 cp s3://muse-datasets/laion-aesthetic6plus-min512-data/{00000..01210}.tar -' \
63+
--validation_steps=200 \
64+
--checkpointing_steps=200 --checkpoints_total_limit=10 \
65+
--train_batch_size=12 \
66+
--gradient_checkpointing --enable_xformers_memory_efficient_attention \
67+
--gradient_accumulation_steps=1 \
68+
--use_8bit_adam \
69+
--resume_from_checkpoint=latest \
70+
--report_to=wandb \
71+
--seed=453645634 \
72+
--push_to_hub \
73+
```
74+
75+
## LCM-LoRA
76+
77+
Instead of fine-tuning the full model, we can also just train a LoRA that can be injected into any SDXL model.
78+
79+
### Example with LAION-A6+ dataset
80+
81+
```bash
82+
runwayml/stable-diffusion-v1-5
83+
PROGRAM="train_lcm_distill_lora_sd_wds.py \
84+
--pretrained_teacher_model=$MODEL_DIR \
85+
--output_dir=$OUTPUT_DIR \
86+
--mixed_precision=fp16 \
87+
--resolution=512 \
88+
--lora_rank=64 \
89+
--learning_rate=1e-6 --loss_type="huber" --adam_weight_decay=0.0 \
90+
--max_train_steps=1000 \
91+
--max_train_samples=4000000 \
92+
--dataloader_num_workers=8 \
93+
--train_shards_path_or_url='pipe:aws s3 cp s3://muse-datasets/laion-aesthetic6plus-min512-data/{00000..01210}.tar -' \
94+
--validation_steps=200 \
95+
--checkpointing_steps=200 --checkpoints_total_limit=10 \
96+
--train_batch_size=12 \
97+
--gradient_checkpointing --enable_xformers_memory_efficient_attention \
98+
--gradient_accumulation_steps=1 \
99+
--use_8bit_adam \
100+
--resume_from_checkpoint=latest \
101+
--report_to=wandb \
102+
--seed=453645634 \
103+
--push_to_hub \
104+
```
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,106 @@
1+
# Latent Consistency Distillation Example:
2+
3+
[Latent Consistency Models (LCMs)](https://arxiv.org/abs/2310.04378) is method to distill latent diffusion model to enable swift inference with minimal steps. This example demonstrates how to use the latent consistency distillation to distill SDXL for less timestep inference.
4+
5+
## Full model distillation
6+
7+
### Running locally with PyTorch
8+
9+
#### Installing the dependencies
10+
11+
Before running the scripts, make sure to install the library's training dependencies:
12+
13+
**Important**
14+
15+
To make sure you can successfully run the latest versions of the example scripts, we highly recommend **installing from source** and keeping the install up to date as we update the example scripts frequently and install some example-specific requirements. To do this, execute the following steps in a new virtual environment:
16+
```bash
17+
git clone https://github.com/huggingface/diffusers
18+
cd diffusers
19+
pip install -e .
20+
```
21+
22+
Then cd in the example folder and run
23+
```bash
24+
pip install -r requirements.txt
25+
```
26+
27+
And initialize an [🤗Accelerate](https://github.com/huggingface/accelerate/) environment with:
28+
29+
```bash
30+
accelerate config
31+
```
32+
33+
Or for a default accelerate configuration without answering questions about your environment
34+
35+
```bash
36+
accelerate config default
37+
```
38+
39+
Or if your environment doesn't support an interactive shell e.g. a notebook
40+
41+
```python
42+
from accelerate.utils import write_basic_config
43+
write_basic_config()
44+
```
45+
46+
When running `accelerate config`, if we specify torch compile mode to True there can be dramatic speedups.
47+
48+
49+
#### Example with LAION-A6+ dataset
50+
51+
```bash
52+
export MODEL_DIR="stabilityai/stable-diffusion-xl-base-1.0"
53+
PROGRAM="train_lcm_distill_sdxl_wds.py \
54+
--pretrained_teacher_model=$MODEL_DIR \
55+
--pretrained_vae_model_name_or_path=madebyollin/sdxl-vae-fp16-fix \
56+
--output_dir=$OUTPUT_DIR \
57+
--mixed_precision=fp16 \
58+
--resolution=1024 \
59+
--learning_rate=1e-6 --loss_type="huber" --use_fix_crop_and_size --ema_decay=0.95 --adam_weight_decay=0.0 \
60+
--max_train_steps=1000 \
61+
--max_train_samples=4000000 \
62+
--dataloader_num_workers=8 \
63+
--train_shards_path_or_url='pipe:aws s3 cp s3://muse-datasets/laion-aesthetic6plus-min512-data/{00000..01210}.tar -' \
64+
--validation_steps=200 \
65+
--checkpointing_steps=200 --checkpoints_total_limit=10 \
66+
--train_batch_size=12 \
67+
--gradient_checkpointing --enable_xformers_memory_efficient_attention \
68+
--gradient_accumulation_steps=1 \
69+
--use_8bit_adam \
70+
--resume_from_checkpoint=latest \
71+
--report_to=wandb \
72+
--seed=453645634 \
73+
--push_to_hub \
74+
```
75+
76+
## LCM-LoRA
77+
78+
Instead of fine-tuning the full model, we can also just train a LoRA that can be injected into any SDXL model.
79+
80+
### Example with LAION-A6+ dataset
81+
82+
```bash
83+
export MODEL_DIR="stabilityai/stable-diffusion-xl-base-1.0"
84+
PROGRAM="train_lcm_distill_lora_sdxl_wds.py \
85+
--pretrained_teacher_model=$MODEL_DIR \
86+
--pretrained_vae_model_name_or_path=madebyollin/sdxl-vae-fp16-fix \
87+
--output_dir=$OUTPUT_DIR \
88+
--mixed_precision=fp16 \
89+
--resolution=1024 \
90+
--lora_rank=64 \
91+
--learning_rate=1e-6 --loss_type="huber" --use_fix_crop_and_size --adam_weight_decay=0.0 \
92+
--max_train_steps=1000 \
93+
--max_train_samples=4000000 \
94+
--dataloader_num_workers=8 \
95+
--train_shards_path_or_url='pipe:aws s3 cp s3://muse-datasets/laion-aesthetic6plus-min512-data/{00000..01210}.tar -' \
96+
--validation_steps=200 \
97+
--checkpointing_steps=200 --checkpoints_total_limit=10 \
98+
--train_batch_size=12 \
99+
--gradient_checkpointing --enable_xformers_memory_efficient_attention \
100+
--gradient_accumulation_steps=1 \
101+
--use_8bit_adam \
102+
--resume_from_checkpoint=latest \
103+
--report_to=wandb \
104+
--seed=453645634 \
105+
--push_to_hub \
106+
```
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,7 @@
1+
accelerate>=0.16.0
2+
torchvision
3+
transformers>=4.25.1
4+
ftfy
5+
tensorboard
6+
Jinja2
7+
webdataset

0 commit comments

Comments
 (0)