From fa57c21fd824a2c0c3757793f85db6bf0e3c8fe6 Mon Sep 17 00:00:00 2001 From: sayakpaul <spsayakpaul@gmail.com> Date: Wed, 10 Jan 2024 13:28:59 +0530 Subject: [PATCH 1/8] fix: training resume from fp16. --- .../dreambooth/train_dreambooth_lora_sdxl.py | 107 +++++++++++------- src/diffusers/loaders/lora.py | 22 ++-- 2 files changed, 81 insertions(+), 48 deletions(-) diff --git a/examples/dreambooth/train_dreambooth_lora_sdxl.py b/examples/dreambooth/train_dreambooth_lora_sdxl.py index 122af23865b8..09be06f4ecb5 100644 --- a/examples/dreambooth/train_dreambooth_lora_sdxl.py +++ b/examples/dreambooth/train_dreambooth_lora_sdxl.py @@ -56,7 +56,8 @@ from diffusers.training_utils import compute_snr from diffusers.utils import check_min_version, convert_state_dict_to_diffusers, is_wandb_available from diffusers.utils.import_utils import is_xformers_available - +from diffusers.utils.torch_utils import is_compiled_module +from diffusers.utils.peft_utils import delete_adapter_layers # Will error if the minimal version of diffusers is not installed. Remove at your own risks. check_min_version("0.26.0.dev0") @@ -780,12 +781,13 @@ def encode_prompt(text_encoders, tokenizers, prompt, text_input_ids_list=None): text_input_ids = text_input_ids_list[i] prompt_embeds = text_encoder( - text_input_ids.to(text_encoder.device), output_hidden_states=True, return_dict=False + text_input_ids.to(text_encoder.device), + output_hidden_states=True, ) # We are only ALWAYS interested in the pooled output of the final text encoder pooled_prompt_embeds = prompt_embeds[0] - prompt_embeds = prompt_embeds[-1][-2] + prompt_embeds = prompt_embeds.hidden_states[-2] bs_embed, seq_len, _ = prompt_embeds.shape prompt_embeds = prompt_embeds.view(bs_embed, seq_len, -1) prompt_embeds_list.append(prompt_embeds) @@ -996,16 +998,10 @@ def main(args): text_encoder_one.add_adapter(text_lora_config) text_encoder_two.add_adapter(text_lora_config) - # Make sure the trainable params are in float32. - if args.mixed_precision == "fp16": - models = [unet] - if args.train_text_encoder: - models.extend([text_encoder_one, text_encoder_two]) - for model in models: - for param in model.parameters(): - # only upcast trainable parameters (LoRA) into fp32 - if param.requires_grad: - param.data = param.to(torch.float32) + def unwrap_model(model): + model = accelerator.unwrap_model(model) + model = model._orig_mod if is_compiled_module(model) else model + return model # create custom saving & loading hooks so that `accelerator.save_state(...)` serializes in a nice format def save_model_hook(models, weights, output_dir): @@ -1017,13 +1013,13 @@ def save_model_hook(models, weights, output_dir): text_encoder_two_lora_layers_to_save = None for model in models: - if isinstance(model, type(accelerator.unwrap_model(unet))): + if isinstance(model, type(unwrap_model(unet))): unet_lora_layers_to_save = convert_state_dict_to_diffusers(get_peft_model_state_dict(model)) - elif isinstance(model, type(accelerator.unwrap_model(text_encoder_one))): + elif isinstance(model, type(unwrap_model(text_encoder_one))): text_encoder_one_lora_layers_to_save = convert_state_dict_to_diffusers( get_peft_model_state_dict(model) ) - elif isinstance(model, type(accelerator.unwrap_model(text_encoder_two))): + elif isinstance(model, type(unwrap_model(text_encoder_two))): text_encoder_two_lora_layers_to_save = convert_state_dict_to_diffusers( get_peft_model_state_dict(model) ) @@ -1048,27 +1044,47 @@ def load_model_hook(models, input_dir): while len(models) > 0: model = models.pop() - if isinstance(model, type(accelerator.unwrap_model(unet))): + if isinstance(model, type(unwrap_model(unet))): unet_ = model - elif isinstance(model, type(accelerator.unwrap_model(text_encoder_one))): + elif isinstance(model, type(unwrap_model(text_encoder_one))): text_encoder_one_ = model - elif isinstance(model, type(accelerator.unwrap_model(text_encoder_two))): + elif isinstance(model, type(unwrap_model(text_encoder_two))): text_encoder_two_ = model else: raise ValueError(f"unexpected save model: {model.__class__}") lora_state_dict, network_alphas = LoraLoaderMixin.lora_state_dict(input_dir) - LoraLoaderMixin.load_lora_into_unet(lora_state_dict, network_alphas=network_alphas, unet=unet_) - - text_encoder_state_dict = {k: v for k, v in lora_state_dict.items() if "text_encoder." in k} - LoraLoaderMixin.load_lora_into_text_encoder( - text_encoder_state_dict, network_alphas=network_alphas, text_encoder=text_encoder_one_ + # We pass the config here to ensure parity between `unet_lora_config` and + # the `LoraConfig` that's inferred in `load_lora_into_unet`. + LoraLoaderMixin.load_lora_into_unet( + lora_state_dict, network_alphas=network_alphas, unet=unet_, _config=unet_lora_config ) + # Remove the newly created adapter as we don't need it. + delete_adapter_layers(unet_, "default_1") - text_encoder_2_state_dict = {k: v for k, v in lora_state_dict.items() if "text_encoder_2." in k} - LoraLoaderMixin.load_lora_into_text_encoder( - text_encoder_2_state_dict, network_alphas=network_alphas, text_encoder=text_encoder_two_ - ) + if args.train_text_encoder: + text_encoder_state_dict = {k: v for k, v in lora_state_dict.items() if "text_encoder." in k} + LoraLoaderMixin.load_lora_into_text_encoder( + text_encoder_state_dict, network_alphas=network_alphas, text_encoder=text_encoder_one_, _config=text_lora_config + ) + delete_adapter_layers(text_encoder_one_, "default_1") + + text_encoder_2_state_dict = {k: v for k, v in lora_state_dict.items() if "text_encoder_2." in k} + LoraLoaderMixin.load_lora_into_text_encoder( + text_encoder_2_state_dict, network_alphas=network_alphas, text_encoder=text_encoder_two_, _config=text_lora_config + ) + delete_adapter_layers(text_encoder_two_, "default_1") + + # Make sure the trainable params are in float32. + if args.mixed_precision == "fp16": + models = [unet_] + if args.train_text_encoder: + models.extend([text_encoder_one_, text_encoder_two_]) + for model in models: + for param in model.parameters(): + # only upcast trainable parameters (LoRA) into fp32 + if param.requires_grad: + param.data = param.to(torch.float32) accelerator.register_save_state_pre_hook(save_model_hook) accelerator.register_load_state_pre_hook(load_model_hook) @@ -1083,6 +1099,17 @@ def load_model_hook(models, input_dir): args.learning_rate * args.gradient_accumulation_steps * args.train_batch_size * accelerator.num_processes ) + # Make sure the trainable params are in float32. + if args.mixed_precision == "fp16": + models = [unet] + if args.train_text_encoder: + models.extend([text_encoder_one, text_encoder_two]) + for model in models: + for param in model.parameters(): + # only upcast trainable parameters (LoRA) into fp32 + if param.requires_grad: + param.data = param.to(torch.float32) + unet_lora_parameters = list(filter(lambda p: p.requires_grad, unet.parameters())) if args.train_text_encoder: @@ -1428,8 +1455,7 @@ def compute_text_embeddings(prompt, text_encoders, tokenizers): timesteps, prompt_embeds_input, added_cond_kwargs=unet_added_conditions, - return_dict=False, - )[0] + ).sample else: unet_added_conditions = {"time_ids": add_time_ids.repeat(elems_to_repeat_time_ids, 1)} prompt_embeds, pooled_prompt_embeds = encode_prompt( @@ -1443,12 +1469,8 @@ def compute_text_embeddings(prompt, text_encoders, tokenizers): ) prompt_embeds_input = prompt_embeds.repeat(elems_to_repeat_text_embeds, 1, 1) model_pred = unet( - noisy_model_input, - timesteps, - prompt_embeds_input, - added_cond_kwargs=unet_added_conditions, - return_dict=False, - )[0] + noisy_model_input, timesteps, prompt_embeds_input, added_cond_kwargs=unet_added_conditions + ).sample # Get the target for loss depending on the prediction type if noise_scheduler.config.prediction_type == "epsilon": @@ -1499,7 +1521,12 @@ def compute_text_embeddings(prompt, text_encoders, tokenizers): if args.train_text_encoder else unet_lora_parameters ) + params_to_clip_dtype = {param.dtype for param in params_to_clip} + print(f"params_to_clip_dtype: {params_to_clip_dtype}") accelerator.clip_grad_norm_(params_to_clip, args.max_grad_norm) + + params_to_optimize_dtype = {param.dtype for param in params_to_optimize[0]["params"]} + print(f"params_to_optimize_dtype: {params_to_optimize_dtype}") optimizer.step() lr_scheduler.step() optimizer.zero_grad() @@ -1621,16 +1648,16 @@ def compute_text_embeddings(prompt, text_encoders, tokenizers): # Save the lora layers accelerator.wait_for_everyone() if accelerator.is_main_process: - unet = accelerator.unwrap_model(unet) - unet = unet.to(torch.float32) + unet = unwrap_model(unet) + unet = unet.to(torch.float32) unet_lora_layers = convert_state_dict_to_diffusers(get_peft_model_state_dict(unet)) if args.train_text_encoder: - text_encoder_one = accelerator.unwrap_model(text_encoder_one) + text_encoder_one = unwrap_model(text_encoder_one) text_encoder_lora_layers = convert_state_dict_to_diffusers( get_peft_model_state_dict(text_encoder_one.to(torch.float32)) ) - text_encoder_two = accelerator.unwrap_model(text_encoder_two) + text_encoder_two = unwrap_model(text_encoder_two) text_encoder_2_lora_layers = convert_state_dict_to_diffusers( get_peft_model_state_dict(text_encoder_two.to(torch.float32)) ) diff --git a/src/diffusers/loaders/lora.py b/src/diffusers/loaders/lora.py index 424e95f0843e..50002cd0c209 100644 --- a/src/diffusers/loaders/lora.py +++ b/src/diffusers/loaders/lora.py @@ -373,7 +373,7 @@ def _optionally_disable_offloading(cls, _pipeline): @classmethod def load_lora_into_unet( - cls, state_dict, network_alphas, unet, low_cpu_mem_usage=None, adapter_name=None, _pipeline=None + cls, state_dict, network_alphas, unet, low_cpu_mem_usage=None, adapter_name=None, _config=None, _pipeline=None ): """ This will load the LoRA layers specified in `state_dict` into `unet`. @@ -446,8 +446,11 @@ def load_lora_into_unet( if "lora_B" in key: rank[key] = val.shape[1] - lora_config_kwargs = get_peft_kwargs(rank, network_alphas, state_dict, is_unet=True) - lora_config = LoraConfig(**lora_config_kwargs) + if _config is None: + lora_config_kwargs = get_peft_kwargs(rank, network_alphas, state_dict, is_unet=True) + lora_config = LoraConfig(**lora_config_kwargs) + else: + lora_config = _config # adapter_name if adapter_name is None: @@ -490,6 +493,7 @@ def load_lora_into_text_encoder( lora_scale=1.0, low_cpu_mem_usage=None, adapter_name=None, + _config=None, _pipeline=None, ): """ @@ -578,11 +582,13 @@ def load_lora_into_text_encoder( if USE_PEFT_BACKEND: from peft import LoraConfig - lora_config_kwargs = get_peft_kwargs( - rank, network_alphas, text_encoder_lora_state_dict, is_unet=False - ) - - lora_config = LoraConfig(**lora_config_kwargs) + if _config is None: + lora_config_kwargs = get_peft_kwargs( + rank, network_alphas, text_encoder_lora_state_dict, is_unet=False + ) + lora_config = LoraConfig(**lora_config_kwargs) + else: + lora_config = _config # adapter_name if adapter_name is None: From 3994545e3f471a88eb0e768d83a34c59dc4d53c9 Mon Sep 17 00:00:00 2001 From: sayakpaul <spsayakpaul@gmail.com> Date: Wed, 10 Jan 2024 13:35:18 +0530 Subject: [PATCH 2/8] add: comment --- .../dreambooth/train_dreambooth_lora_sdxl.py | 28 +++++++++++-------- 1 file changed, 16 insertions(+), 12 deletions(-) diff --git a/examples/dreambooth/train_dreambooth_lora_sdxl.py b/examples/dreambooth/train_dreambooth_lora_sdxl.py index 09be06f4ecb5..435c50ec38e1 100644 --- a/examples/dreambooth/train_dreambooth_lora_sdxl.py +++ b/examples/dreambooth/train_dreambooth_lora_sdxl.py @@ -56,8 +56,9 @@ from diffusers.training_utils import compute_snr from diffusers.utils import check_min_version, convert_state_dict_to_diffusers, is_wandb_available from diffusers.utils.import_utils import is_xformers_available -from diffusers.utils.torch_utils import is_compiled_module from diffusers.utils.peft_utils import delete_adapter_layers +from diffusers.utils.torch_utils import is_compiled_module + # Will error if the minimal version of diffusers is not installed. Remove at your own risks. check_min_version("0.26.0.dev0") @@ -1000,7 +1001,7 @@ def main(args): def unwrap_model(model): model = accelerator.unwrap_model(model) - model = model._orig_mod if is_compiled_module(model) else model + model = model._orig_mod if is_compiled_module(model) else model return model # create custom saving & loading hooks so that `accelerator.save_state(...)` serializes in a nice format @@ -1065,17 +1066,24 @@ def load_model_hook(models, input_dir): if args.train_text_encoder: text_encoder_state_dict = {k: v for k, v in lora_state_dict.items() if "text_encoder." in k} LoraLoaderMixin.load_lora_into_text_encoder( - text_encoder_state_dict, network_alphas=network_alphas, text_encoder=text_encoder_one_, _config=text_lora_config + text_encoder_state_dict, + network_alphas=network_alphas, + text_encoder=text_encoder_one_, + _config=text_lora_config, ) delete_adapter_layers(text_encoder_one_, "default_1") - + text_encoder_2_state_dict = {k: v for k, v in lora_state_dict.items() if "text_encoder_2." in k} LoraLoaderMixin.load_lora_into_text_encoder( - text_encoder_2_state_dict, network_alphas=network_alphas, text_encoder=text_encoder_two_, _config=text_lora_config + text_encoder_2_state_dict, + network_alphas=network_alphas, + text_encoder=text_encoder_two_, + _config=text_lora_config, ) delete_adapter_layers(text_encoder_two_, "default_1") - # Make sure the trainable params are in float32. + # Make sure the trainable params are in float32. This is again needed since the base models + # are in `weight_dtype`. if args.mixed_precision == "fp16": models = [unet_] if args.train_text_encoder: @@ -1521,12 +1529,8 @@ def compute_text_embeddings(prompt, text_encoders, tokenizers): if args.train_text_encoder else unet_lora_parameters ) - params_to_clip_dtype = {param.dtype for param in params_to_clip} - print(f"params_to_clip_dtype: {params_to_clip_dtype}") accelerator.clip_grad_norm_(params_to_clip, args.max_grad_norm) - - params_to_optimize_dtype = {param.dtype for param in params_to_optimize[0]["params"]} - print(f"params_to_optimize_dtype: {params_to_optimize_dtype}") + optimizer.step() lr_scheduler.step() optimizer.zero_grad() @@ -1649,7 +1653,7 @@ def compute_text_embeddings(prompt, text_encoders, tokenizers): accelerator.wait_for_everyone() if accelerator.is_main_process: unet = unwrap_model(unet) - unet = unet.to(torch.float32) + unet = unet.to(torch.float32) unet_lora_layers = convert_state_dict_to_diffusers(get_peft_model_state_dict(unet)) if args.train_text_encoder: From 356ef29639c66364ea2d7be5fbf0225736245e82 Mon Sep 17 00:00:00 2001 From: sayakpaul <spsayakpaul@gmail.com> Date: Wed, 10 Jan 2024 13:38:50 +0530 Subject: [PATCH 3/8] remove residue from another branch. --- .../dreambooth/train_dreambooth_lora_sdxl.py | 24 +++++++------------ 1 file changed, 9 insertions(+), 15 deletions(-) diff --git a/examples/dreambooth/train_dreambooth_lora_sdxl.py b/examples/dreambooth/train_dreambooth_lora_sdxl.py index 435c50ec38e1..732f75a55b83 100644 --- a/examples/dreambooth/train_dreambooth_lora_sdxl.py +++ b/examples/dreambooth/train_dreambooth_lora_sdxl.py @@ -57,7 +57,6 @@ from diffusers.utils import check_min_version, convert_state_dict_to_diffusers, is_wandb_available from diffusers.utils.import_utils import is_xformers_available from diffusers.utils.peft_utils import delete_adapter_layers -from diffusers.utils.torch_utils import is_compiled_module # Will error if the minimal version of diffusers is not installed. Remove at your own risks. @@ -999,11 +998,6 @@ def main(args): text_encoder_one.add_adapter(text_lora_config) text_encoder_two.add_adapter(text_lora_config) - def unwrap_model(model): - model = accelerator.unwrap_model(model) - model = model._orig_mod if is_compiled_module(model) else model - return model - # create custom saving & loading hooks so that `accelerator.save_state(...)` serializes in a nice format def save_model_hook(models, weights, output_dir): if accelerator.is_main_process: @@ -1014,13 +1008,13 @@ def save_model_hook(models, weights, output_dir): text_encoder_two_lora_layers_to_save = None for model in models: - if isinstance(model, type(unwrap_model(unet))): + if isinstance(model, type(accelerator.unwrap_model(unet))): unet_lora_layers_to_save = convert_state_dict_to_diffusers(get_peft_model_state_dict(model)) - elif isinstance(model, type(unwrap_model(text_encoder_one))): + elif isinstance(model, type(accelerator.unwrap_model(text_encoder_one))): text_encoder_one_lora_layers_to_save = convert_state_dict_to_diffusers( get_peft_model_state_dict(model) ) - elif isinstance(model, type(unwrap_model(text_encoder_two))): + elif isinstance(model, type(accelerator.unwrap_model(text_encoder_two))): text_encoder_two_lora_layers_to_save = convert_state_dict_to_diffusers( get_peft_model_state_dict(model) ) @@ -1045,11 +1039,11 @@ def load_model_hook(models, input_dir): while len(models) > 0: model = models.pop() - if isinstance(model, type(unwrap_model(unet))): + if isinstance(model, type(accelerator.unwrap_model(unet))): unet_ = model - elif isinstance(model, type(unwrap_model(text_encoder_one))): + elif isinstance(model, type(accelerator.unwrap_model(text_encoder_one))): text_encoder_one_ = model - elif isinstance(model, type(unwrap_model(text_encoder_two))): + elif isinstance(model, type(accelerator.unwrap_model(text_encoder_two))): text_encoder_two_ = model else: raise ValueError(f"unexpected save model: {model.__class__}") @@ -1652,16 +1646,16 @@ def compute_text_embeddings(prompt, text_encoders, tokenizers): # Save the lora layers accelerator.wait_for_everyone() if accelerator.is_main_process: - unet = unwrap_model(unet) + unet = accelerator.unwrap_model(unet) unet = unet.to(torch.float32) unet_lora_layers = convert_state_dict_to_diffusers(get_peft_model_state_dict(unet)) if args.train_text_encoder: - text_encoder_one = unwrap_model(text_encoder_one) + text_encoder_one = accelerator.unwrap_model(text_encoder_one) text_encoder_lora_layers = convert_state_dict_to_diffusers( get_peft_model_state_dict(text_encoder_one.to(torch.float32)) ) - text_encoder_two = unwrap_model(text_encoder_two) + text_encoder_two = accelerator.unwrap_model(text_encoder_two) text_encoder_2_lora_layers = convert_state_dict_to_diffusers( get_peft_model_state_dict(text_encoder_two.to(torch.float32)) ) From 32ecbee9f42b48a07abb397b6776e78269d02353 Mon Sep 17 00:00:00 2001 From: sayakpaul <spsayakpaul@gmail.com> Date: Wed, 10 Jan 2024 13:41:56 +0530 Subject: [PATCH 4/8] remove more residues. --- .../dreambooth/train_dreambooth_lora_sdxl.py | 16 ++++++++++------ 1 file changed, 10 insertions(+), 6 deletions(-) diff --git a/examples/dreambooth/train_dreambooth_lora_sdxl.py b/examples/dreambooth/train_dreambooth_lora_sdxl.py index 732f75a55b83..6dd2837167e2 100644 --- a/examples/dreambooth/train_dreambooth_lora_sdxl.py +++ b/examples/dreambooth/train_dreambooth_lora_sdxl.py @@ -781,13 +781,12 @@ def encode_prompt(text_encoders, tokenizers, prompt, text_input_ids_list=None): text_input_ids = text_input_ids_list[i] prompt_embeds = text_encoder( - text_input_ids.to(text_encoder.device), - output_hidden_states=True, + text_input_ids.to(text_encoder.device), output_hidden_states=True, return_dict=False ) # We are only ALWAYS interested in the pooled output of the final text encoder pooled_prompt_embeds = prompt_embeds[0] - prompt_embeds = prompt_embeds.hidden_states[-2] + prompt_embeds = prompt_embeds[-1][-2] bs_embed, seq_len, _ = prompt_embeds.shape prompt_embeds = prompt_embeds.view(bs_embed, seq_len, -1) prompt_embeds_list.append(prompt_embeds) @@ -1457,7 +1456,8 @@ def compute_text_embeddings(prompt, text_encoders, tokenizers): timesteps, prompt_embeds_input, added_cond_kwargs=unet_added_conditions, - ).sample + return_dict=False, + )[0] else: unet_added_conditions = {"time_ids": add_time_ids.repeat(elems_to_repeat_time_ids, 1)} prompt_embeds, pooled_prompt_embeds = encode_prompt( @@ -1471,8 +1471,12 @@ def compute_text_embeddings(prompt, text_encoders, tokenizers): ) prompt_embeds_input = prompt_embeds.repeat(elems_to_repeat_text_embeds, 1, 1) model_pred = unet( - noisy_model_input, timesteps, prompt_embeds_input, added_cond_kwargs=unet_added_conditions - ).sample + noisy_model_input, + timesteps, + prompt_embeds_input, + added_cond_kwargs=unet_added_conditions, + return_dict=False, + )[0] # Get the target for loss depending on the prediction type if noise_scheduler.config.prediction_type == "epsilon": From 03bef1db8a71fa3285f110ae0ff4621bd17cee6e Mon Sep 17 00:00:00 2001 From: sayakpaul <spsayakpaul@gmail.com> Date: Wed, 10 Jan 2024 17:12:19 +0530 Subject: [PATCH 5/8] thanks to Younes; no hacks. --- .../dreambooth/train_dreambooth_lora_sdxl.py | 56 +++++++++++-------- src/diffusers/loaders/lora.py | 21 +++---- 2 files changed, 39 insertions(+), 38 deletions(-) diff --git a/examples/dreambooth/train_dreambooth_lora_sdxl.py b/examples/dreambooth/train_dreambooth_lora_sdxl.py index 6dd2837167e2..068ab1253969 100644 --- a/examples/dreambooth/train_dreambooth_lora_sdxl.py +++ b/examples/dreambooth/train_dreambooth_lora_sdxl.py @@ -34,7 +34,7 @@ from huggingface_hub import create_repo, upload_folder from huggingface_hub.utils import insecure_hashlib from packaging import version -from peft import LoraConfig +from peft import LoraConfig, set_peft_model_state_dict from peft.utils import get_peft_model_state_dict from PIL import Image from PIL.ImageOps import exif_transpose @@ -56,7 +56,7 @@ from diffusers.training_utils import compute_snr from diffusers.utils import check_min_version, convert_state_dict_to_diffusers, is_wandb_available from diffusers.utils.import_utils import is_xformers_available -from diffusers.utils.peft_utils import delete_adapter_layers +from diffusers.utils.state_dict_utils import convert_state_dict_to_peft, convert_unet_state_dict_to_peft # Will error if the minimal version of diffusers is not installed. Remove at your own risks. @@ -1048,32 +1048,40 @@ def load_model_hook(models, input_dir): raise ValueError(f"unexpected save model: {model.__class__}") lora_state_dict, network_alphas = LoraLoaderMixin.lora_state_dict(input_dir) - # We pass the config here to ensure parity between `unet_lora_config` and - # the `LoraConfig` that's inferred in `load_lora_into_unet`. - LoraLoaderMixin.load_lora_into_unet( - lora_state_dict, network_alphas=network_alphas, unet=unet_, _config=unet_lora_config - ) - # Remove the newly created adapter as we don't need it. - delete_adapter_layers(unet_, "default_1") + + unet_state_dict = {f'{k.replace("unet.", "")}': v for k, v in lora_state_dict.items() if k.startswith("unet.")} + unet_state_dict = convert_unet_state_dict_to_peft(unet_state_dict) + incompatible_keys = set_peft_model_state_dict(unet_, unet_state_dict, adapter_name="default") + if incompatible_keys is not None: + # check only for unexpected keys + unexpected_keys = getattr(incompatible_keys, "unexpected_keys", None) + if unexpected_keys: + logger.warning( + f"Loading adapter weights from state_dict led to unexpected keys not found in the model: " + f" {unexpected_keys}. " + ) if args.train_text_encoder: - text_encoder_state_dict = {k: v for k, v in lora_state_dict.items() if "text_encoder." in k} - LoraLoaderMixin.load_lora_into_text_encoder( - text_encoder_state_dict, - network_alphas=network_alphas, - text_encoder=text_encoder_one_, - _config=text_lora_config, + # Do we need to call `scale_lora_layers()` here? + text_encoder_state_dict = { + f'{k.replace("text_encoder.", "")}': v + for k, v in lora_state_dict.items() + if k.startswith("text_encoder.") + } + text_encoder_state_dict = convert_state_dict_to_peft( + convert_state_dict_to_diffusers(text_encoder_state_dict) ) - delete_adapter_layers(text_encoder_one_, "default_1") - - text_encoder_2_state_dict = {k: v for k, v in lora_state_dict.items() if "text_encoder_2." in k} - LoraLoaderMixin.load_lora_into_text_encoder( - text_encoder_2_state_dict, - network_alphas=network_alphas, - text_encoder=text_encoder_two_, - _config=text_lora_config, + set_peft_model_state_dict(text_encoder_one_, text_encoder_state_dict, adapter_name="default") + + text_encoder_2_state_dict = { + f'{k.replace("text_encoder_2.", "")}': v + for k, v in lora_state_dict.items() + if k.startswith("text_encoder_2.") + } + text_encoder_2_state_dict = convert_state_dict_to_peft( + convert_state_dict_to_diffusers(text_encoder_2_state_dict) ) - delete_adapter_layers(text_encoder_two_, "default_1") + set_peft_model_state_dict(text_encoder_two_, text_encoder_2_state_dict, adapter_name="default") # Make sure the trainable params are in float32. This is again needed since the base models # are in `weight_dtype`. diff --git a/src/diffusers/loaders/lora.py b/src/diffusers/loaders/lora.py index 50002cd0c209..7df91a05b593 100644 --- a/src/diffusers/loaders/lora.py +++ b/src/diffusers/loaders/lora.py @@ -373,7 +373,7 @@ def _optionally_disable_offloading(cls, _pipeline): @classmethod def load_lora_into_unet( - cls, state_dict, network_alphas, unet, low_cpu_mem_usage=None, adapter_name=None, _config=None, _pipeline=None + cls, state_dict, network_alphas, unet, low_cpu_mem_usage=None, adapter_name=None, _pipeline=None ): """ This will load the LoRA layers specified in `state_dict` into `unet`. @@ -446,11 +446,8 @@ def load_lora_into_unet( if "lora_B" in key: rank[key] = val.shape[1] - if _config is None: - lora_config_kwargs = get_peft_kwargs(rank, network_alphas, state_dict, is_unet=True) - lora_config = LoraConfig(**lora_config_kwargs) - else: - lora_config = _config + lora_config_kwargs = get_peft_kwargs(rank, network_alphas, state_dict, is_unet=True) + lora_config = LoraConfig(**lora_config_kwargs) # adapter_name if adapter_name is None: @@ -493,7 +490,6 @@ def load_lora_into_text_encoder( lora_scale=1.0, low_cpu_mem_usage=None, adapter_name=None, - _config=None, _pipeline=None, ): """ @@ -582,13 +578,10 @@ def load_lora_into_text_encoder( if USE_PEFT_BACKEND: from peft import LoraConfig - if _config is None: - lora_config_kwargs = get_peft_kwargs( - rank, network_alphas, text_encoder_lora_state_dict, is_unet=False - ) - lora_config = LoraConfig(**lora_config_kwargs) - else: - lora_config = _config + lora_config_kwargs = get_peft_kwargs( + rank, network_alphas, text_encoder_lora_state_dict, is_unet=False + ) + lora_config = LoraConfig(**lora_config_kwargs) # adapter_name if adapter_name is None: From 820522fbdb44c674904c355edaaa0d7de13b2c06 Mon Sep 17 00:00:00 2001 From: sayakpaul <spsayakpaul@gmail.com> Date: Fri, 12 Jan 2024 09:56:23 +0530 Subject: [PATCH 6/8] style. --- examples/dreambooth/train_dreambooth_lora_sdxl.py | 1 - 1 file changed, 1 deletion(-) diff --git a/examples/dreambooth/train_dreambooth_lora_sdxl.py b/examples/dreambooth/train_dreambooth_lora_sdxl.py index 9cd7770302cc..c740b4d303fc 100644 --- a/examples/dreambooth/train_dreambooth_lora_sdxl.py +++ b/examples/dreambooth/train_dreambooth_lora_sdxl.py @@ -150,7 +150,6 @@ def import_model_class_from_model_name_or_path( return CLIPTextModelWithProjection else: raise ValueError(f"{model_class} is not supported.") - def parse_args(input_args=None): From 85e6b6bc6cba54b8bcbeaadf76a1bb1aee0caa28 Mon Sep 17 00:00:00 2001 From: sayakpaul <spsayakpaul@gmail.com> Date: Fri, 12 Jan 2024 10:11:08 +0530 Subject: [PATCH 7/8] clean things a bit and modularize _set_state_dict_into_text_encoder --- .../dreambooth/train_dreambooth_lora_sdxl.py | 32 +++++++------------ src/diffusers/training_utils.py | 30 ++++++++++++++++- 2 files changed, 40 insertions(+), 22 deletions(-) diff --git a/examples/dreambooth/train_dreambooth_lora_sdxl.py b/examples/dreambooth/train_dreambooth_lora_sdxl.py index c740b4d303fc..87bf4e9da2c4 100644 --- a/examples/dreambooth/train_dreambooth_lora_sdxl.py +++ b/examples/dreambooth/train_dreambooth_lora_sdxl.py @@ -53,10 +53,14 @@ ) from diffusers.loaders import LoraLoaderMixin from diffusers.optimization import get_scheduler -from diffusers.training_utils import compute_snr -from diffusers.utils import check_min_version, convert_state_dict_to_diffusers, is_wandb_available +from diffusers.training_utils import _set_state_dict_into_text_encoder, compute_snr +from diffusers.utils import ( + check_min_version, + convert_state_dict_to_diffusers, + convert_unet_state_dict_to_peft, + is_wandb_available, +) from diffusers.utils.import_utils import is_xformers_available -from diffusers.utils.state_dict_utils import convert_state_dict_to_peft, convert_unet_state_dict_to_peft from diffusers.utils.torch_utils import is_compiled_module @@ -1069,25 +1073,11 @@ def load_model_hook(models, input_dir): if args.train_text_encoder: # Do we need to call `scale_lora_layers()` here? - text_encoder_state_dict = { - f'{k.replace("text_encoder.", "")}': v - for k, v in lora_state_dict.items() - if k.startswith("text_encoder.") - } - text_encoder_state_dict = convert_state_dict_to_peft( - convert_state_dict_to_diffusers(text_encoder_state_dict) - ) - set_peft_model_state_dict(text_encoder_one_, text_encoder_state_dict, adapter_name="default") - - text_encoder_2_state_dict = { - f'{k.replace("text_encoder_2.", "")}': v - for k, v in lora_state_dict.items() - if k.startswith("text_encoder_2.") - } - text_encoder_2_state_dict = convert_state_dict_to_peft( - convert_state_dict_to_diffusers(text_encoder_2_state_dict) + _set_state_dict_into_text_encoder(lora_state_dict, prefix="text_encoder.", text_encoder=text_encoder_one_) + + _set_state_dict_into_text_encoder( + lora_state_dict, prefix="text_encoder_2.", text_encoder=text_encoder_one_ ) - set_peft_model_state_dict(text_encoder_two_, text_encoder_2_state_dict, adapter_name="default") # Make sure the trainable params are in float32. This is again needed since the base models # are in `weight_dtype`. diff --git a/src/diffusers/training_utils.py b/src/diffusers/training_utils.py index 9fb6fad3a267..8ff904305242 100644 --- a/src/diffusers/training_utils.py +++ b/src/diffusers/training_utils.py @@ -8,12 +8,21 @@ from torchvision import transforms from .models import UNet2DConditionModel -from .utils import deprecate, is_transformers_available +from .utils import ( + convert_state_dict_to_diffusers, + convert_state_dict_to_peft, + deprecate, + is_peft_available, + is_transformers_available, +) if is_transformers_available(): import transformers +if is_peft_available(): + from peft import set_peft_model_state_dict + def set_seed(seed: int): """ @@ -112,6 +121,25 @@ def unet_lora_state_dict(unet: UNet2DConditionModel) -> Dict[str, torch.Tensor]: return lora_state_dict +def _set_state_dict_into_text_encoder( + lora_state_dict: Dict[str, torch.Tensor], prefix: str, text_encoder: torch.nn.Module +): + """ + Sets the `lora_state_dict` into `text_encoder` coming from `transformers`. + + Args: + lora_state_dict: The state dictionary to be set. + prefix: String identifier to retrieve the portion of the state dict that belongs to `text_encoder`. + text_encoder: Where the `lora_state_dict` is to be set. + """ + + text_encoder_state_dict = { + f'{k.replace(prefix, "")}': v for k, v in lora_state_dict.items() if k.startswith(prefix) + } + text_encoder_state_dict = convert_state_dict_to_peft(convert_state_dict_to_diffusers(text_encoder_state_dict)) + set_peft_model_state_dict(text_encoder, text_encoder_state_dict, adapter_name="default") + + # Adapted from torch-ema https://github.com/fadel/pytorch_ema/blob/master/torch_ema/ema.py#L14 class EMAModel: """ From e8f1d38641f0b5c759432d15e0b7493608c39d1a Mon Sep 17 00:00:00 2001 From: sayakpaul <spsayakpaul@gmail.com> Date: Fri, 12 Jan 2024 16:53:53 +0530 Subject: [PATCH 8/8] add comment about the fix detailed. --- examples/dreambooth/train_dreambooth_lora_sdxl.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/examples/dreambooth/train_dreambooth_lora_sdxl.py b/examples/dreambooth/train_dreambooth_lora_sdxl.py index 87bf4e9da2c4..c59036d13beb 100644 --- a/examples/dreambooth/train_dreambooth_lora_sdxl.py +++ b/examples/dreambooth/train_dreambooth_lora_sdxl.py @@ -1080,7 +1080,8 @@ def load_model_hook(models, input_dir): ) # Make sure the trainable params are in float32. This is again needed since the base models - # are in `weight_dtype`. + # are in `weight_dtype`. More details: + # https://github.com/huggingface/diffusers/pull/6514#discussion_r1449796804 if args.mixed_precision == "fp16": models = [unet_] if args.train_text_encoder: