|
28 | 28 | from ..utils import (
|
29 | 29 | USE_PEFT_BACKEND,
|
30 | 30 | _get_model_file,
|
| 31 | + convert_state_dict_to_diffusers, |
| 32 | + convert_state_dict_to_peft, |
31 | 33 | delete_adapter_layers,
|
32 | 34 | deprecate,
|
| 35 | + get_adapter_name, |
| 36 | + get_peft_kwargs, |
33 | 37 | is_accelerate_available,
|
34 | 38 | is_peft_available,
|
| 39 | + is_peft_version, |
35 | 40 | is_transformers_available,
|
| 41 | + is_transformers_version, |
36 | 42 | logging,
|
37 | 43 | recurse_remove_peft_layers,
|
| 44 | + scale_lora_layers, |
38 | 45 | set_adapter_layers,
|
39 | 46 | set_weights_and_activate_adapters,
|
40 | 47 | )
|
|
43 | 50 | if is_transformers_available():
|
44 | 51 | from transformers import PreTrainedModel
|
45 | 52 |
|
| 53 | + from ..models.lora import text_encoder_attn_modules, text_encoder_mlp_modules |
| 54 | + |
46 | 55 | if is_peft_available():
|
47 | 56 | from peft.tuners.tuners_utils import BaseTunerLayer
|
48 | 57 |
|
@@ -297,6 +306,152 @@ def _best_guess_weight_name(
|
297 | 306 | return weight_name
|
298 | 307 |
|
299 | 308 |
|
| 309 | +def _load_lora_into_text_encoder( |
| 310 | + state_dict, |
| 311 | + network_alphas, |
| 312 | + text_encoder, |
| 313 | + prefix=None, |
| 314 | + lora_scale=1.0, |
| 315 | + text_encoder_name="text_encoder", |
| 316 | + adapter_name=None, |
| 317 | + _pipeline=None, |
| 318 | + low_cpu_mem_usage=False, |
| 319 | +): |
| 320 | + if not USE_PEFT_BACKEND: |
| 321 | + raise ValueError("PEFT backend is required for this method.") |
| 322 | + |
| 323 | + peft_kwargs = {} |
| 324 | + if low_cpu_mem_usage: |
| 325 | + if not is_peft_version(">=", "0.13.1"): |
| 326 | + raise ValueError( |
| 327 | + "`low_cpu_mem_usage=True` is not compatible with this `peft` version. Please update it with `pip install -U peft`." |
| 328 | + ) |
| 329 | + if not is_transformers_version(">", "4.45.2"): |
| 330 | + # Note from sayakpaul: It's not in `transformers` stable yet. |
| 331 | + # https://github.com/huggingface/transformers/pull/33725/ |
| 332 | + raise ValueError( |
| 333 | + "`low_cpu_mem_usage=True` is not compatible with this `transformers` version. Please update it with `pip install -U transformers`." |
| 334 | + ) |
| 335 | + peft_kwargs["low_cpu_mem_usage"] = low_cpu_mem_usage |
| 336 | + |
| 337 | + from peft import LoraConfig |
| 338 | + |
| 339 | + # If the serialization format is new (introduced in https://github.com/huggingface/diffusers/pull/2918), |
| 340 | + # then the `state_dict` keys should have `unet_name` and/or `text_encoder_name` as |
| 341 | + # their prefixes. |
| 342 | + keys = list(state_dict.keys()) |
| 343 | + prefix = text_encoder_name if prefix is None else prefix |
| 344 | + |
| 345 | + # Safe prefix to check with. |
| 346 | + if any(text_encoder_name in key for key in keys): |
| 347 | + # Load the layers corresponding to text encoder and make necessary adjustments. |
| 348 | + text_encoder_keys = [k for k in keys if k.startswith(prefix) and k.split(".")[0] == prefix] |
| 349 | + text_encoder_lora_state_dict = { |
| 350 | + k.replace(f"{prefix}.", ""): v for k, v in state_dict.items() if k in text_encoder_keys |
| 351 | + } |
| 352 | + |
| 353 | + if len(text_encoder_lora_state_dict) > 0: |
| 354 | + logger.info(f"Loading {prefix}.") |
| 355 | + rank = {} |
| 356 | + text_encoder_lora_state_dict = convert_state_dict_to_diffusers(text_encoder_lora_state_dict) |
| 357 | + |
| 358 | + # convert state dict |
| 359 | + text_encoder_lora_state_dict = convert_state_dict_to_peft(text_encoder_lora_state_dict) |
| 360 | + |
| 361 | + for name, _ in text_encoder_attn_modules(text_encoder): |
| 362 | + for module in ("out_proj", "q_proj", "k_proj", "v_proj"): |
| 363 | + rank_key = f"{name}.{module}.lora_B.weight" |
| 364 | + if rank_key not in text_encoder_lora_state_dict: |
| 365 | + continue |
| 366 | + rank[rank_key] = text_encoder_lora_state_dict[rank_key].shape[1] |
| 367 | + |
| 368 | + for name, _ in text_encoder_mlp_modules(text_encoder): |
| 369 | + for module in ("fc1", "fc2"): |
| 370 | + rank_key = f"{name}.{module}.lora_B.weight" |
| 371 | + if rank_key not in text_encoder_lora_state_dict: |
| 372 | + continue |
| 373 | + rank[rank_key] = text_encoder_lora_state_dict[rank_key].shape[1] |
| 374 | + |
| 375 | + if network_alphas is not None: |
| 376 | + alpha_keys = [k for k in network_alphas.keys() if k.startswith(prefix) and k.split(".")[0] == prefix] |
| 377 | + network_alphas = {k.replace(f"{prefix}.", ""): v for k, v in network_alphas.items() if k in alpha_keys} |
| 378 | + |
| 379 | + lora_config_kwargs = get_peft_kwargs(rank, network_alphas, text_encoder_lora_state_dict, is_unet=False) |
| 380 | + |
| 381 | + if "use_dora" in lora_config_kwargs: |
| 382 | + if lora_config_kwargs["use_dora"]: |
| 383 | + if is_peft_version("<", "0.9.0"): |
| 384 | + raise ValueError( |
| 385 | + "You need `peft` 0.9.0 at least to use DoRA-enabled LoRAs. Please upgrade your installation of `peft`." |
| 386 | + ) |
| 387 | + else: |
| 388 | + if is_peft_version("<", "0.9.0"): |
| 389 | + lora_config_kwargs.pop("use_dora") |
| 390 | + |
| 391 | + if "lora_bias" in lora_config_kwargs: |
| 392 | + if lora_config_kwargs["lora_bias"]: |
| 393 | + if is_peft_version("<=", "0.13.2"): |
| 394 | + raise ValueError( |
| 395 | + "You need `peft` 0.14.0 at least to use `bias` in LoRAs. Please upgrade your installation of `peft`." |
| 396 | + ) |
| 397 | + else: |
| 398 | + if is_peft_version("<=", "0.13.2"): |
| 399 | + lora_config_kwargs.pop("lora_bias") |
| 400 | + |
| 401 | + lora_config = LoraConfig(**lora_config_kwargs) |
| 402 | + |
| 403 | + # adapter_name |
| 404 | + if adapter_name is None: |
| 405 | + adapter_name = get_adapter_name(text_encoder) |
| 406 | + |
| 407 | + is_model_cpu_offload, is_sequential_cpu_offload = _func_optionally_disable_offloading(_pipeline) |
| 408 | + |
| 409 | + # inject LoRA layers and load the state dict |
| 410 | + # in transformers we automatically check whether the adapter name is already in use or not |
| 411 | + text_encoder.load_adapter( |
| 412 | + adapter_name=adapter_name, |
| 413 | + adapter_state_dict=text_encoder_lora_state_dict, |
| 414 | + peft_config=lora_config, |
| 415 | + **peft_kwargs, |
| 416 | + ) |
| 417 | + |
| 418 | + # scale LoRA layers with `lora_scale` |
| 419 | + scale_lora_layers(text_encoder, weight=lora_scale) |
| 420 | + |
| 421 | + text_encoder.to(device=text_encoder.device, dtype=text_encoder.dtype) |
| 422 | + |
| 423 | + # Offload back. |
| 424 | + if is_model_cpu_offload: |
| 425 | + _pipeline.enable_model_cpu_offload() |
| 426 | + elif is_sequential_cpu_offload: |
| 427 | + _pipeline.enable_sequential_cpu_offload() |
| 428 | + # Unsafe code /> |
| 429 | + |
| 430 | + |
| 431 | +def _func_optionally_disable_offloading(_pipeline): |
| 432 | + is_model_cpu_offload = False |
| 433 | + is_sequential_cpu_offload = False |
| 434 | + |
| 435 | + if _pipeline is not None and _pipeline.hf_device_map is None: |
| 436 | + for _, component in _pipeline.components.items(): |
| 437 | + if isinstance(component, nn.Module) and hasattr(component, "_hf_hook"): |
| 438 | + if not is_model_cpu_offload: |
| 439 | + is_model_cpu_offload = isinstance(component._hf_hook, CpuOffload) |
| 440 | + if not is_sequential_cpu_offload: |
| 441 | + is_sequential_cpu_offload = ( |
| 442 | + isinstance(component._hf_hook, AlignDevicesHook) |
| 443 | + or hasattr(component._hf_hook, "hooks") |
| 444 | + and isinstance(component._hf_hook.hooks[0], AlignDevicesHook) |
| 445 | + ) |
| 446 | + |
| 447 | + logger.info( |
| 448 | + "Accelerate hooks detected. Since you have called `load_lora_weights()`, the previous hooks will be first removed. Then the LoRA parameters will be loaded and the hooks will be applied again." |
| 449 | + ) |
| 450 | + remove_hook_from_module(component, recurse=is_sequential_cpu_offload) |
| 451 | + |
| 452 | + return (is_model_cpu_offload, is_sequential_cpu_offload) |
| 453 | + |
| 454 | + |
300 | 455 | class LoraBaseMixin:
|
301 | 456 | """Utility class for handling LoRAs."""
|
302 | 457 |
|
@@ -327,27 +482,7 @@ def _optionally_disable_offloading(cls, _pipeline):
|
327 | 482 | tuple:
|
328 | 483 | A tuple indicating if `is_model_cpu_offload` or `is_sequential_cpu_offload` is True.
|
329 | 484 | """
|
330 |
| - is_model_cpu_offload = False |
331 |
| - is_sequential_cpu_offload = False |
332 |
| - |
333 |
| - if _pipeline is not None and _pipeline.hf_device_map is None: |
334 |
| - for _, component in _pipeline.components.items(): |
335 |
| - if isinstance(component, nn.Module) and hasattr(component, "_hf_hook"): |
336 |
| - if not is_model_cpu_offload: |
337 |
| - is_model_cpu_offload = isinstance(component._hf_hook, CpuOffload) |
338 |
| - if not is_sequential_cpu_offload: |
339 |
| - is_sequential_cpu_offload = ( |
340 |
| - isinstance(component._hf_hook, AlignDevicesHook) |
341 |
| - or hasattr(component._hf_hook, "hooks") |
342 |
| - and isinstance(component._hf_hook.hooks[0], AlignDevicesHook) |
343 |
| - ) |
344 |
| - |
345 |
| - logger.info( |
346 |
| - "Accelerate hooks detected. Since you have called `load_lora_weights()`, the previous hooks will be first removed. Then the LoRA parameters will be loaded and the hooks will be applied again." |
347 |
| - ) |
348 |
| - remove_hook_from_module(component, recurse=is_sequential_cpu_offload) |
349 |
| - |
350 |
| - return (is_model_cpu_offload, is_sequential_cpu_offload) |
| 485 | + return _func_optionally_disable_offloading(_pipeline=_pipeline) |
351 | 486 |
|
352 | 487 | @classmethod
|
353 | 488 | def _fetch_state_dict(cls, *args, **kwargs):
|
|
0 commit comments