@@ -1495,10 +1495,10 @@ class FluxLoraLoaderMixin(LoraBaseMixin):
1495
1495
1496
1496
@classmethod
1497
1497
@validate_hf_hub_args
1498
- # Copied from diffusers.loaders.lora_pipeline.SD3LoraLoaderMixin.lora_state_dict
1499
1498
def lora_state_dict (
1500
1499
cls ,
1501
1500
pretrained_model_name_or_path_or_dict : Union [str , Dict [str , torch .Tensor ]],
1501
+ return_alphas : bool = False ,
1502
1502
** kwargs ,
1503
1503
):
1504
1504
r"""
@@ -1583,7 +1583,26 @@ def lora_state_dict(
1583
1583
allow_pickle = allow_pickle ,
1584
1584
)
1585
1585
1586
- return state_dict
1586
+ # For state dicts like
1587
+ # https://huggingface.co/TheLastBen/Jon_Snow_Flux_LoRA
1588
+ keys = list (state_dict .keys ())
1589
+ network_alphas = {}
1590
+ for k in keys :
1591
+ if "alpha" in k :
1592
+ alpha_value = state_dict .get (k )
1593
+ if (torch .is_tensor (alpha_value ) and torch .is_floating_point (alpha_value )) or isinstance (
1594
+ alpha_value , float
1595
+ ):
1596
+ network_alphas [k ] = state_dict .pop (k )
1597
+ else :
1598
+ raise ValueError (
1599
+ f"The alpha key ({ k } ) seems to be incorrect. If you think this error is unexpected, please open as issue."
1600
+ )
1601
+
1602
+ if return_alphas :
1603
+ return state_dict , network_alphas
1604
+ else :
1605
+ return state_dict
1587
1606
1588
1607
def load_lora_weights (
1589
1608
self , pretrained_model_name_or_path_or_dict : Union [str , Dict [str , torch .Tensor ]], adapter_name = None , ** kwargs
@@ -1617,14 +1636,17 @@ def load_lora_weights(
1617
1636
pretrained_model_name_or_path_or_dict = pretrained_model_name_or_path_or_dict .copy ()
1618
1637
1619
1638
# First, ensure that the checkpoint is a compatible one and can be successfully loaded.
1620
- state_dict = self .lora_state_dict (pretrained_model_name_or_path_or_dict , ** kwargs )
1639
+ state_dict , network_alphas = self .lora_state_dict (
1640
+ pretrained_model_name_or_path_or_dict , return_alphas = True , ** kwargs
1641
+ )
1621
1642
1622
1643
is_correct_format = all ("lora" in key or "dora_scale" in key for key in state_dict .keys ())
1623
1644
if not is_correct_format :
1624
1645
raise ValueError ("Invalid LoRA checkpoint." )
1625
1646
1626
1647
self .load_lora_into_transformer (
1627
1648
state_dict ,
1649
+ network_alphas = network_alphas ,
1628
1650
transformer = getattr (self , self .transformer_name ) if not hasattr (self , "transformer" ) else self .transformer ,
1629
1651
adapter_name = adapter_name ,
1630
1652
_pipeline = self ,
@@ -1634,7 +1656,7 @@ def load_lora_weights(
1634
1656
if len (text_encoder_state_dict ) > 0 :
1635
1657
self .load_lora_into_text_encoder (
1636
1658
text_encoder_state_dict ,
1637
- network_alphas = None ,
1659
+ network_alphas = network_alphas ,
1638
1660
text_encoder = self .text_encoder ,
1639
1661
prefix = "text_encoder" ,
1640
1662
lora_scale = self .lora_scale ,
@@ -1643,8 +1665,7 @@ def load_lora_weights(
1643
1665
)
1644
1666
1645
1667
@classmethod
1646
- # Copied from diffusers.loaders.lora_pipeline.SD3LoraLoaderMixin.load_lora_into_transformer
1647
- def load_lora_into_transformer (cls , state_dict , transformer , adapter_name = None , _pipeline = None ):
1668
+ def load_lora_into_transformer (cls , state_dict , network_alphas , transformer , adapter_name = None , _pipeline = None ):
1648
1669
"""
1649
1670
This will load the LoRA layers specified in `state_dict` into `transformer`.
1650
1671
@@ -1653,6 +1674,10 @@ def load_lora_into_transformer(cls, state_dict, transformer, adapter_name=None,
1653
1674
A standard state dict containing the lora layer parameters. The keys can either be indexed directly
1654
1675
into the unet or prefixed with an additional `unet` which can be used to distinguish between text
1655
1676
encoder lora layers.
1677
+ network_alphas (`Dict[str, float]`):
1678
+ The value of the network alpha used for stable learning and preventing underflow. This value has the
1679
+ same meaning as the `--network_alpha` option in the kohya-ss trainer script. Refer to [this
1680
+ link](https://github.com/darkstorm2150/sd-scripts/blob/main/docs/train_network_README-en.md#execute-learning).
1656
1681
transformer (`SD3Transformer2DModel`):
1657
1682
The Transformer model to load the LoRA layers into.
1658
1683
adapter_name (`str`, *optional*):
@@ -1684,7 +1709,12 @@ def load_lora_into_transformer(cls, state_dict, transformer, adapter_name=None,
1684
1709
if "lora_B" in key :
1685
1710
rank [key ] = val .shape [1 ]
1686
1711
1687
- lora_config_kwargs = get_peft_kwargs (rank , network_alpha_dict = None , peft_state_dict = state_dict )
1712
+ if network_alphas is not None and len (network_alphas ) >= 1 :
1713
+ prefix = cls .transformer_name
1714
+ alpha_keys = [k for k in network_alphas .keys () if k .startswith (prefix ) and k .split ("." )[0 ] == prefix ]
1715
+ network_alphas = {k .replace (f"{ prefix } ." , "" ): v for k , v in network_alphas .items () if k in alpha_keys }
1716
+
1717
+ lora_config_kwargs = get_peft_kwargs (rank , network_alpha_dict = network_alphas , peft_state_dict = state_dict )
1688
1718
if "use_dora" in lora_config_kwargs :
1689
1719
if lora_config_kwargs ["use_dora" ] and is_peft_version ("<" , "0.9.0" ):
1690
1720
raise ValueError (
0 commit comments