diff --git a/examples/dreambooth/train_dreambooth_lora_sdxl.py b/examples/dreambooth/train_dreambooth_lora_sdxl.py
index 8f92e3b44295..c59036d13beb 100644
--- a/examples/dreambooth/train_dreambooth_lora_sdxl.py
+++ b/examples/dreambooth/train_dreambooth_lora_sdxl.py
@@ -34,7 +34,7 @@
 from huggingface_hub import create_repo, upload_folder
 from huggingface_hub.utils import insecure_hashlib
 from packaging import version
-from peft import LoraConfig
+from peft import LoraConfig, set_peft_model_state_dict
 from peft.utils import get_peft_model_state_dict
 from PIL import Image
 from PIL.ImageOps import exif_transpose
@@ -53,8 +53,13 @@
 )
 from diffusers.loaders import LoraLoaderMixin
 from diffusers.optimization import get_scheduler
-from diffusers.training_utils import compute_snr
-from diffusers.utils import check_min_version, convert_state_dict_to_diffusers, is_wandb_available
+from diffusers.training_utils import _set_state_dict_into_text_encoder, compute_snr
+from diffusers.utils import (
+    check_min_version,
+    convert_state_dict_to_diffusers,
+    convert_unet_state_dict_to_peft,
+    is_wandb_available,
+)
 from diffusers.utils.import_utils import is_xformers_available
 from diffusers.utils.torch_utils import is_compiled_module
 
@@ -997,17 +1002,6 @@ def main(args):
         text_encoder_one.add_adapter(text_lora_config)
         text_encoder_two.add_adapter(text_lora_config)
 
-    # Make sure the trainable params are in float32.
-    if args.mixed_precision == "fp16":
-        models = [unet]
-        if args.train_text_encoder:
-            models.extend([text_encoder_one, text_encoder_two])
-        for model in models:
-            for param in model.parameters():
-                # only upcast trainable parameters (LoRA) into fp32
-                if param.requires_grad:
-                    param.data = param.to(torch.float32)
-
     def unwrap_model(model):
         model = accelerator.unwrap_model(model)
         model = model._orig_mod if is_compiled_module(model) else model
@@ -1064,17 +1058,39 @@ def load_model_hook(models, input_dir):
                 raise ValueError(f"unexpected save model: {model.__class__}")
 
         lora_state_dict, network_alphas = LoraLoaderMixin.lora_state_dict(input_dir)
-        LoraLoaderMixin.load_lora_into_unet(lora_state_dict, network_alphas=network_alphas, unet=unet_)
 
-        text_encoder_state_dict = {k: v for k, v in lora_state_dict.items() if "text_encoder." in k}
-        LoraLoaderMixin.load_lora_into_text_encoder(
-            text_encoder_state_dict, network_alphas=network_alphas, text_encoder=text_encoder_one_
-        )
+        unet_state_dict = {f'{k.replace("unet.", "")}': v for k, v in lora_state_dict.items() if k.startswith("unet.")}
+        unet_state_dict = convert_unet_state_dict_to_peft(unet_state_dict)
+        incompatible_keys = set_peft_model_state_dict(unet_, unet_state_dict, adapter_name="default")
+        if incompatible_keys is not None:
+            # check only for unexpected keys
+            unexpected_keys = getattr(incompatible_keys, "unexpected_keys", None)
+            if unexpected_keys:
+                logger.warning(
+                    f"Loading adapter weights from state_dict led to unexpected keys not found in the model: "
+                    f" {unexpected_keys}. "
+                )
 
-        text_encoder_2_state_dict = {k: v for k, v in lora_state_dict.items() if "text_encoder_2." in k}
-        LoraLoaderMixin.load_lora_into_text_encoder(
-            text_encoder_2_state_dict, network_alphas=network_alphas, text_encoder=text_encoder_two_
-        )
+        if args.train_text_encoder:
+            # Do we need to call `scale_lora_layers()` here?
+            _set_state_dict_into_text_encoder(lora_state_dict, prefix="text_encoder.", text_encoder=text_encoder_one_)
+
+            _set_state_dict_into_text_encoder(
+                lora_state_dict, prefix="text_encoder_2.", text_encoder=text_encoder_one_
+            )
+
+        # Make sure the trainable params are in float32. This is again needed since the base models
+        # are in `weight_dtype`. More details:
+        # https://github.com/huggingface/diffusers/pull/6514#discussion_r1449796804
+        if args.mixed_precision == "fp16":
+            models = [unet_]
+            if args.train_text_encoder:
+                models.extend([text_encoder_one_, text_encoder_two_])
+            for model in models:
+                for param in model.parameters():
+                    # only upcast trainable parameters (LoRA) into fp32
+                    if param.requires_grad:
+                        param.data = param.to(torch.float32)
 
     accelerator.register_save_state_pre_hook(save_model_hook)
     accelerator.register_load_state_pre_hook(load_model_hook)
@@ -1089,6 +1105,17 @@ def load_model_hook(models, input_dir):
             args.learning_rate * args.gradient_accumulation_steps * args.train_batch_size * accelerator.num_processes
         )
 
+    # Make sure the trainable params are in float32.
+    if args.mixed_precision == "fp16":
+        models = [unet]
+        if args.train_text_encoder:
+            models.extend([text_encoder_one, text_encoder_two])
+        for model in models:
+            for param in model.parameters():
+                # only upcast trainable parameters (LoRA) into fp32
+                if param.requires_grad:
+                    param.data = param.to(torch.float32)
+
     unet_lora_parameters = list(filter(lambda p: p.requires_grad, unet.parameters()))
 
     if args.train_text_encoder:
@@ -1506,6 +1533,7 @@ def compute_text_embeddings(prompt, text_encoders, tokenizers):
                         else unet_lora_parameters
                     )
                     accelerator.clip_grad_norm_(params_to_clip, args.max_grad_norm)
+
                 optimizer.step()
                 lr_scheduler.step()
                 optimizer.zero_grad()
diff --git a/src/diffusers/loaders/lora.py b/src/diffusers/loaders/lora.py
index 424e95f0843e..7df91a05b593 100644
--- a/src/diffusers/loaders/lora.py
+++ b/src/diffusers/loaders/lora.py
@@ -581,7 +581,6 @@ def load_lora_into_text_encoder(
                     lora_config_kwargs = get_peft_kwargs(
                         rank, network_alphas, text_encoder_lora_state_dict, is_unet=False
                     )
-
                     lora_config = LoraConfig(**lora_config_kwargs)
 
                     # adapter_name
diff --git a/src/diffusers/training_utils.py b/src/diffusers/training_utils.py
index 9fb6fad3a267..8ff904305242 100644
--- a/src/diffusers/training_utils.py
+++ b/src/diffusers/training_utils.py
@@ -8,12 +8,21 @@
 from torchvision import transforms
 
 from .models import UNet2DConditionModel
-from .utils import deprecate, is_transformers_available
+from .utils import (
+    convert_state_dict_to_diffusers,
+    convert_state_dict_to_peft,
+    deprecate,
+    is_peft_available,
+    is_transformers_available,
+)
 
 
 if is_transformers_available():
     import transformers
 
+if is_peft_available():
+    from peft import set_peft_model_state_dict
+
 
 def set_seed(seed: int):
     """
@@ -112,6 +121,25 @@ def unet_lora_state_dict(unet: UNet2DConditionModel) -> Dict[str, torch.Tensor]:
     return lora_state_dict
 
 
+def _set_state_dict_into_text_encoder(
+    lora_state_dict: Dict[str, torch.Tensor], prefix: str, text_encoder: torch.nn.Module
+):
+    """
+    Sets the `lora_state_dict` into `text_encoder` coming from `transformers`.
+
+    Args:
+        lora_state_dict: The state dictionary to be set.
+        prefix: String identifier to retrieve the portion of the state dict that belongs to `text_encoder`.
+        text_encoder: Where the `lora_state_dict` is to be set.
+    """
+
+    text_encoder_state_dict = {
+        f'{k.replace(prefix, "")}': v for k, v in lora_state_dict.items() if k.startswith(prefix)
+    }
+    text_encoder_state_dict = convert_state_dict_to_peft(convert_state_dict_to_diffusers(text_encoder_state_dict))
+    set_peft_model_state_dict(text_encoder, text_encoder_state_dict, adapter_name="default")
+
+
 # Adapted from torch-ema https://github.com/fadel/pytorch_ema/blob/master/torch_ema/ema.py#L14
 class EMAModel:
     """