diff --git a/examples/dreambooth/train_dreambooth_lora_sdxl.py b/examples/dreambooth/train_dreambooth_lora_sdxl.py index 8f92e3b44295..c59036d13beb 100644 --- a/examples/dreambooth/train_dreambooth_lora_sdxl.py +++ b/examples/dreambooth/train_dreambooth_lora_sdxl.py @@ -34,7 +34,7 @@ from huggingface_hub import create_repo, upload_folder from huggingface_hub.utils import insecure_hashlib from packaging import version -from peft import LoraConfig +from peft import LoraConfig, set_peft_model_state_dict from peft.utils import get_peft_model_state_dict from PIL import Image from PIL.ImageOps import exif_transpose @@ -53,8 +53,13 @@ ) from diffusers.loaders import LoraLoaderMixin from diffusers.optimization import get_scheduler -from diffusers.training_utils import compute_snr -from diffusers.utils import check_min_version, convert_state_dict_to_diffusers, is_wandb_available +from diffusers.training_utils import _set_state_dict_into_text_encoder, compute_snr +from diffusers.utils import ( + check_min_version, + convert_state_dict_to_diffusers, + convert_unet_state_dict_to_peft, + is_wandb_available, +) from diffusers.utils.import_utils import is_xformers_available from diffusers.utils.torch_utils import is_compiled_module @@ -997,17 +1002,6 @@ def main(args): text_encoder_one.add_adapter(text_lora_config) text_encoder_two.add_adapter(text_lora_config) - # Make sure the trainable params are in float32. - if args.mixed_precision == "fp16": - models = [unet] - if args.train_text_encoder: - models.extend([text_encoder_one, text_encoder_two]) - for model in models: - for param in model.parameters(): - # only upcast trainable parameters (LoRA) into fp32 - if param.requires_grad: - param.data = param.to(torch.float32) - def unwrap_model(model): model = accelerator.unwrap_model(model) model = model._orig_mod if is_compiled_module(model) else model @@ -1064,17 +1058,39 @@ def load_model_hook(models, input_dir): raise ValueError(f"unexpected save model: {model.__class__}") lora_state_dict, network_alphas = LoraLoaderMixin.lora_state_dict(input_dir) - LoraLoaderMixin.load_lora_into_unet(lora_state_dict, network_alphas=network_alphas, unet=unet_) - text_encoder_state_dict = {k: v for k, v in lora_state_dict.items() if "text_encoder." in k} - LoraLoaderMixin.load_lora_into_text_encoder( - text_encoder_state_dict, network_alphas=network_alphas, text_encoder=text_encoder_one_ - ) + unet_state_dict = {f'{k.replace("unet.", "")}': v for k, v in lora_state_dict.items() if k.startswith("unet.")} + unet_state_dict = convert_unet_state_dict_to_peft(unet_state_dict) + incompatible_keys = set_peft_model_state_dict(unet_, unet_state_dict, adapter_name="default") + if incompatible_keys is not None: + # check only for unexpected keys + unexpected_keys = getattr(incompatible_keys, "unexpected_keys", None) + if unexpected_keys: + logger.warning( + f"Loading adapter weights from state_dict led to unexpected keys not found in the model: " + f" {unexpected_keys}. " + ) - text_encoder_2_state_dict = {k: v for k, v in lora_state_dict.items() if "text_encoder_2." in k} - LoraLoaderMixin.load_lora_into_text_encoder( - text_encoder_2_state_dict, network_alphas=network_alphas, text_encoder=text_encoder_two_ - ) + if args.train_text_encoder: + # Do we need to call `scale_lora_layers()` here? + _set_state_dict_into_text_encoder(lora_state_dict, prefix="text_encoder.", text_encoder=text_encoder_one_) + + _set_state_dict_into_text_encoder( + lora_state_dict, prefix="text_encoder_2.", text_encoder=text_encoder_one_ + ) + + # Make sure the trainable params are in float32. This is again needed since the base models + # are in `weight_dtype`. More details: + # https://github.com/huggingface/diffusers/pull/6514#discussion_r1449796804 + if args.mixed_precision == "fp16": + models = [unet_] + if args.train_text_encoder: + models.extend([text_encoder_one_, text_encoder_two_]) + for model in models: + for param in model.parameters(): + # only upcast trainable parameters (LoRA) into fp32 + if param.requires_grad: + param.data = param.to(torch.float32) accelerator.register_save_state_pre_hook(save_model_hook) accelerator.register_load_state_pre_hook(load_model_hook) @@ -1089,6 +1105,17 @@ def load_model_hook(models, input_dir): args.learning_rate * args.gradient_accumulation_steps * args.train_batch_size * accelerator.num_processes ) + # Make sure the trainable params are in float32. + if args.mixed_precision == "fp16": + models = [unet] + if args.train_text_encoder: + models.extend([text_encoder_one, text_encoder_two]) + for model in models: + for param in model.parameters(): + # only upcast trainable parameters (LoRA) into fp32 + if param.requires_grad: + param.data = param.to(torch.float32) + unet_lora_parameters = list(filter(lambda p: p.requires_grad, unet.parameters())) if args.train_text_encoder: @@ -1506,6 +1533,7 @@ def compute_text_embeddings(prompt, text_encoders, tokenizers): else unet_lora_parameters ) accelerator.clip_grad_norm_(params_to_clip, args.max_grad_norm) + optimizer.step() lr_scheduler.step() optimizer.zero_grad() diff --git a/src/diffusers/loaders/lora.py b/src/diffusers/loaders/lora.py index 424e95f0843e..7df91a05b593 100644 --- a/src/diffusers/loaders/lora.py +++ b/src/diffusers/loaders/lora.py @@ -581,7 +581,6 @@ def load_lora_into_text_encoder( lora_config_kwargs = get_peft_kwargs( rank, network_alphas, text_encoder_lora_state_dict, is_unet=False ) - lora_config = LoraConfig(**lora_config_kwargs) # adapter_name diff --git a/src/diffusers/training_utils.py b/src/diffusers/training_utils.py index 9fb6fad3a267..8ff904305242 100644 --- a/src/diffusers/training_utils.py +++ b/src/diffusers/training_utils.py @@ -8,12 +8,21 @@ from torchvision import transforms from .models import UNet2DConditionModel -from .utils import deprecate, is_transformers_available +from .utils import ( + convert_state_dict_to_diffusers, + convert_state_dict_to_peft, + deprecate, + is_peft_available, + is_transformers_available, +) if is_transformers_available(): import transformers +if is_peft_available(): + from peft import set_peft_model_state_dict + def set_seed(seed: int): """ @@ -112,6 +121,25 @@ def unet_lora_state_dict(unet: UNet2DConditionModel) -> Dict[str, torch.Tensor]: return lora_state_dict +def _set_state_dict_into_text_encoder( + lora_state_dict: Dict[str, torch.Tensor], prefix: str, text_encoder: torch.nn.Module +): + """ + Sets the `lora_state_dict` into `text_encoder` coming from `transformers`. + + Args: + lora_state_dict: The state dictionary to be set. + prefix: String identifier to retrieve the portion of the state dict that belongs to `text_encoder`. + text_encoder: Where the `lora_state_dict` is to be set. + """ + + text_encoder_state_dict = { + f'{k.replace(prefix, "")}': v for k, v in lora_state_dict.items() if k.startswith(prefix) + } + text_encoder_state_dict = convert_state_dict_to_peft(convert_state_dict_to_diffusers(text_encoder_state_dict)) + set_peft_model_state_dict(text_encoder, text_encoder_state_dict, adapter_name="default") + + # Adapted from torch-ema https://github.com/fadel/pytorch_ema/blob/master/torch_ema/ema.py#L14 class EMAModel: """