Skip to content

Commit a663a67

Browse files
sayakpaulDN6
authored andcommitted
[LoRA] clean up load_lora_into_text_encoder() and fuse_lora() copied from (#10495)
* factor out text encoder loading. * make fix-copies * remove copied from fuse_lora and unfuse_lora as needed. * remove unused imports
1 parent 526858c commit a663a67

File tree

4 files changed

+231
-679
lines changed

4 files changed

+231
-679
lines changed

src/diffusers/loaders/lora_base.py

+156-21
Original file line numberDiff line numberDiff line change
@@ -28,13 +28,20 @@
2828
from ..utils import (
2929
USE_PEFT_BACKEND,
3030
_get_model_file,
31+
convert_state_dict_to_diffusers,
32+
convert_state_dict_to_peft,
3133
delete_adapter_layers,
3234
deprecate,
35+
get_adapter_name,
36+
get_peft_kwargs,
3337
is_accelerate_available,
3438
is_peft_available,
39+
is_peft_version,
3540
is_transformers_available,
41+
is_transformers_version,
3642
logging,
3743
recurse_remove_peft_layers,
44+
scale_lora_layers,
3845
set_adapter_layers,
3946
set_weights_and_activate_adapters,
4047
)
@@ -43,6 +50,8 @@
4350
if is_transformers_available():
4451
from transformers import PreTrainedModel
4552

53+
from ..models.lora import text_encoder_attn_modules, text_encoder_mlp_modules
54+
4655
if is_peft_available():
4756
from peft.tuners.tuners_utils import BaseTunerLayer
4857

@@ -297,6 +306,152 @@ def _best_guess_weight_name(
297306
return weight_name
298307

299308

309+
def _load_lora_into_text_encoder(
310+
state_dict,
311+
network_alphas,
312+
text_encoder,
313+
prefix=None,
314+
lora_scale=1.0,
315+
text_encoder_name="text_encoder",
316+
adapter_name=None,
317+
_pipeline=None,
318+
low_cpu_mem_usage=False,
319+
):
320+
if not USE_PEFT_BACKEND:
321+
raise ValueError("PEFT backend is required for this method.")
322+
323+
peft_kwargs = {}
324+
if low_cpu_mem_usage:
325+
if not is_peft_version(">=", "0.13.1"):
326+
raise ValueError(
327+
"`low_cpu_mem_usage=True` is not compatible with this `peft` version. Please update it with `pip install -U peft`."
328+
)
329+
if not is_transformers_version(">", "4.45.2"):
330+
# Note from sayakpaul: It's not in `transformers` stable yet.
331+
# https://github.com/huggingface/transformers/pull/33725/
332+
raise ValueError(
333+
"`low_cpu_mem_usage=True` is not compatible with this `transformers` version. Please update it with `pip install -U transformers`."
334+
)
335+
peft_kwargs["low_cpu_mem_usage"] = low_cpu_mem_usage
336+
337+
from peft import LoraConfig
338+
339+
# If the serialization format is new (introduced in https://github.com/huggingface/diffusers/pull/2918),
340+
# then the `state_dict` keys should have `unet_name` and/or `text_encoder_name` as
341+
# their prefixes.
342+
keys = list(state_dict.keys())
343+
prefix = text_encoder_name if prefix is None else prefix
344+
345+
# Safe prefix to check with.
346+
if any(text_encoder_name in key for key in keys):
347+
# Load the layers corresponding to text encoder and make necessary adjustments.
348+
text_encoder_keys = [k for k in keys if k.startswith(prefix) and k.split(".")[0] == prefix]
349+
text_encoder_lora_state_dict = {
350+
k.replace(f"{prefix}.", ""): v for k, v in state_dict.items() if k in text_encoder_keys
351+
}
352+
353+
if len(text_encoder_lora_state_dict) > 0:
354+
logger.info(f"Loading {prefix}.")
355+
rank = {}
356+
text_encoder_lora_state_dict = convert_state_dict_to_diffusers(text_encoder_lora_state_dict)
357+
358+
# convert state dict
359+
text_encoder_lora_state_dict = convert_state_dict_to_peft(text_encoder_lora_state_dict)
360+
361+
for name, _ in text_encoder_attn_modules(text_encoder):
362+
for module in ("out_proj", "q_proj", "k_proj", "v_proj"):
363+
rank_key = f"{name}.{module}.lora_B.weight"
364+
if rank_key not in text_encoder_lora_state_dict:
365+
continue
366+
rank[rank_key] = text_encoder_lora_state_dict[rank_key].shape[1]
367+
368+
for name, _ in text_encoder_mlp_modules(text_encoder):
369+
for module in ("fc1", "fc2"):
370+
rank_key = f"{name}.{module}.lora_B.weight"
371+
if rank_key not in text_encoder_lora_state_dict:
372+
continue
373+
rank[rank_key] = text_encoder_lora_state_dict[rank_key].shape[1]
374+
375+
if network_alphas is not None:
376+
alpha_keys = [k for k in network_alphas.keys() if k.startswith(prefix) and k.split(".")[0] == prefix]
377+
network_alphas = {k.replace(f"{prefix}.", ""): v for k, v in network_alphas.items() if k in alpha_keys}
378+
379+
lora_config_kwargs = get_peft_kwargs(rank, network_alphas, text_encoder_lora_state_dict, is_unet=False)
380+
381+
if "use_dora" in lora_config_kwargs:
382+
if lora_config_kwargs["use_dora"]:
383+
if is_peft_version("<", "0.9.0"):
384+
raise ValueError(
385+
"You need `peft` 0.9.0 at least to use DoRA-enabled LoRAs. Please upgrade your installation of `peft`."
386+
)
387+
else:
388+
if is_peft_version("<", "0.9.0"):
389+
lora_config_kwargs.pop("use_dora")
390+
391+
if "lora_bias" in lora_config_kwargs:
392+
if lora_config_kwargs["lora_bias"]:
393+
if is_peft_version("<=", "0.13.2"):
394+
raise ValueError(
395+
"You need `peft` 0.14.0 at least to use `bias` in LoRAs. Please upgrade your installation of `peft`."
396+
)
397+
else:
398+
if is_peft_version("<=", "0.13.2"):
399+
lora_config_kwargs.pop("lora_bias")
400+
401+
lora_config = LoraConfig(**lora_config_kwargs)
402+
403+
# adapter_name
404+
if adapter_name is None:
405+
adapter_name = get_adapter_name(text_encoder)
406+
407+
is_model_cpu_offload, is_sequential_cpu_offload = _func_optionally_disable_offloading(_pipeline)
408+
409+
# inject LoRA layers and load the state dict
410+
# in transformers we automatically check whether the adapter name is already in use or not
411+
text_encoder.load_adapter(
412+
adapter_name=adapter_name,
413+
adapter_state_dict=text_encoder_lora_state_dict,
414+
peft_config=lora_config,
415+
**peft_kwargs,
416+
)
417+
418+
# scale LoRA layers with `lora_scale`
419+
scale_lora_layers(text_encoder, weight=lora_scale)
420+
421+
text_encoder.to(device=text_encoder.device, dtype=text_encoder.dtype)
422+
423+
# Offload back.
424+
if is_model_cpu_offload:
425+
_pipeline.enable_model_cpu_offload()
426+
elif is_sequential_cpu_offload:
427+
_pipeline.enable_sequential_cpu_offload()
428+
# Unsafe code />
429+
430+
431+
def _func_optionally_disable_offloading(_pipeline):
432+
is_model_cpu_offload = False
433+
is_sequential_cpu_offload = False
434+
435+
if _pipeline is not None and _pipeline.hf_device_map is None:
436+
for _, component in _pipeline.components.items():
437+
if isinstance(component, nn.Module) and hasattr(component, "_hf_hook"):
438+
if not is_model_cpu_offload:
439+
is_model_cpu_offload = isinstance(component._hf_hook, CpuOffload)
440+
if not is_sequential_cpu_offload:
441+
is_sequential_cpu_offload = (
442+
isinstance(component._hf_hook, AlignDevicesHook)
443+
or hasattr(component._hf_hook, "hooks")
444+
and isinstance(component._hf_hook.hooks[0], AlignDevicesHook)
445+
)
446+
447+
logger.info(
448+
"Accelerate hooks detected. Since you have called `load_lora_weights()`, the previous hooks will be first removed. Then the LoRA parameters will be loaded and the hooks will be applied again."
449+
)
450+
remove_hook_from_module(component, recurse=is_sequential_cpu_offload)
451+
452+
return (is_model_cpu_offload, is_sequential_cpu_offload)
453+
454+
300455
class LoraBaseMixin:
301456
"""Utility class for handling LoRAs."""
302457

@@ -327,27 +482,7 @@ def _optionally_disable_offloading(cls, _pipeline):
327482
tuple:
328483
A tuple indicating if `is_model_cpu_offload` or `is_sequential_cpu_offload` is True.
329484
"""
330-
is_model_cpu_offload = False
331-
is_sequential_cpu_offload = False
332-
333-
if _pipeline is not None and _pipeline.hf_device_map is None:
334-
for _, component in _pipeline.components.items():
335-
if isinstance(component, nn.Module) and hasattr(component, "_hf_hook"):
336-
if not is_model_cpu_offload:
337-
is_model_cpu_offload = isinstance(component._hf_hook, CpuOffload)
338-
if not is_sequential_cpu_offload:
339-
is_sequential_cpu_offload = (
340-
isinstance(component._hf_hook, AlignDevicesHook)
341-
or hasattr(component._hf_hook, "hooks")
342-
and isinstance(component._hf_hook.hooks[0], AlignDevicesHook)
343-
)
344-
345-
logger.info(
346-
"Accelerate hooks detected. Since you have called `load_lora_weights()`, the previous hooks will be first removed. Then the LoRA parameters will be loaded and the hooks will be applied again."
347-
)
348-
remove_hook_from_module(component, recurse=is_sequential_cpu_offload)
349-
350-
return (is_model_cpu_offload, is_sequential_cpu_offload)
485+
return _func_optionally_disable_offloading(_pipeline=_pipeline)
351486

352487
@classmethod
353488
def _fetch_state_dict(cls, *args, **kwargs):

0 commit comments

Comments
 (0)