Skip to content

Commit 6905951

Browse files
raulmosaraul_arsayakpaul
committed
Update handle single blocks on _convert_xlabs_flux_lora_to_diffusers (#9915)
* Update handle single blocks on _convert_xlabs_flux_lora_to_diffusers to fix bug on updating keys and old_state_dict --------- Co-authored-by: raul_ar <raul.moreno.salinas@autoretouch.com> Co-authored-by: Sayak Paul <spsayakpaul@gmail.com>
1 parent 6b44c28 commit 6905951

File tree

2 files changed

+33
-3
lines changed

2 files changed

+33
-3
lines changed

src/diffusers/loaders/lora_conversion_utils.py

+8-3
Original file line numberDiff line numberDiff line change
@@ -636,10 +636,15 @@ def handle_qkv(sds_sd, ait_sd, sds_key, ait_keys, dims=None):
636636
block_num = re.search(r"single_blocks\.(\d+)", old_key).group(1)
637637
new_key = f"transformer.single_transformer_blocks.{block_num}"
638638

639-
if "proj_lora1" in old_key or "proj_lora2" in old_key:
639+
if "proj_lora" in old_key:
640640
new_key += ".proj_out"
641-
elif "qkv_lora1" in old_key or "qkv_lora2" in old_key:
642-
new_key += ".norm.linear"
641+
elif "qkv_lora" in old_key and "up" not in old_key:
642+
handle_qkv(
643+
old_state_dict,
644+
new_state_dict,
645+
old_key,
646+
[f"transformer.single_transformer_blocks.{block_num}.norm.linear"],
647+
)
643648

644649
if "down" in old_key:
645650
new_key += ".lora_A.weight"

tests/lora/test_lora_layers_flux.py

+25
Original file line numberDiff line numberDiff line change
@@ -282,3 +282,28 @@ def test_flux_xlabs(self):
282282
max_diff = numpy_cosine_similarity_distance(expected_slice.flatten(), out_slice)
283283

284284
assert max_diff < 1e-3
285+
286+
def test_flux_xlabs_load_lora_with_single_blocks(self):
287+
self.pipeline.load_lora_weights(
288+
"salinasr/test_xlabs_flux_lora_with_singleblocks", weight_name="lora.safetensors"
289+
)
290+
self.pipeline.fuse_lora()
291+
self.pipeline.unload_lora_weights()
292+
self.pipeline.enable_model_cpu_offload()
293+
294+
prompt = "a wizard mouse playing chess"
295+
296+
out = self.pipeline(
297+
prompt,
298+
num_inference_steps=self.num_inference_steps,
299+
guidance_scale=3.5,
300+
output_type="np",
301+
generator=torch.manual_seed(self.seed),
302+
).images
303+
out_slice = out[0, -3:, -3:, -1].flatten()
304+
expected_slice = np.array(
305+
[0.04882812, 0.04101562, 0.04882812, 0.03710938, 0.02929688, 0.02734375, 0.0234375, 0.01757812, 0.0390625]
306+
)
307+
max_diff = numpy_cosine_similarity_distance(expected_slice.flatten(), out_slice)
308+
309+
assert max_diff < 1e-3

0 commit comments

Comments
 (0)