From fa57c21fd824a2c0c3757793f85db6bf0e3c8fe6 Mon Sep 17 00:00:00 2001
From: sayakpaul <spsayakpaul@gmail.com>
Date: Wed, 10 Jan 2024 13:28:59 +0530
Subject: [PATCH 1/8] fix: training resume from fp16.

---
 .../dreambooth/train_dreambooth_lora_sdxl.py  | 107 +++++++++++-------
 src/diffusers/loaders/lora.py                 |  22 ++--
 2 files changed, 81 insertions(+), 48 deletions(-)

diff --git a/examples/dreambooth/train_dreambooth_lora_sdxl.py b/examples/dreambooth/train_dreambooth_lora_sdxl.py
index 122af23865b8..09be06f4ecb5 100644
--- a/examples/dreambooth/train_dreambooth_lora_sdxl.py
+++ b/examples/dreambooth/train_dreambooth_lora_sdxl.py
@@ -56,7 +56,8 @@
 from diffusers.training_utils import compute_snr
 from diffusers.utils import check_min_version, convert_state_dict_to_diffusers, is_wandb_available
 from diffusers.utils.import_utils import is_xformers_available
-
+from diffusers.utils.torch_utils import is_compiled_module
+from diffusers.utils.peft_utils import delete_adapter_layers
 
 # Will error if the minimal version of diffusers is not installed. Remove at your own risks.
 check_min_version("0.26.0.dev0")
@@ -780,12 +781,13 @@ def encode_prompt(text_encoders, tokenizers, prompt, text_input_ids_list=None):
             text_input_ids = text_input_ids_list[i]
 
         prompt_embeds = text_encoder(
-            text_input_ids.to(text_encoder.device), output_hidden_states=True, return_dict=False
+            text_input_ids.to(text_encoder.device),
+            output_hidden_states=True,
         )
 
         # We are only ALWAYS interested in the pooled output of the final text encoder
         pooled_prompt_embeds = prompt_embeds[0]
-        prompt_embeds = prompt_embeds[-1][-2]
+        prompt_embeds = prompt_embeds.hidden_states[-2]
         bs_embed, seq_len, _ = prompt_embeds.shape
         prompt_embeds = prompt_embeds.view(bs_embed, seq_len, -1)
         prompt_embeds_list.append(prompt_embeds)
@@ -996,16 +998,10 @@ def main(args):
         text_encoder_one.add_adapter(text_lora_config)
         text_encoder_two.add_adapter(text_lora_config)
 
-    # Make sure the trainable params are in float32.
-    if args.mixed_precision == "fp16":
-        models = [unet]
-        if args.train_text_encoder:
-            models.extend([text_encoder_one, text_encoder_two])
-        for model in models:
-            for param in model.parameters():
-                # only upcast trainable parameters (LoRA) into fp32
-                if param.requires_grad:
-                    param.data = param.to(torch.float32)
+    def unwrap_model(model):
+        model = accelerator.unwrap_model(model)
+        model = model._orig_mod if is_compiled_module(model) else model 
+        return model
 
     # create custom saving & loading hooks so that `accelerator.save_state(...)` serializes in a nice format
     def save_model_hook(models, weights, output_dir):
@@ -1017,13 +1013,13 @@ def save_model_hook(models, weights, output_dir):
             text_encoder_two_lora_layers_to_save = None
 
             for model in models:
-                if isinstance(model, type(accelerator.unwrap_model(unet))):
+                if isinstance(model, type(unwrap_model(unet))):
                     unet_lora_layers_to_save = convert_state_dict_to_diffusers(get_peft_model_state_dict(model))
-                elif isinstance(model, type(accelerator.unwrap_model(text_encoder_one))):
+                elif isinstance(model, type(unwrap_model(text_encoder_one))):
                     text_encoder_one_lora_layers_to_save = convert_state_dict_to_diffusers(
                         get_peft_model_state_dict(model)
                     )
-                elif isinstance(model, type(accelerator.unwrap_model(text_encoder_two))):
+                elif isinstance(model, type(unwrap_model(text_encoder_two))):
                     text_encoder_two_lora_layers_to_save = convert_state_dict_to_diffusers(
                         get_peft_model_state_dict(model)
                     )
@@ -1048,27 +1044,47 @@ def load_model_hook(models, input_dir):
         while len(models) > 0:
             model = models.pop()
 
-            if isinstance(model, type(accelerator.unwrap_model(unet))):
+            if isinstance(model, type(unwrap_model(unet))):
                 unet_ = model
-            elif isinstance(model, type(accelerator.unwrap_model(text_encoder_one))):
+            elif isinstance(model, type(unwrap_model(text_encoder_one))):
                 text_encoder_one_ = model
-            elif isinstance(model, type(accelerator.unwrap_model(text_encoder_two))):
+            elif isinstance(model, type(unwrap_model(text_encoder_two))):
                 text_encoder_two_ = model
             else:
                 raise ValueError(f"unexpected save model: {model.__class__}")
 
         lora_state_dict, network_alphas = LoraLoaderMixin.lora_state_dict(input_dir)
-        LoraLoaderMixin.load_lora_into_unet(lora_state_dict, network_alphas=network_alphas, unet=unet_)
-
-        text_encoder_state_dict = {k: v for k, v in lora_state_dict.items() if "text_encoder." in k}
-        LoraLoaderMixin.load_lora_into_text_encoder(
-            text_encoder_state_dict, network_alphas=network_alphas, text_encoder=text_encoder_one_
+        # We pass the config here to ensure parity between `unet_lora_config` and
+        # the `LoraConfig` that's inferred in `load_lora_into_unet`.
+        LoraLoaderMixin.load_lora_into_unet(
+            lora_state_dict, network_alphas=network_alphas, unet=unet_, _config=unet_lora_config
         )
+        # Remove the newly created adapter as we don't need it.
+        delete_adapter_layers(unet_, "default_1")
 
-        text_encoder_2_state_dict = {k: v for k, v in lora_state_dict.items() if "text_encoder_2." in k}
-        LoraLoaderMixin.load_lora_into_text_encoder(
-            text_encoder_2_state_dict, network_alphas=network_alphas, text_encoder=text_encoder_two_
-        )
+        if args.train_text_encoder:
+            text_encoder_state_dict = {k: v for k, v in lora_state_dict.items() if "text_encoder." in k}
+            LoraLoaderMixin.load_lora_into_text_encoder(
+                text_encoder_state_dict, network_alphas=network_alphas, text_encoder=text_encoder_one_, _config=text_lora_config
+            )
+            delete_adapter_layers(text_encoder_one_, "default_1")
+        
+            text_encoder_2_state_dict = {k: v for k, v in lora_state_dict.items() if "text_encoder_2." in k}
+            LoraLoaderMixin.load_lora_into_text_encoder(
+                text_encoder_2_state_dict, network_alphas=network_alphas, text_encoder=text_encoder_two_, _config=text_lora_config
+            )
+            delete_adapter_layers(text_encoder_two_, "default_1")
+
+        # Make sure the trainable params are in float32.
+        if args.mixed_precision == "fp16":
+            models = [unet_]
+            if args.train_text_encoder:
+                models.extend([text_encoder_one_, text_encoder_two_])
+            for model in models:
+                for param in model.parameters():
+                    # only upcast trainable parameters (LoRA) into fp32
+                    if param.requires_grad:
+                        param.data = param.to(torch.float32)
 
     accelerator.register_save_state_pre_hook(save_model_hook)
     accelerator.register_load_state_pre_hook(load_model_hook)
@@ -1083,6 +1099,17 @@ def load_model_hook(models, input_dir):
             args.learning_rate * args.gradient_accumulation_steps * args.train_batch_size * accelerator.num_processes
         )
 
+    # Make sure the trainable params are in float32.
+    if args.mixed_precision == "fp16":
+        models = [unet]
+        if args.train_text_encoder:
+            models.extend([text_encoder_one, text_encoder_two])
+        for model in models:
+            for param in model.parameters():
+                # only upcast trainable parameters (LoRA) into fp32
+                if param.requires_grad:
+                    param.data = param.to(torch.float32)
+
     unet_lora_parameters = list(filter(lambda p: p.requires_grad, unet.parameters()))
 
     if args.train_text_encoder:
@@ -1428,8 +1455,7 @@ def compute_text_embeddings(prompt, text_encoders, tokenizers):
                         timesteps,
                         prompt_embeds_input,
                         added_cond_kwargs=unet_added_conditions,
-                        return_dict=False,
-                    )[0]
+                    ).sample
                 else:
                     unet_added_conditions = {"time_ids": add_time_ids.repeat(elems_to_repeat_time_ids, 1)}
                     prompt_embeds, pooled_prompt_embeds = encode_prompt(
@@ -1443,12 +1469,8 @@ def compute_text_embeddings(prompt, text_encoders, tokenizers):
                     )
                     prompt_embeds_input = prompt_embeds.repeat(elems_to_repeat_text_embeds, 1, 1)
                     model_pred = unet(
-                        noisy_model_input,
-                        timesteps,
-                        prompt_embeds_input,
-                        added_cond_kwargs=unet_added_conditions,
-                        return_dict=False,
-                    )[0]
+                        noisy_model_input, timesteps, prompt_embeds_input, added_cond_kwargs=unet_added_conditions
+                    ).sample
 
                 # Get the target for loss depending on the prediction type
                 if noise_scheduler.config.prediction_type == "epsilon":
@@ -1499,7 +1521,12 @@ def compute_text_embeddings(prompt, text_encoders, tokenizers):
                         if args.train_text_encoder
                         else unet_lora_parameters
                     )
+                    params_to_clip_dtype = {param.dtype for param in params_to_clip}
+                    print(f"params_to_clip_dtype: {params_to_clip_dtype}")
                     accelerator.clip_grad_norm_(params_to_clip, args.max_grad_norm)
+                
+                params_to_optimize_dtype = {param.dtype for param in params_to_optimize[0]["params"]}
+                print(f"params_to_optimize_dtype: {params_to_optimize_dtype}")
                 optimizer.step()
                 lr_scheduler.step()
                 optimizer.zero_grad()
@@ -1621,16 +1648,16 @@ def compute_text_embeddings(prompt, text_encoders, tokenizers):
     # Save the lora layers
     accelerator.wait_for_everyone()
     if accelerator.is_main_process:
-        unet = accelerator.unwrap_model(unet)
-        unet = unet.to(torch.float32)
+        unet = unwrap_model(unet)
+        unet = unet.to(torch.float32) 
         unet_lora_layers = convert_state_dict_to_diffusers(get_peft_model_state_dict(unet))
 
         if args.train_text_encoder:
-            text_encoder_one = accelerator.unwrap_model(text_encoder_one)
+            text_encoder_one = unwrap_model(text_encoder_one)
             text_encoder_lora_layers = convert_state_dict_to_diffusers(
                 get_peft_model_state_dict(text_encoder_one.to(torch.float32))
             )
-            text_encoder_two = accelerator.unwrap_model(text_encoder_two)
+            text_encoder_two = unwrap_model(text_encoder_two)
             text_encoder_2_lora_layers = convert_state_dict_to_diffusers(
                 get_peft_model_state_dict(text_encoder_two.to(torch.float32))
             )
diff --git a/src/diffusers/loaders/lora.py b/src/diffusers/loaders/lora.py
index 424e95f0843e..50002cd0c209 100644
--- a/src/diffusers/loaders/lora.py
+++ b/src/diffusers/loaders/lora.py
@@ -373,7 +373,7 @@ def _optionally_disable_offloading(cls, _pipeline):
 
     @classmethod
     def load_lora_into_unet(
-        cls, state_dict, network_alphas, unet, low_cpu_mem_usage=None, adapter_name=None, _pipeline=None
+        cls, state_dict, network_alphas, unet, low_cpu_mem_usage=None, adapter_name=None, _config=None, _pipeline=None
     ):
         """
         This will load the LoRA layers specified in `state_dict` into `unet`.
@@ -446,8 +446,11 @@ def load_lora_into_unet(
                 if "lora_B" in key:
                     rank[key] = val.shape[1]
 
-            lora_config_kwargs = get_peft_kwargs(rank, network_alphas, state_dict, is_unet=True)
-            lora_config = LoraConfig(**lora_config_kwargs)
+            if _config is None:
+                lora_config_kwargs = get_peft_kwargs(rank, network_alphas, state_dict, is_unet=True)
+                lora_config = LoraConfig(**lora_config_kwargs)
+            else:
+                lora_config = _config
 
             # adapter_name
             if adapter_name is None:
@@ -490,6 +493,7 @@ def load_lora_into_text_encoder(
         lora_scale=1.0,
         low_cpu_mem_usage=None,
         adapter_name=None,
+        _config=None,
         _pipeline=None,
     ):
         """
@@ -578,11 +582,13 @@ def load_lora_into_text_encoder(
                 if USE_PEFT_BACKEND:
                     from peft import LoraConfig
 
-                    lora_config_kwargs = get_peft_kwargs(
-                        rank, network_alphas, text_encoder_lora_state_dict, is_unet=False
-                    )
-
-                    lora_config = LoraConfig(**lora_config_kwargs)
+                    if _config is None:
+                        lora_config_kwargs = get_peft_kwargs(
+                            rank, network_alphas, text_encoder_lora_state_dict, is_unet=False
+                        )
+                        lora_config = LoraConfig(**lora_config_kwargs)
+                    else:
+                        lora_config = _config
 
                     # adapter_name
                     if adapter_name is None:

From 3994545e3f471a88eb0e768d83a34c59dc4d53c9 Mon Sep 17 00:00:00 2001
From: sayakpaul <spsayakpaul@gmail.com>
Date: Wed, 10 Jan 2024 13:35:18 +0530
Subject: [PATCH 2/8] add: comment

---
 .../dreambooth/train_dreambooth_lora_sdxl.py  | 28 +++++++++++--------
 1 file changed, 16 insertions(+), 12 deletions(-)

diff --git a/examples/dreambooth/train_dreambooth_lora_sdxl.py b/examples/dreambooth/train_dreambooth_lora_sdxl.py
index 09be06f4ecb5..435c50ec38e1 100644
--- a/examples/dreambooth/train_dreambooth_lora_sdxl.py
+++ b/examples/dreambooth/train_dreambooth_lora_sdxl.py
@@ -56,8 +56,9 @@
 from diffusers.training_utils import compute_snr
 from diffusers.utils import check_min_version, convert_state_dict_to_diffusers, is_wandb_available
 from diffusers.utils.import_utils import is_xformers_available
-from diffusers.utils.torch_utils import is_compiled_module
 from diffusers.utils.peft_utils import delete_adapter_layers
+from diffusers.utils.torch_utils import is_compiled_module
+
 
 # Will error if the minimal version of diffusers is not installed. Remove at your own risks.
 check_min_version("0.26.0.dev0")
@@ -1000,7 +1001,7 @@ def main(args):
 
     def unwrap_model(model):
         model = accelerator.unwrap_model(model)
-        model = model._orig_mod if is_compiled_module(model) else model 
+        model = model._orig_mod if is_compiled_module(model) else model
         return model
 
     # create custom saving & loading hooks so that `accelerator.save_state(...)` serializes in a nice format
@@ -1065,17 +1066,24 @@ def load_model_hook(models, input_dir):
         if args.train_text_encoder:
             text_encoder_state_dict = {k: v for k, v in lora_state_dict.items() if "text_encoder." in k}
             LoraLoaderMixin.load_lora_into_text_encoder(
-                text_encoder_state_dict, network_alphas=network_alphas, text_encoder=text_encoder_one_, _config=text_lora_config
+                text_encoder_state_dict,
+                network_alphas=network_alphas,
+                text_encoder=text_encoder_one_,
+                _config=text_lora_config,
             )
             delete_adapter_layers(text_encoder_one_, "default_1")
-        
+
             text_encoder_2_state_dict = {k: v for k, v in lora_state_dict.items() if "text_encoder_2." in k}
             LoraLoaderMixin.load_lora_into_text_encoder(
-                text_encoder_2_state_dict, network_alphas=network_alphas, text_encoder=text_encoder_two_, _config=text_lora_config
+                text_encoder_2_state_dict,
+                network_alphas=network_alphas,
+                text_encoder=text_encoder_two_,
+                _config=text_lora_config,
             )
             delete_adapter_layers(text_encoder_two_, "default_1")
 
-        # Make sure the trainable params are in float32.
+        # Make sure the trainable params are in float32. This is again needed since the base models
+        # are in `weight_dtype`.
         if args.mixed_precision == "fp16":
             models = [unet_]
             if args.train_text_encoder:
@@ -1521,12 +1529,8 @@ def compute_text_embeddings(prompt, text_encoders, tokenizers):
                         if args.train_text_encoder
                         else unet_lora_parameters
                     )
-                    params_to_clip_dtype = {param.dtype for param in params_to_clip}
-                    print(f"params_to_clip_dtype: {params_to_clip_dtype}")
                     accelerator.clip_grad_norm_(params_to_clip, args.max_grad_norm)
-                
-                params_to_optimize_dtype = {param.dtype for param in params_to_optimize[0]["params"]}
-                print(f"params_to_optimize_dtype: {params_to_optimize_dtype}")
+
                 optimizer.step()
                 lr_scheduler.step()
                 optimizer.zero_grad()
@@ -1649,7 +1653,7 @@ def compute_text_embeddings(prompt, text_encoders, tokenizers):
     accelerator.wait_for_everyone()
     if accelerator.is_main_process:
         unet = unwrap_model(unet)
-        unet = unet.to(torch.float32) 
+        unet = unet.to(torch.float32)
         unet_lora_layers = convert_state_dict_to_diffusers(get_peft_model_state_dict(unet))
 
         if args.train_text_encoder:

From 356ef29639c66364ea2d7be5fbf0225736245e82 Mon Sep 17 00:00:00 2001
From: sayakpaul <spsayakpaul@gmail.com>
Date: Wed, 10 Jan 2024 13:38:50 +0530
Subject: [PATCH 3/8] remove residue from another branch.

---
 .../dreambooth/train_dreambooth_lora_sdxl.py  | 24 +++++++------------
 1 file changed, 9 insertions(+), 15 deletions(-)

diff --git a/examples/dreambooth/train_dreambooth_lora_sdxl.py b/examples/dreambooth/train_dreambooth_lora_sdxl.py
index 435c50ec38e1..732f75a55b83 100644
--- a/examples/dreambooth/train_dreambooth_lora_sdxl.py
+++ b/examples/dreambooth/train_dreambooth_lora_sdxl.py
@@ -57,7 +57,6 @@
 from diffusers.utils import check_min_version, convert_state_dict_to_diffusers, is_wandb_available
 from diffusers.utils.import_utils import is_xformers_available
 from diffusers.utils.peft_utils import delete_adapter_layers
-from diffusers.utils.torch_utils import is_compiled_module
 
 
 # Will error if the minimal version of diffusers is not installed. Remove at your own risks.
@@ -999,11 +998,6 @@ def main(args):
         text_encoder_one.add_adapter(text_lora_config)
         text_encoder_two.add_adapter(text_lora_config)
 
-    def unwrap_model(model):
-        model = accelerator.unwrap_model(model)
-        model = model._orig_mod if is_compiled_module(model) else model
-        return model
-
     # create custom saving & loading hooks so that `accelerator.save_state(...)` serializes in a nice format
     def save_model_hook(models, weights, output_dir):
         if accelerator.is_main_process:
@@ -1014,13 +1008,13 @@ def save_model_hook(models, weights, output_dir):
             text_encoder_two_lora_layers_to_save = None
 
             for model in models:
-                if isinstance(model, type(unwrap_model(unet))):
+                if isinstance(model, type(accelerator.unwrap_model(unet))):
                     unet_lora_layers_to_save = convert_state_dict_to_diffusers(get_peft_model_state_dict(model))
-                elif isinstance(model, type(unwrap_model(text_encoder_one))):
+                elif isinstance(model, type(accelerator.unwrap_model(text_encoder_one))):
                     text_encoder_one_lora_layers_to_save = convert_state_dict_to_diffusers(
                         get_peft_model_state_dict(model)
                     )
-                elif isinstance(model, type(unwrap_model(text_encoder_two))):
+                elif isinstance(model, type(accelerator.unwrap_model(text_encoder_two))):
                     text_encoder_two_lora_layers_to_save = convert_state_dict_to_diffusers(
                         get_peft_model_state_dict(model)
                     )
@@ -1045,11 +1039,11 @@ def load_model_hook(models, input_dir):
         while len(models) > 0:
             model = models.pop()
 
-            if isinstance(model, type(unwrap_model(unet))):
+            if isinstance(model, type(accelerator.unwrap_model(unet))):
                 unet_ = model
-            elif isinstance(model, type(unwrap_model(text_encoder_one))):
+            elif isinstance(model, type(accelerator.unwrap_model(text_encoder_one))):
                 text_encoder_one_ = model
-            elif isinstance(model, type(unwrap_model(text_encoder_two))):
+            elif isinstance(model, type(accelerator.unwrap_model(text_encoder_two))):
                 text_encoder_two_ = model
             else:
                 raise ValueError(f"unexpected save model: {model.__class__}")
@@ -1652,16 +1646,16 @@ def compute_text_embeddings(prompt, text_encoders, tokenizers):
     # Save the lora layers
     accelerator.wait_for_everyone()
     if accelerator.is_main_process:
-        unet = unwrap_model(unet)
+        unet = accelerator.unwrap_model(unet)
         unet = unet.to(torch.float32)
         unet_lora_layers = convert_state_dict_to_diffusers(get_peft_model_state_dict(unet))
 
         if args.train_text_encoder:
-            text_encoder_one = unwrap_model(text_encoder_one)
+            text_encoder_one = accelerator.unwrap_model(text_encoder_one)
             text_encoder_lora_layers = convert_state_dict_to_diffusers(
                 get_peft_model_state_dict(text_encoder_one.to(torch.float32))
             )
-            text_encoder_two = unwrap_model(text_encoder_two)
+            text_encoder_two = accelerator.unwrap_model(text_encoder_two)
             text_encoder_2_lora_layers = convert_state_dict_to_diffusers(
                 get_peft_model_state_dict(text_encoder_two.to(torch.float32))
             )

From 32ecbee9f42b48a07abb397b6776e78269d02353 Mon Sep 17 00:00:00 2001
From: sayakpaul <spsayakpaul@gmail.com>
Date: Wed, 10 Jan 2024 13:41:56 +0530
Subject: [PATCH 4/8] remove more residues.

---
 .../dreambooth/train_dreambooth_lora_sdxl.py     | 16 ++++++++++------
 1 file changed, 10 insertions(+), 6 deletions(-)

diff --git a/examples/dreambooth/train_dreambooth_lora_sdxl.py b/examples/dreambooth/train_dreambooth_lora_sdxl.py
index 732f75a55b83..6dd2837167e2 100644
--- a/examples/dreambooth/train_dreambooth_lora_sdxl.py
+++ b/examples/dreambooth/train_dreambooth_lora_sdxl.py
@@ -781,13 +781,12 @@ def encode_prompt(text_encoders, tokenizers, prompt, text_input_ids_list=None):
             text_input_ids = text_input_ids_list[i]
 
         prompt_embeds = text_encoder(
-            text_input_ids.to(text_encoder.device),
-            output_hidden_states=True,
+            text_input_ids.to(text_encoder.device), output_hidden_states=True, return_dict=False
         )
 
         # We are only ALWAYS interested in the pooled output of the final text encoder
         pooled_prompt_embeds = prompt_embeds[0]
-        prompt_embeds = prompt_embeds.hidden_states[-2]
+        prompt_embeds = prompt_embeds[-1][-2]
         bs_embed, seq_len, _ = prompt_embeds.shape
         prompt_embeds = prompt_embeds.view(bs_embed, seq_len, -1)
         prompt_embeds_list.append(prompt_embeds)
@@ -1457,7 +1456,8 @@ def compute_text_embeddings(prompt, text_encoders, tokenizers):
                         timesteps,
                         prompt_embeds_input,
                         added_cond_kwargs=unet_added_conditions,
-                    ).sample
+                        return_dict=False,
+                    )[0]
                 else:
                     unet_added_conditions = {"time_ids": add_time_ids.repeat(elems_to_repeat_time_ids, 1)}
                     prompt_embeds, pooled_prompt_embeds = encode_prompt(
@@ -1471,8 +1471,12 @@ def compute_text_embeddings(prompt, text_encoders, tokenizers):
                     )
                     prompt_embeds_input = prompt_embeds.repeat(elems_to_repeat_text_embeds, 1, 1)
                     model_pred = unet(
-                        noisy_model_input, timesteps, prompt_embeds_input, added_cond_kwargs=unet_added_conditions
-                    ).sample
+                        noisy_model_input,
+                        timesteps,
+                        prompt_embeds_input,
+                        added_cond_kwargs=unet_added_conditions,
+                        return_dict=False,
+                    )[0]
 
                 # Get the target for loss depending on the prediction type
                 if noise_scheduler.config.prediction_type == "epsilon":

From 03bef1db8a71fa3285f110ae0ff4621bd17cee6e Mon Sep 17 00:00:00 2001
From: sayakpaul <spsayakpaul@gmail.com>
Date: Wed, 10 Jan 2024 17:12:19 +0530
Subject: [PATCH 5/8] thanks to Younes; no hacks.

---
 .../dreambooth/train_dreambooth_lora_sdxl.py  | 56 +++++++++++--------
 src/diffusers/loaders/lora.py                 | 21 +++----
 2 files changed, 39 insertions(+), 38 deletions(-)

diff --git a/examples/dreambooth/train_dreambooth_lora_sdxl.py b/examples/dreambooth/train_dreambooth_lora_sdxl.py
index 6dd2837167e2..068ab1253969 100644
--- a/examples/dreambooth/train_dreambooth_lora_sdxl.py
+++ b/examples/dreambooth/train_dreambooth_lora_sdxl.py
@@ -34,7 +34,7 @@
 from huggingface_hub import create_repo, upload_folder
 from huggingface_hub.utils import insecure_hashlib
 from packaging import version
-from peft import LoraConfig
+from peft import LoraConfig, set_peft_model_state_dict
 from peft.utils import get_peft_model_state_dict
 from PIL import Image
 from PIL.ImageOps import exif_transpose
@@ -56,7 +56,7 @@
 from diffusers.training_utils import compute_snr
 from diffusers.utils import check_min_version, convert_state_dict_to_diffusers, is_wandb_available
 from diffusers.utils.import_utils import is_xformers_available
-from diffusers.utils.peft_utils import delete_adapter_layers
+from diffusers.utils.state_dict_utils import convert_state_dict_to_peft, convert_unet_state_dict_to_peft
 
 
 # Will error if the minimal version of diffusers is not installed. Remove at your own risks.
@@ -1048,32 +1048,40 @@ def load_model_hook(models, input_dir):
                 raise ValueError(f"unexpected save model: {model.__class__}")
 
         lora_state_dict, network_alphas = LoraLoaderMixin.lora_state_dict(input_dir)
-        # We pass the config here to ensure parity between `unet_lora_config` and
-        # the `LoraConfig` that's inferred in `load_lora_into_unet`.
-        LoraLoaderMixin.load_lora_into_unet(
-            lora_state_dict, network_alphas=network_alphas, unet=unet_, _config=unet_lora_config
-        )
-        # Remove the newly created adapter as we don't need it.
-        delete_adapter_layers(unet_, "default_1")
+
+        unet_state_dict = {f'{k.replace("unet.", "")}': v for k, v in lora_state_dict.items() if k.startswith("unet.")}
+        unet_state_dict = convert_unet_state_dict_to_peft(unet_state_dict)
+        incompatible_keys = set_peft_model_state_dict(unet_, unet_state_dict, adapter_name="default")
+        if incompatible_keys is not None:
+            # check only for unexpected keys
+            unexpected_keys = getattr(incompatible_keys, "unexpected_keys", None)
+            if unexpected_keys:
+                logger.warning(
+                    f"Loading adapter weights from state_dict led to unexpected keys not found in the model: "
+                    f" {unexpected_keys}. "
+                )
 
         if args.train_text_encoder:
-            text_encoder_state_dict = {k: v for k, v in lora_state_dict.items() if "text_encoder." in k}
-            LoraLoaderMixin.load_lora_into_text_encoder(
-                text_encoder_state_dict,
-                network_alphas=network_alphas,
-                text_encoder=text_encoder_one_,
-                _config=text_lora_config,
+            # Do we need to call `scale_lora_layers()` here?
+            text_encoder_state_dict = {
+                f'{k.replace("text_encoder.", "")}': v
+                for k, v in lora_state_dict.items()
+                if k.startswith("text_encoder.")
+            }
+            text_encoder_state_dict = convert_state_dict_to_peft(
+                convert_state_dict_to_diffusers(text_encoder_state_dict)
             )
-            delete_adapter_layers(text_encoder_one_, "default_1")
-
-            text_encoder_2_state_dict = {k: v for k, v in lora_state_dict.items() if "text_encoder_2." in k}
-            LoraLoaderMixin.load_lora_into_text_encoder(
-                text_encoder_2_state_dict,
-                network_alphas=network_alphas,
-                text_encoder=text_encoder_two_,
-                _config=text_lora_config,
+            set_peft_model_state_dict(text_encoder_one_, text_encoder_state_dict, adapter_name="default")
+
+            text_encoder_2_state_dict = {
+                f'{k.replace("text_encoder_2.", "")}': v
+                for k, v in lora_state_dict.items()
+                if k.startswith("text_encoder_2.")
+            }
+            text_encoder_2_state_dict = convert_state_dict_to_peft(
+                convert_state_dict_to_diffusers(text_encoder_2_state_dict)
             )
-            delete_adapter_layers(text_encoder_two_, "default_1")
+            set_peft_model_state_dict(text_encoder_two_, text_encoder_2_state_dict, adapter_name="default")
 
         # Make sure the trainable params are in float32. This is again needed since the base models
         # are in `weight_dtype`.
diff --git a/src/diffusers/loaders/lora.py b/src/diffusers/loaders/lora.py
index 50002cd0c209..7df91a05b593 100644
--- a/src/diffusers/loaders/lora.py
+++ b/src/diffusers/loaders/lora.py
@@ -373,7 +373,7 @@ def _optionally_disable_offloading(cls, _pipeline):
 
     @classmethod
     def load_lora_into_unet(
-        cls, state_dict, network_alphas, unet, low_cpu_mem_usage=None, adapter_name=None, _config=None, _pipeline=None
+        cls, state_dict, network_alphas, unet, low_cpu_mem_usage=None, adapter_name=None, _pipeline=None
     ):
         """
         This will load the LoRA layers specified in `state_dict` into `unet`.
@@ -446,11 +446,8 @@ def load_lora_into_unet(
                 if "lora_B" in key:
                     rank[key] = val.shape[1]
 
-            if _config is None:
-                lora_config_kwargs = get_peft_kwargs(rank, network_alphas, state_dict, is_unet=True)
-                lora_config = LoraConfig(**lora_config_kwargs)
-            else:
-                lora_config = _config
+            lora_config_kwargs = get_peft_kwargs(rank, network_alphas, state_dict, is_unet=True)
+            lora_config = LoraConfig(**lora_config_kwargs)
 
             # adapter_name
             if adapter_name is None:
@@ -493,7 +490,6 @@ def load_lora_into_text_encoder(
         lora_scale=1.0,
         low_cpu_mem_usage=None,
         adapter_name=None,
-        _config=None,
         _pipeline=None,
     ):
         """
@@ -582,13 +578,10 @@ def load_lora_into_text_encoder(
                 if USE_PEFT_BACKEND:
                     from peft import LoraConfig
 
-                    if _config is None:
-                        lora_config_kwargs = get_peft_kwargs(
-                            rank, network_alphas, text_encoder_lora_state_dict, is_unet=False
-                        )
-                        lora_config = LoraConfig(**lora_config_kwargs)
-                    else:
-                        lora_config = _config
+                    lora_config_kwargs = get_peft_kwargs(
+                        rank, network_alphas, text_encoder_lora_state_dict, is_unet=False
+                    )
+                    lora_config = LoraConfig(**lora_config_kwargs)
 
                     # adapter_name
                     if adapter_name is None:

From 820522fbdb44c674904c355edaaa0d7de13b2c06 Mon Sep 17 00:00:00 2001
From: sayakpaul <spsayakpaul@gmail.com>
Date: Fri, 12 Jan 2024 09:56:23 +0530
Subject: [PATCH 6/8] style.

---
 examples/dreambooth/train_dreambooth_lora_sdxl.py | 1 -
 1 file changed, 1 deletion(-)

diff --git a/examples/dreambooth/train_dreambooth_lora_sdxl.py b/examples/dreambooth/train_dreambooth_lora_sdxl.py
index 9cd7770302cc..c740b4d303fc 100644
--- a/examples/dreambooth/train_dreambooth_lora_sdxl.py
+++ b/examples/dreambooth/train_dreambooth_lora_sdxl.py
@@ -150,7 +150,6 @@ def import_model_class_from_model_name_or_path(
         return CLIPTextModelWithProjection
     else:
         raise ValueError(f"{model_class} is not supported.")
-    
 
 
 def parse_args(input_args=None):

From 85e6b6bc6cba54b8bcbeaadf76a1bb1aee0caa28 Mon Sep 17 00:00:00 2001
From: sayakpaul <spsayakpaul@gmail.com>
Date: Fri, 12 Jan 2024 10:11:08 +0530
Subject: [PATCH 7/8] clean things a bit and modularize
 _set_state_dict_into_text_encoder

---
 .../dreambooth/train_dreambooth_lora_sdxl.py  | 32 +++++++------------
 src/diffusers/training_utils.py               | 30 ++++++++++++++++-
 2 files changed, 40 insertions(+), 22 deletions(-)

diff --git a/examples/dreambooth/train_dreambooth_lora_sdxl.py b/examples/dreambooth/train_dreambooth_lora_sdxl.py
index c740b4d303fc..87bf4e9da2c4 100644
--- a/examples/dreambooth/train_dreambooth_lora_sdxl.py
+++ b/examples/dreambooth/train_dreambooth_lora_sdxl.py
@@ -53,10 +53,14 @@
 )
 from diffusers.loaders import LoraLoaderMixin
 from diffusers.optimization import get_scheduler
-from diffusers.training_utils import compute_snr
-from diffusers.utils import check_min_version, convert_state_dict_to_diffusers, is_wandb_available
+from diffusers.training_utils import _set_state_dict_into_text_encoder, compute_snr
+from diffusers.utils import (
+    check_min_version,
+    convert_state_dict_to_diffusers,
+    convert_unet_state_dict_to_peft,
+    is_wandb_available,
+)
 from diffusers.utils.import_utils import is_xformers_available
-from diffusers.utils.state_dict_utils import convert_state_dict_to_peft, convert_unet_state_dict_to_peft
 from diffusers.utils.torch_utils import is_compiled_module
 
 
@@ -1069,25 +1073,11 @@ def load_model_hook(models, input_dir):
 
         if args.train_text_encoder:
             # Do we need to call `scale_lora_layers()` here?
-            text_encoder_state_dict = {
-                f'{k.replace("text_encoder.", "")}': v
-                for k, v in lora_state_dict.items()
-                if k.startswith("text_encoder.")
-            }
-            text_encoder_state_dict = convert_state_dict_to_peft(
-                convert_state_dict_to_diffusers(text_encoder_state_dict)
-            )
-            set_peft_model_state_dict(text_encoder_one_, text_encoder_state_dict, adapter_name="default")
-
-            text_encoder_2_state_dict = {
-                f'{k.replace("text_encoder_2.", "")}': v
-                for k, v in lora_state_dict.items()
-                if k.startswith("text_encoder_2.")
-            }
-            text_encoder_2_state_dict = convert_state_dict_to_peft(
-                convert_state_dict_to_diffusers(text_encoder_2_state_dict)
+            _set_state_dict_into_text_encoder(lora_state_dict, prefix="text_encoder.", text_encoder=text_encoder_one_)
+
+            _set_state_dict_into_text_encoder(
+                lora_state_dict, prefix="text_encoder_2.", text_encoder=text_encoder_one_
             )
-            set_peft_model_state_dict(text_encoder_two_, text_encoder_2_state_dict, adapter_name="default")
 
         # Make sure the trainable params are in float32. This is again needed since the base models
         # are in `weight_dtype`.
diff --git a/src/diffusers/training_utils.py b/src/diffusers/training_utils.py
index 9fb6fad3a267..8ff904305242 100644
--- a/src/diffusers/training_utils.py
+++ b/src/diffusers/training_utils.py
@@ -8,12 +8,21 @@
 from torchvision import transforms
 
 from .models import UNet2DConditionModel
-from .utils import deprecate, is_transformers_available
+from .utils import (
+    convert_state_dict_to_diffusers,
+    convert_state_dict_to_peft,
+    deprecate,
+    is_peft_available,
+    is_transformers_available,
+)
 
 
 if is_transformers_available():
     import transformers
 
+if is_peft_available():
+    from peft import set_peft_model_state_dict
+
 
 def set_seed(seed: int):
     """
@@ -112,6 +121,25 @@ def unet_lora_state_dict(unet: UNet2DConditionModel) -> Dict[str, torch.Tensor]:
     return lora_state_dict
 
 
+def _set_state_dict_into_text_encoder(
+    lora_state_dict: Dict[str, torch.Tensor], prefix: str, text_encoder: torch.nn.Module
+):
+    """
+    Sets the `lora_state_dict` into `text_encoder` coming from `transformers`.
+
+    Args:
+        lora_state_dict: The state dictionary to be set.
+        prefix: String identifier to retrieve the portion of the state dict that belongs to `text_encoder`.
+        text_encoder: Where the `lora_state_dict` is to be set.
+    """
+
+    text_encoder_state_dict = {
+        f'{k.replace(prefix, "")}': v for k, v in lora_state_dict.items() if k.startswith(prefix)
+    }
+    text_encoder_state_dict = convert_state_dict_to_peft(convert_state_dict_to_diffusers(text_encoder_state_dict))
+    set_peft_model_state_dict(text_encoder, text_encoder_state_dict, adapter_name="default")
+
+
 # Adapted from torch-ema https://github.com/fadel/pytorch_ema/blob/master/torch_ema/ema.py#L14
 class EMAModel:
     """

From e8f1d38641f0b5c759432d15e0b7493608c39d1a Mon Sep 17 00:00:00 2001
From: sayakpaul <spsayakpaul@gmail.com>
Date: Fri, 12 Jan 2024 16:53:53 +0530
Subject: [PATCH 8/8] add comment about the fix detailed.

---
 examples/dreambooth/train_dreambooth_lora_sdxl.py | 3 ++-
 1 file changed, 2 insertions(+), 1 deletion(-)

diff --git a/examples/dreambooth/train_dreambooth_lora_sdxl.py b/examples/dreambooth/train_dreambooth_lora_sdxl.py
index 87bf4e9da2c4..c59036d13beb 100644
--- a/examples/dreambooth/train_dreambooth_lora_sdxl.py
+++ b/examples/dreambooth/train_dreambooth_lora_sdxl.py
@@ -1080,7 +1080,8 @@ def load_model_hook(models, input_dir):
             )
 
         # Make sure the trainable params are in float32. This is again needed since the base models
-        # are in `weight_dtype`.
+        # are in `weight_dtype`. More details:
+        # https://github.com/huggingface/diffusers/pull/6514#discussion_r1449796804
         if args.mixed_precision == "fp16":
             models = [unet_]
             if args.train_text_encoder: