Skip to content

Fix lora alpha and metadata handling #11739

New issue

Have a question about this project? # for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “#”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? # to your account

Open
wants to merge 6 commits into
base: main
Choose a base branch
from
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
18 changes: 16 additions & 2 deletions examples/dreambooth/train_dreambooth_lora_sdxl.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,6 +62,7 @@
from diffusers.optimization import get_scheduler
from diffusers.training_utils import _set_state_dict_into_text_encoder, cast_training_params, compute_snr
from diffusers.utils import (
_collate_lora_metadata,
check_min_version,
convert_all_state_dict_to_peft,
convert_state_dict_to_diffusers,
Expand Down Expand Up @@ -659,6 +660,12 @@ def parse_args(input_args=None):
default=4,
help=("The dimension of the LoRA update matrices."),
)

parser.add_argument(
"--lora_alpha",
type=int,
default=4,
help="LoRA alpha to be used for additional scaling.",

parser.add_argument("--lora_dropout", type=float, default=0.0, help="Dropout probability for LoRA layers")

Expand Down Expand Up @@ -1202,10 +1209,10 @@ def main(args):
text_encoder_one.gradient_checkpointing_enable()
text_encoder_two.gradient_checkpointing_enable()

def get_lora_config(rank, dropout, use_dora, target_modules):
def get_lora_config(rank, lora_alpha, dropout, use_dora, target_modules):
base_config = {
"r": rank,
"lora_alpha": rank,
"lora_alpha":lora_alpha,
"lora_dropout": dropout,
"init_lora_weights": "gaussian",
"target_modules": target_modules,
Expand All @@ -1224,6 +1231,7 @@ def get_lora_config(rank, dropout, use_dora, target_modules):
unet_target_modules = ["to_k", "to_q", "to_v", "to_out.0"]
unet_lora_config = get_lora_config(
rank=args.rank,
lora_alpha=args.lora_alpha,
dropout=args.lora_dropout,
use_dora=args.use_dora,
target_modules=unet_target_modules,
Expand All @@ -1236,6 +1244,7 @@ def get_lora_config(rank, dropout, use_dora, target_modules):
text_target_modules = ["q_proj", "k_proj", "v_proj", "out_proj"]
text_lora_config = get_lora_config(
rank=args.rank,
lora_alpha=args.lora_alpha,
dropout=args.lora_dropout,
use_dora=args.use_dora,
target_modules=text_target_modules,
Expand All @@ -1256,10 +1265,12 @@ def save_model_hook(models, weights, output_dir):
unet_lora_layers_to_save = None
text_encoder_one_lora_layers_to_save = None
text_encoder_two_lora_layers_to_save = None
modules_to_save = {}

for model in models:
if isinstance(model, type(unwrap_model(unet))):
unet_lora_layers_to_save = convert_state_dict_to_diffusers(get_peft_model_state_dict(model))
modules_to_save["transformer"] = model
elif isinstance(model, type(unwrap_model(text_encoder_one))):
text_encoder_one_lora_layers_to_save = convert_state_dict_to_diffusers(
get_peft_model_state_dict(model)
Expand All @@ -1279,6 +1290,7 @@ def save_model_hook(models, weights, output_dir):
unet_lora_layers=unet_lora_layers_to_save,
text_encoder_lora_layers=text_encoder_one_lora_layers_to_save,
text_encoder_2_lora_layers=text_encoder_two_lora_layers_to_save,
**_collate_lora_metadata(modules_to_save),
)

def load_model_hook(models, input_dir):
Expand Down Expand Up @@ -1945,6 +1957,7 @@ def get_sigmas(timesteps, n_dim=4, dtype=torch.float32):
# Save the lora layers
accelerator.wait_for_everyone()
if accelerator.is_main_process:
modules_to_save = {}
unet = unwrap_model(unet)
unet = unet.to(torch.float32)
unet_lora_layers = convert_state_dict_to_diffusers(get_peft_model_state_dict(unet))
Expand All @@ -1967,6 +1980,7 @@ def get_sigmas(timesteps, n_dim=4, dtype=torch.float32):
unet_lora_layers=unet_lora_layers,
text_encoder_lora_layers=text_encoder_lora_layers,
text_encoder_2_lora_layers=text_encoder_2_lora_layers,
**_collate_lora_metadata(modules_to_save),
)
if args.output_kohya_format:
lora_state_dict = load_file(f"{args.output_dir}/pytorch_lora_weights.safetensors")
Expand Down