Skip to content

Commit 9cc96a6

Browse files
[FIX] Fix TypeError in DreamBooth SDXL when use_dora is False (#9879)
* fix use_dora * fix style and quality * fix use_dora with peft version --------- Co-authored-by: Sayak Paul <spsayakpaul@gmail.com>
1 parent 5b972fb commit 9cc96a6

File tree

1 file changed

+22
-14
lines changed

1 file changed

+22
-14
lines changed

Diff for: examples/dreambooth/train_dreambooth_lora_sdxl.py

+22-14
Original file line numberDiff line numberDiff line change
@@ -67,6 +67,7 @@
6767
convert_state_dict_to_diffusers,
6868
convert_state_dict_to_kohya,
6969
convert_unet_state_dict_to_peft,
70+
is_peft_version,
7071
is_wandb_available,
7172
)
7273
from diffusers.utils.hub_utils import load_or_create_model_card, populate_model_card
@@ -1183,26 +1184,33 @@ def main(args):
11831184
text_encoder_one.gradient_checkpointing_enable()
11841185
text_encoder_two.gradient_checkpointing_enable()
11851186

1187+
def get_lora_config(rank, use_dora, target_modules):
1188+
base_config = {
1189+
"r": rank,
1190+
"lora_alpha": rank,
1191+
"init_lora_weights": "gaussian",
1192+
"target_modules": target_modules,
1193+
}
1194+
if use_dora:
1195+
if is_peft_version("<", "0.9.0"):
1196+
raise ValueError(
1197+
"You need `peft` 0.9.0 at least to use DoRA-enabled LoRAs. Please upgrade your installation of `peft`."
1198+
)
1199+
else:
1200+
base_config["use_dora"] = True
1201+
1202+
return LoraConfig(**base_config)
1203+
11861204
# now we will add new LoRA weights to the attention layers
1187-
unet_lora_config = LoraConfig(
1188-
r=args.rank,
1189-
use_dora=args.use_dora,
1190-
lora_alpha=args.rank,
1191-
init_lora_weights="gaussian",
1192-
target_modules=["to_k", "to_q", "to_v", "to_out.0"],
1193-
)
1205+
unet_target_modules = ["to_k", "to_q", "to_v", "to_out.0"]
1206+
unet_lora_config = get_lora_config(rank=args.rank, use_dora=args.use_dora, target_modules=unet_target_modules)
11941207
unet.add_adapter(unet_lora_config)
11951208

11961209
# The text encoder comes from 🤗 transformers, so we cannot directly modify it.
11971210
# So, instead, we monkey-patch the forward calls of its attention-blocks.
11981211
if args.train_text_encoder:
1199-
text_lora_config = LoraConfig(
1200-
r=args.rank,
1201-
use_dora=args.use_dora,
1202-
lora_alpha=args.rank,
1203-
init_lora_weights="gaussian",
1204-
target_modules=["q_proj", "k_proj", "v_proj", "out_proj"],
1205-
)
1212+
text_target_modules = ["q_proj", "k_proj", "v_proj", "out_proj"]
1213+
text_lora_config = get_lora_config(rank=args.rank, use_dora=args.use_dora, target_modules=text_target_modules)
12061214
text_encoder_one.add_adapter(text_lora_config)
12071215
text_encoder_two.add_adapter(text_lora_config)
12081216

0 commit comments

Comments
 (0)