Skip to content

Commit 43ff401

Browse files
younesbelkadasayakpaul
authored andcommitted
[PEFT] Adapt example scripts to use PEFT (huggingface#5388)
* adapt example scripts to use PEFT * Update examples/text_to_image/train_text_to_image_lora.py * fix * add for SDXL * oops * make sure to install peft * fix * fix * fix dreambooth and lora * more fixes * add peft to requirements.txt * fix * final fix * add peft version in requirements * remove comment * change variable names * add few lines in readme * add to reqs * style * fix issues * fix lora dreambooth xl tests * init_lora_weights to gaussian and add out proj where missing * ammend requirements. * ammend requirements.txt * add correct peft versions --------- Co-authored-by: Sayak Paul <spsayakpaul@gmail.com>
1 parent 193ffe5 commit 43ff401

13 files changed

+121
-286
lines changed

.github/workflows/pr_tests.yml

+1
Original file line numberDiff line numberDiff line change
@@ -113,6 +113,7 @@ jobs:
113113
- name: Run example PyTorch CPU tests
114114
if: ${{ matrix.config.framework == 'pytorch_examples' }}
115115
run: |
116+
python -m pip install peft
116117
python -m pytest -n 2 --max-worker-restart=0 --dist=loadfile \
117118
--make-reports=tests_${{ matrix.config.report }} \
118119
examples

examples/dreambooth/README.md

+1
Original file line numberDiff line numberDiff line change
@@ -44,6 +44,7 @@ write_basic_config()
4444
```
4545

4646
When running `accelerate config`, if we specify torch compile mode to True there can be dramatic speedups.
47+
Note also that we use PEFT library as backend for LoRA training, make sure to have `peft>=0.6.0` installed in your environment.
4748

4849
### Dog toy example
4950

examples/dreambooth/README_sdxl.md

+1
Original file line numberDiff line numberDiff line change
@@ -47,6 +47,7 @@ write_basic_config()
4747
```
4848

4949
When running `accelerate config`, if we specify torch compile mode to True there can be dramatic speedups.
50+
Note also that we use PEFT library as backend for LoRA training, make sure to have `peft>=0.6.0` installed in your environment.
5051

5152
### Dog toy example
5253

examples/dreambooth/requirements.txt

+1
Original file line numberDiff line numberDiff line change
@@ -4,3 +4,4 @@ transformers>=4.25.1
44
ftfy
55
tensorboard
66
Jinja2
7+
peft==0.7.0

examples/dreambooth/requirements_sdxl.txt

+1
Original file line numberDiff line numberDiff line change
@@ -4,3 +4,4 @@ transformers>=4.25.1
44
ftfy
55
tensorboard
66
Jinja2
7+
peft==0.7.0

examples/dreambooth/train_dreambooth_lora.py

+27-99
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,6 @@
1616
import argparse
1717
import copy
1818
import gc
19-
import itertools
2019
import logging
2120
import math
2221
import os
@@ -35,6 +34,8 @@
3534
from huggingface_hub import create_repo, upload_folder
3635
from huggingface_hub.utils import insecure_hashlib
3736
from packaging import version
37+
from peft import LoraConfig
38+
from peft.utils import get_peft_model_state_dict
3839
from PIL import Image
3940
from PIL.ImageOps import exif_transpose
4041
from torch.utils.data import Dataset
@@ -52,14 +53,7 @@
5253
UNet2DConditionModel,
5354
)
5455
from diffusers.loaders import LoraLoaderMixin
55-
from diffusers.models.attention_processor import (
56-
AttnAddedKVProcessor,
57-
AttnAddedKVProcessor2_0,
58-
SlicedAttnAddedKVProcessor,
59-
)
60-
from diffusers.models.lora import LoRALinearLayer
6156
from diffusers.optimization import get_scheduler
62-
from diffusers.training_utils import unet_lora_state_dict
6357
from diffusers.utils import check_min_version, is_wandb_available
6458
from diffusers.utils.import_utils import is_xformers_available
6559

@@ -864,79 +858,19 @@ def main(args):
864858
text_encoder.gradient_checkpointing_enable()
865859

866860
# now we will add new LoRA weights to the attention layers
867-
# It's important to realize here how many attention weights will be added and of which sizes
868-
# The sizes of the attention layers consist only of two different variables:
869-
# 1) - the "hidden_size", which is increased according to `unet.config.block_out_channels`.
870-
# 2) - the "cross attention size", which is set to `unet.config.cross_attention_dim`.
871-
872-
# Let's first see how many attention processors we will have to set.
873-
# For Stable Diffusion, it should be equal to:
874-
# - down blocks (2x attention layers) * (2x transformer layers) * (3x down blocks) = 12
875-
# - mid blocks (2x attention layers) * (1x transformer layers) * (1x mid blocks) = 2
876-
# - up blocks (2x attention layers) * (3x transformer layers) * (3x up blocks) = 18
877-
# => 32 layers
878-
879-
# Set correct lora layers
880-
unet_lora_parameters = []
881-
for attn_processor_name, attn_processor in unet.attn_processors.items():
882-
# Parse the attention module.
883-
attn_module = unet
884-
for n in attn_processor_name.split(".")[:-1]:
885-
attn_module = getattr(attn_module, n)
886-
887-
# Set the `lora_layer` attribute of the attention-related matrices.
888-
attn_module.to_q.set_lora_layer(
889-
LoRALinearLayer(
890-
in_features=attn_module.to_q.in_features, out_features=attn_module.to_q.out_features, rank=args.rank
891-
)
892-
)
893-
attn_module.to_k.set_lora_layer(
894-
LoRALinearLayer(
895-
in_features=attn_module.to_k.in_features, out_features=attn_module.to_k.out_features, rank=args.rank
896-
)
897-
)
898-
attn_module.to_v.set_lora_layer(
899-
LoRALinearLayer(
900-
in_features=attn_module.to_v.in_features, out_features=attn_module.to_v.out_features, rank=args.rank
901-
)
902-
)
903-
attn_module.to_out[0].set_lora_layer(
904-
LoRALinearLayer(
905-
in_features=attn_module.to_out[0].in_features,
906-
out_features=attn_module.to_out[0].out_features,
907-
rank=args.rank,
908-
)
909-
)
910-
911-
# Accumulate the LoRA params to optimize.
912-
unet_lora_parameters.extend(attn_module.to_q.lora_layer.parameters())
913-
unet_lora_parameters.extend(attn_module.to_k.lora_layer.parameters())
914-
unet_lora_parameters.extend(attn_module.to_v.lora_layer.parameters())
915-
unet_lora_parameters.extend(attn_module.to_out[0].lora_layer.parameters())
916-
917-
if isinstance(attn_processor, (AttnAddedKVProcessor, SlicedAttnAddedKVProcessor, AttnAddedKVProcessor2_0)):
918-
attn_module.add_k_proj.set_lora_layer(
919-
LoRALinearLayer(
920-
in_features=attn_module.add_k_proj.in_features,
921-
out_features=attn_module.add_k_proj.out_features,
922-
rank=args.rank,
923-
)
924-
)
925-
attn_module.add_v_proj.set_lora_layer(
926-
LoRALinearLayer(
927-
in_features=attn_module.add_v_proj.in_features,
928-
out_features=attn_module.add_v_proj.out_features,
929-
rank=args.rank,
930-
)
931-
)
932-
unet_lora_parameters.extend(attn_module.add_k_proj.lora_layer.parameters())
933-
unet_lora_parameters.extend(attn_module.add_v_proj.lora_layer.parameters())
861+
unet_lora_config = LoraConfig(
862+
r=args.rank,
863+
init_lora_weights="gaussian",
864+
target_modules=["to_k", "to_q", "to_v", "to_out.0", "add_k_proj", "add_v_proj"],
865+
)
866+
unet.add_adapter(unet_lora_config)
934867

935-
# The text encoder comes from 🤗 transformers, so we cannot directly modify it.
936-
# So, instead, we monkey-patch the forward calls of its attention-blocks.
868+
# The text encoder comes from 🤗 transformers, we will also attach adapters to it.
937869
if args.train_text_encoder:
938-
# ensure that dtype is float32, even if rest of the model that isn't trained is loaded in fp16
939-
text_lora_parameters = LoraLoaderMixin._modify_text_encoder(text_encoder, dtype=torch.float32, rank=args.rank)
870+
text_lora_config = LoraConfig(
871+
r=args.rank, init_lora_weights="gaussian", target_modules=["q_proj", "k_proj", "v_proj", "out_proj"]
872+
)
873+
text_encoder.add_adapter(text_lora_config)
940874

941875
# create custom saving & loading hooks so that `accelerator.save_state(...)` serializes in a nice format
942876
def save_model_hook(models, weights, output_dir):
@@ -948,9 +882,9 @@ def save_model_hook(models, weights, output_dir):
948882

949883
for model in models:
950884
if isinstance(model, type(accelerator.unwrap_model(unet))):
951-
unet_lora_layers_to_save = unet_lora_state_dict(model)
885+
unet_lora_layers_to_save = get_peft_model_state_dict(model)
952886
elif isinstance(model, type(accelerator.unwrap_model(text_encoder))):
953-
text_encoder_lora_layers_to_save = text_encoder_lora_state_dict(model)
887+
text_encoder_lora_layers_to_save = get_peft_model_state_dict(model)
954888
else:
955889
raise ValueError(f"unexpected save model: {model.__class__}")
956890

@@ -1010,11 +944,10 @@ def load_model_hook(models, input_dir):
1010944
optimizer_class = torch.optim.AdamW
1011945

1012946
# Optimizer creation
1013-
params_to_optimize = (
1014-
itertools.chain(unet_lora_parameters, text_lora_parameters)
1015-
if args.train_text_encoder
1016-
else unet_lora_parameters
1017-
)
947+
params_to_optimize = list(filter(lambda p: p.requires_grad, unet.parameters()))
948+
if args.train_text_encoder:
949+
params_to_optimize = params_to_optimize + list(filter(lambda p: p.requires_grad, text_encoder.parameters()))
950+
1018951
optimizer = optimizer_class(
1019952
params_to_optimize,
1020953
lr=args.learning_rate,
@@ -1257,12 +1190,7 @@ def compute_text_embeddings(prompt):
12571190

12581191
accelerator.backward(loss)
12591192
if accelerator.sync_gradients:
1260-
params_to_clip = (
1261-
itertools.chain(unet_lora_parameters, text_lora_parameters)
1262-
if args.train_text_encoder
1263-
else unet_lora_parameters
1264-
)
1265-
accelerator.clip_grad_norm_(params_to_clip, args.max_grad_norm)
1193+
accelerator.clip_grad_norm_(params_to_optimize, args.max_grad_norm)
12661194
optimizer.step()
12671195
lr_scheduler.step()
12681196
optimizer.zero_grad()
@@ -1385,19 +1313,19 @@ def compute_text_embeddings(prompt):
13851313
if accelerator.is_main_process:
13861314
unet = accelerator.unwrap_model(unet)
13871315
unet = unet.to(torch.float32)
1388-
unet_lora_layers = unet_lora_state_dict(unet)
13891316

1390-
if text_encoder is not None and args.train_text_encoder:
1317+
unet_lora_state_dict = get_peft_model_state_dict(unet)
1318+
1319+
if args.train_text_encoder:
13911320
text_encoder = accelerator.unwrap_model(text_encoder)
1392-
text_encoder = text_encoder.to(torch.float32)
1393-
text_encoder_lora_layers = text_encoder_lora_state_dict(text_encoder)
1321+
text_encoder_state_dict = get_peft_model_state_dict(text_encoder)
13941322
else:
1395-
text_encoder_lora_layers = None
1323+
text_encoder_state_dict = None
13961324

13971325
LoraLoaderMixin.save_lora_weights(
13981326
save_directory=args.output_dir,
1399-
unet_lora_layers=unet_lora_layers,
1400-
text_encoder_lora_layers=text_encoder_lora_layers,
1327+
unet_lora_layers=unet_lora_state_dict,
1328+
text_encoder_lora_layers=text_encoder_state_dict,
14011329
)
14021330

14031331
# Final inference

examples/dreambooth/train_dreambooth_lora_sdxl.py

+23-67
Original file line numberDiff line numberDiff line change
@@ -34,6 +34,8 @@
3434
from huggingface_hub import create_repo, upload_folder
3535
from huggingface_hub.utils import insecure_hashlib
3636
from packaging import version
37+
from peft import LoraConfig
38+
from peft.utils import get_peft_model_state_dict
3739
from PIL import Image
3840
from PIL.ImageOps import exif_transpose
3941
from torch.utils.data import Dataset
@@ -50,9 +52,8 @@
5052
UNet2DConditionModel,
5153
)
5254
from diffusers.loaders import LoraLoaderMixin
53-
from diffusers.models.lora import LoRALinearLayer
5455
from diffusers.optimization import get_scheduler
55-
from diffusers.training_utils import compute_snr, unet_lora_state_dict
56+
from diffusers.training_utils import compute_snr
5657
from diffusers.utils import check_min_version, is_wandb_available
5758
from diffusers.utils.import_utils import is_xformers_available
5859

@@ -1009,54 +1010,19 @@ def main(args):
10091010
text_encoder_two.gradient_checkpointing_enable()
10101011

10111012
# now we will add new LoRA weights to the attention layers
1012-
# Set correct lora layers
1013-
unet_lora_parameters = []
1014-
for attn_processor_name, attn_processor in unet.attn_processors.items():
1015-
# Parse the attention module.
1016-
attn_module = unet
1017-
for n in attn_processor_name.split(".")[:-1]:
1018-
attn_module = getattr(attn_module, n)
1019-
1020-
# Set the `lora_layer` attribute of the attention-related matrices.
1021-
attn_module.to_q.set_lora_layer(
1022-
LoRALinearLayer(
1023-
in_features=attn_module.to_q.in_features, out_features=attn_module.to_q.out_features, rank=args.rank
1024-
)
1025-
)
1026-
attn_module.to_k.set_lora_layer(
1027-
LoRALinearLayer(
1028-
in_features=attn_module.to_k.in_features, out_features=attn_module.to_k.out_features, rank=args.rank
1029-
)
1030-
)
1031-
attn_module.to_v.set_lora_layer(
1032-
LoRALinearLayer(
1033-
in_features=attn_module.to_v.in_features, out_features=attn_module.to_v.out_features, rank=args.rank
1034-
)
1035-
)
1036-
attn_module.to_out[0].set_lora_layer(
1037-
LoRALinearLayer(
1038-
in_features=attn_module.to_out[0].in_features,
1039-
out_features=attn_module.to_out[0].out_features,
1040-
rank=args.rank,
1041-
)
1042-
)
1043-
1044-
# Accumulate the LoRA params to optimize.
1045-
unet_lora_parameters.extend(attn_module.to_q.lora_layer.parameters())
1046-
unet_lora_parameters.extend(attn_module.to_k.lora_layer.parameters())
1047-
unet_lora_parameters.extend(attn_module.to_v.lora_layer.parameters())
1048-
unet_lora_parameters.extend(attn_module.to_out[0].lora_layer.parameters())
1013+
unet_lora_config = LoraConfig(
1014+
r=args.rank, init_lora_weights="gaussian", target_modules=["to_k", "to_q", "to_v", "to_out.0"]
1015+
)
1016+
unet.add_adapter(unet_lora_config)
10491017

10501018
# The text encoder comes from 🤗 transformers, so we cannot directly modify it.
10511019
# So, instead, we monkey-patch the forward calls of its attention-blocks.
10521020
if args.train_text_encoder:
1053-
# ensure that dtype is float32, even if rest of the model that isn't trained is loaded in fp16
1054-
text_lora_parameters_one = LoraLoaderMixin._modify_text_encoder(
1055-
text_encoder_one, dtype=torch.float32, rank=args.rank
1056-
)
1057-
text_lora_parameters_two = LoraLoaderMixin._modify_text_encoder(
1058-
text_encoder_two, dtype=torch.float32, rank=args.rank
1021+
text_lora_config = LoraConfig(
1022+
r=args.rank, init_lora_weights="gaussian", target_modules=["q_proj", "k_proj", "v_proj", "out_proj"]
10591023
)
1024+
text_encoder_one.add_adapter(text_lora_config)
1025+
text_encoder_two.add_adapter(text_lora_config)
10601026

10611027
# create custom saving & loading hooks so that `accelerator.save_state(...)` serializes in a nice format
10621028
def save_model_hook(models, weights, output_dir):
@@ -1069,11 +1035,11 @@ def save_model_hook(models, weights, output_dir):
10691035

10701036
for model in models:
10711037
if isinstance(model, type(accelerator.unwrap_model(unet))):
1072-
unet_lora_layers_to_save = unet_lora_state_dict(model)
1038+
unet_lora_layers_to_save = get_peft_model_state_dict(model)
10731039
elif isinstance(model, type(accelerator.unwrap_model(text_encoder_one))):
1074-
text_encoder_one_lora_layers_to_save = text_encoder_lora_state_dict(model)
1040+
text_encoder_one_lora_layers_to_save = get_peft_model_state_dict(model)
10751041
elif isinstance(model, type(accelerator.unwrap_model(text_encoder_two))):
1076-
text_encoder_two_lora_layers_to_save = text_encoder_lora_state_dict(model)
1042+
text_encoder_two_lora_layers_to_save = get_peft_model_state_dict(model)
10771043
else:
10781044
raise ValueError(f"unexpected save model: {model.__class__}")
10791045

@@ -1130,6 +1096,12 @@ def load_model_hook(models, input_dir):
11301096
args.learning_rate * args.gradient_accumulation_steps * args.train_batch_size * accelerator.num_processes
11311097
)
11321098

1099+
unet_lora_parameters = list(filter(lambda p: p.requires_grad, unet.parameters()))
1100+
1101+
if args.train_text_encoder:
1102+
text_lora_parameters_one = list(filter(lambda p: p.requires_grad, text_encoder_one.parameters()))
1103+
text_lora_parameters_two = list(filter(lambda p: p.requires_grad, text_encoder_two.parameters()))
1104+
11331105
# Optimization parameters
11341106
unet_lora_parameters_with_lr = {"params": unet_lora_parameters, "lr": args.learning_rate}
11351107
if args.train_text_encoder:
@@ -1194,26 +1166,10 @@ def load_model_hook(models, input_dir):
11941166

11951167
optimizer_class = prodigyopt.Prodigy
11961168

1197-
if args.learning_rate <= 0.1:
1198-
logger.warn(
1199-
"Learning rate is too low. When using prodigy, it's generally better to set learning rate around 1.0"
1200-
)
1201-
if args.train_text_encoder and args.text_encoder_lr:
1202-
logger.warn(
1203-
f"Learning rates were provided both for the unet and the text encoder- e.g. text_encoder_lr:"
1204-
f" {args.text_encoder_lr} and learning_rate: {args.learning_rate}. "
1205-
f"When using prodigy only learning_rate is used as the initial learning rate."
1206-
)
1207-
# changes the learning rate of text_encoder_parameters_one and text_encoder_parameters_two to be
1208-
# --learning_rate
1209-
params_to_optimize[1]["lr"] = args.learning_rate
1210-
params_to_optimize[2]["lr"] = args.learning_rate
1211-
12121169
optimizer = optimizer_class(
12131170
params_to_optimize,
12141171
lr=args.learning_rate,
12151172
betas=(args.adam_beta1, args.adam_beta2),
1216-
beta3=args.prodigy_beta3,
12171173
weight_decay=args.adam_weight_decay,
12181174
eps=args.adam_epsilon,
12191175
decouple=args.prodigy_decouple,
@@ -1659,13 +1615,13 @@ def compute_text_embeddings(prompt, text_encoders, tokenizers):
16591615
if accelerator.is_main_process:
16601616
unet = accelerator.unwrap_model(unet)
16611617
unet = unet.to(torch.float32)
1662-
unet_lora_layers = unet_lora_state_dict(unet)
1618+
unet_lora_layers = get_peft_model_state_dict(unet)
16631619

16641620
if args.train_text_encoder:
16651621
text_encoder_one = accelerator.unwrap_model(text_encoder_one)
1666-
text_encoder_lora_layers = text_encoder_lora_state_dict(text_encoder_one.to(torch.float32))
1622+
text_encoder_lora_layers = get_peft_model_state_dict(text_encoder_one.to(torch.float32))
16671623
text_encoder_two = accelerator.unwrap_model(text_encoder_two)
1668-
text_encoder_2_lora_layers = text_encoder_lora_state_dict(text_encoder_two.to(torch.float32))
1624+
text_encoder_2_lora_layers = get_peft_model_state_dict(text_encoder_two.to(torch.float32))
16691625
else:
16701626
text_encoder_lora_layers = None
16711627
text_encoder_2_lora_layers = None

examples/text_to_image/README.md

+2
Original file line numberDiff line numberDiff line change
@@ -32,6 +32,8 @@ And initialize an [🤗Accelerate](https://github.com/huggingface/accelerate/) e
3232
accelerate config
3333
```
3434

35+
Note also that we use PEFT library as backend for LoRA training, make sure to have `peft>=0.6.0` installed in your environment.
36+
3537
### Pokemon example
3638

3739
You need to accept the model license before downloading or using the weights. In this example we'll use model version `v1-4`, so you'll need to visit [its card](https://huggingface.co/CompVis/stable-diffusion-v1-4), read the license and tick the checkbox if you agree.

examples/text_to_image/README_sdxl.md

+1
Original file line numberDiff line numberDiff line change
@@ -45,6 +45,7 @@ write_basic_config()
4545
```
4646

4747
When running `accelerate config`, if we specify torch compile mode to True there can be dramatic speedups.
48+
Note also that we use PEFT library as backend for LoRA training, make sure to have `peft>=0.6.0` installed in your environment.
4849

4950
### Training
5051

0 commit comments

Comments
 (0)