Skip to content

Commit 5090b09

Browse files
authored
[Flux LoRA] support parsing alpha from a flux lora state dict. (#9236)
* support parsing alpha from a flux lora state dict. * conditional import. * fix breaking changes. * safeguard alpha. * fix
1 parent 32d6492 commit 5090b09

File tree

2 files changed

+94
-9
lines changed

2 files changed

+94
-9
lines changed

Diff for: src/diffusers/loaders/lora_pipeline.py

+37-7
Original file line numberDiff line numberDiff line change
@@ -1495,10 +1495,10 @@ class FluxLoraLoaderMixin(LoraBaseMixin):
14951495

14961496
@classmethod
14971497
@validate_hf_hub_args
1498-
# Copied from diffusers.loaders.lora_pipeline.SD3LoraLoaderMixin.lora_state_dict
14991498
def lora_state_dict(
15001499
cls,
15011500
pretrained_model_name_or_path_or_dict: Union[str, Dict[str, torch.Tensor]],
1501+
return_alphas: bool = False,
15021502
**kwargs,
15031503
):
15041504
r"""
@@ -1583,7 +1583,26 @@ def lora_state_dict(
15831583
allow_pickle=allow_pickle,
15841584
)
15851585

1586-
return state_dict
1586+
# For state dicts like
1587+
# https://huggingface.co/TheLastBen/Jon_Snow_Flux_LoRA
1588+
keys = list(state_dict.keys())
1589+
network_alphas = {}
1590+
for k in keys:
1591+
if "alpha" in k:
1592+
alpha_value = state_dict.get(k)
1593+
if (torch.is_tensor(alpha_value) and torch.is_floating_point(alpha_value)) or isinstance(
1594+
alpha_value, float
1595+
):
1596+
network_alphas[k] = state_dict.pop(k)
1597+
else:
1598+
raise ValueError(
1599+
f"The alpha key ({k}) seems to be incorrect. If you think this error is unexpected, please open as issue."
1600+
)
1601+
1602+
if return_alphas:
1603+
return state_dict, network_alphas
1604+
else:
1605+
return state_dict
15871606

15881607
def load_lora_weights(
15891608
self, pretrained_model_name_or_path_or_dict: Union[str, Dict[str, torch.Tensor]], adapter_name=None, **kwargs
@@ -1617,14 +1636,17 @@ def load_lora_weights(
16171636
pretrained_model_name_or_path_or_dict = pretrained_model_name_or_path_or_dict.copy()
16181637

16191638
# First, ensure that the checkpoint is a compatible one and can be successfully loaded.
1620-
state_dict = self.lora_state_dict(pretrained_model_name_or_path_or_dict, **kwargs)
1639+
state_dict, network_alphas = self.lora_state_dict(
1640+
pretrained_model_name_or_path_or_dict, return_alphas=True, **kwargs
1641+
)
16211642

16221643
is_correct_format = all("lora" in key or "dora_scale" in key for key in state_dict.keys())
16231644
if not is_correct_format:
16241645
raise ValueError("Invalid LoRA checkpoint.")
16251646

16261647
self.load_lora_into_transformer(
16271648
state_dict,
1649+
network_alphas=network_alphas,
16281650
transformer=getattr(self, self.transformer_name) if not hasattr(self, "transformer") else self.transformer,
16291651
adapter_name=adapter_name,
16301652
_pipeline=self,
@@ -1634,7 +1656,7 @@ def load_lora_weights(
16341656
if len(text_encoder_state_dict) > 0:
16351657
self.load_lora_into_text_encoder(
16361658
text_encoder_state_dict,
1637-
network_alphas=None,
1659+
network_alphas=network_alphas,
16381660
text_encoder=self.text_encoder,
16391661
prefix="text_encoder",
16401662
lora_scale=self.lora_scale,
@@ -1643,8 +1665,7 @@ def load_lora_weights(
16431665
)
16441666

16451667
@classmethod
1646-
# Copied from diffusers.loaders.lora_pipeline.SD3LoraLoaderMixin.load_lora_into_transformer
1647-
def load_lora_into_transformer(cls, state_dict, transformer, adapter_name=None, _pipeline=None):
1668+
def load_lora_into_transformer(cls, state_dict, network_alphas, transformer, adapter_name=None, _pipeline=None):
16481669
"""
16491670
This will load the LoRA layers specified in `state_dict` into `transformer`.
16501671
@@ -1653,6 +1674,10 @@ def load_lora_into_transformer(cls, state_dict, transformer, adapter_name=None,
16531674
A standard state dict containing the lora layer parameters. The keys can either be indexed directly
16541675
into the unet or prefixed with an additional `unet` which can be used to distinguish between text
16551676
encoder lora layers.
1677+
network_alphas (`Dict[str, float]`):
1678+
The value of the network alpha used for stable learning and preventing underflow. This value has the
1679+
same meaning as the `--network_alpha` option in the kohya-ss trainer script. Refer to [this
1680+
link](https://github.com/darkstorm2150/sd-scripts/blob/main/docs/train_network_README-en.md#execute-learning).
16561681
transformer (`SD3Transformer2DModel`):
16571682
The Transformer model to load the LoRA layers into.
16581683
adapter_name (`str`, *optional*):
@@ -1684,7 +1709,12 @@ def load_lora_into_transformer(cls, state_dict, transformer, adapter_name=None,
16841709
if "lora_B" in key:
16851710
rank[key] = val.shape[1]
16861711

1687-
lora_config_kwargs = get_peft_kwargs(rank, network_alpha_dict=None, peft_state_dict=state_dict)
1712+
if network_alphas is not None and len(network_alphas) >= 1:
1713+
prefix = cls.transformer_name
1714+
alpha_keys = [k for k in network_alphas.keys() if k.startswith(prefix) and k.split(".")[0] == prefix]
1715+
network_alphas = {k.replace(f"{prefix}.", ""): v for k, v in network_alphas.items() if k in alpha_keys}
1716+
1717+
lora_config_kwargs = get_peft_kwargs(rank, network_alpha_dict=network_alphas, peft_state_dict=state_dict)
16881718
if "use_dora" in lora_config_kwargs:
16891719
if lora_config_kwargs["use_dora"] and is_peft_version("<", "0.9.0"):
16901720
raise ValueError(

Diff for: tests/lora/test_lora_layers_flux.py

+57-2
Original file line numberDiff line numberDiff line change
@@ -12,19 +12,26 @@
1212
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
1313
# See the License for the specific language governing permissions and
1414
# limitations under the License.
15+
import os
1516
import sys
17+
import tempfile
1618
import unittest
1719

20+
import numpy as np
21+
import safetensors.torch
1822
import torch
1923
from transformers import AutoTokenizer, CLIPTextModel, CLIPTokenizer, T5EncoderModel
2024

2125
from diffusers import FlowMatchEulerDiscreteScheduler, FluxPipeline, FluxTransformer2DModel
22-
from diffusers.utils.testing_utils import floats_tensor, require_peft_backend
26+
from diffusers.utils.testing_utils import floats_tensor, is_peft_available, require_peft_backend, torch_device
2327

2428

29+
if is_peft_available():
30+
from peft.utils import get_peft_model_state_dict
31+
2532
sys.path.append(".")
2633

27-
from utils import PeftLoraLoaderMixinTests # noqa: E402
34+
from utils import PeftLoraLoaderMixinTests, check_if_lora_correctly_set # noqa: E402
2835

2936

3037
@require_peft_backend
@@ -90,3 +97,51 @@ def get_dummy_inputs(self, with_generator=True):
9097
pipeline_inputs.update({"generator": generator})
9198

9299
return noise, input_ids, pipeline_inputs
100+
101+
def test_with_alpha_in_state_dict(self):
102+
components, _, denoiser_lora_config = self.get_dummy_components(FlowMatchEulerDiscreteScheduler)
103+
pipe = self.pipeline_class(**components)
104+
pipe = pipe.to(torch_device)
105+
pipe.set_progress_bar_config(disable=None)
106+
_, _, inputs = self.get_dummy_inputs(with_generator=False)
107+
108+
output_no_lora = pipe(**inputs, generator=torch.manual_seed(0)).images
109+
self.assertTrue(output_no_lora.shape == self.output_shape)
110+
111+
pipe.transformer.add_adapter(denoiser_lora_config)
112+
self.assertTrue(check_if_lora_correctly_set(pipe.transformer), "Lora not correctly set in transformer")
113+
114+
images_lora = pipe(**inputs, generator=torch.manual_seed(0)).images
115+
116+
with tempfile.TemporaryDirectory() as tmpdirname:
117+
denoiser_state_dict = get_peft_model_state_dict(pipe.transformer)
118+
self.pipeline_class.save_lora_weights(tmpdirname, transformer_lora_layers=denoiser_state_dict)
119+
120+
self.assertTrue(os.path.isfile(os.path.join(tmpdirname, "pytorch_lora_weights.safetensors")))
121+
pipe.unload_lora_weights()
122+
pipe.load_lora_weights(os.path.join(tmpdirname, "pytorch_lora_weights.safetensors"))
123+
124+
# modify the state dict to have alpha values following
125+
# https://huggingface.co/TheLastBen/Jon_Snow_Flux_LoRA/blob/main/jon_snow.safetensors
126+
state_dict_with_alpha = safetensors.torch.load_file(
127+
os.path.join(tmpdirname, "pytorch_lora_weights.safetensors")
128+
)
129+
alpha_dict = {}
130+
for k, v in state_dict_with_alpha.items():
131+
# only do for `transformer` and for the k projections -- should be enough to test.
132+
if "transformer" in k and "to_k" in k and "lora_A" in k:
133+
alpha_dict[f"{k}.alpha"] = float(torch.randint(10, 100, size=()))
134+
state_dict_with_alpha.update(alpha_dict)
135+
136+
images_lora_from_pretrained = pipe(**inputs, generator=torch.manual_seed(0)).images
137+
self.assertTrue(check_if_lora_correctly_set(pipe.transformer), "Lora not correctly set in denoiser")
138+
139+
pipe.unload_lora_weights()
140+
pipe.load_lora_weights(state_dict_with_alpha)
141+
images_lora_with_alpha = pipe(**inputs, generator=torch.manual_seed(0)).images
142+
143+
self.assertTrue(
144+
np.allclose(images_lora, images_lora_from_pretrained, atol=1e-3, rtol=1e-3),
145+
"Loading from saved checkpoints should give same results.",
146+
)
147+
self.assertFalse(np.allclose(images_lora_with_alpha, images_lora, atol=1e-3, rtol=1e-3))

0 commit comments

Comments
 (0)