Skip to content

Commit 255ac59

Browse files
authored
[Single File] Support loading Comfy UI Flux checkpoints (#9243)
update
1 parent 2d9ccf3 commit 255ac59

File tree

1 file changed

+18
-6
lines changed

1 file changed

+18
-6
lines changed

src/diffusers/loaders/single_file_utils.py

+18-6
Original file line numberDiff line numberDiff line change
@@ -79,7 +79,10 @@
7979
"animatediff_sdxl_beta": "up_blocks.2.motion_modules.0.temporal_transformer.norm.weight",
8080
"animatediff_scribble": "controlnet_cond_embedding.conv_in.weight",
8181
"animatediff_rgb": "controlnet_cond_embedding.weight",
82-
"flux": "double_blocks.0.img_attn.norm.key_norm.scale",
82+
"flux": [
83+
"double_blocks.0.img_attn.norm.key_norm.scale",
84+
"model.diffusion_model.double_blocks.0.img_attn.norm.key_norm.scale",
85+
],
8386
}
8487

8588
DIFFUSERS_DEFAULT_PIPELINE_PATHS = {
@@ -258,7 +261,7 @@
258261
"timestep_spacing": "leading",
259262
}
260263

261-
LDM_VAE_KEY = "first_stage_model."
264+
LDM_VAE_KEYS = ["first_stage_model.", "vae."]
262265
LDM_VAE_DEFAULT_SCALING_FACTOR = 0.18215
263266
PLAYGROUND_VAE_SCALING_FACTOR = 0.5
264267
LDM_UNET_KEY = "model.diffusion_model."
@@ -267,7 +270,6 @@
267270
"cond_stage_model.transformer.",
268271
"conditioner.embedders.0.transformer.",
269272
]
270-
OPEN_CLIP_PREFIX = "conditioner.embedders.0.model."
271273
LDM_OPEN_CLIP_TEXT_PROJECTION_DIM = 1024
272274
SCHEDULER_LEGACY_KWARGS = ["prediction_type", "scheduler_type"]
273275

@@ -523,8 +525,10 @@ def infer_diffusers_model_type(checkpoint):
523525
else:
524526
model_type = "animatediff_v3"
525527

526-
elif CHECKPOINT_KEY_NAMES["flux"] in checkpoint:
527-
if "guidance_in.in_layer.bias" in checkpoint:
528+
elif any(key in checkpoint for key in CHECKPOINT_KEY_NAMES["flux"]):
529+
if any(
530+
g in checkpoint for g in ["guidance_in.in_layer.bias", "model.diffusion_model.guidance_in.in_layer.bias"]
531+
):
528532
model_type = "flux-dev"
529533
else:
530534
model_type = "flux-schnell"
@@ -1183,7 +1187,11 @@ def convert_ldm_vae_checkpoint(checkpoint, config):
11831187
# remove the LDM_VAE_KEY prefix from the ldm checkpoint keys so that it is easier to map them to diffusers keys
11841188
vae_state_dict = {}
11851189
keys = list(checkpoint.keys())
1186-
vae_key = LDM_VAE_KEY if any(k.startswith(LDM_VAE_KEY) for k in keys) else ""
1190+
vae_key = ""
1191+
for ldm_vae_key in LDM_VAE_KEYS:
1192+
if any(k.startswith(ldm_vae_key) for k in keys):
1193+
vae_key = ldm_vae_key
1194+
11871195
for key in keys:
11881196
if key.startswith(vae_key):
11891197
vae_state_dict[key.replace(vae_key, "")] = checkpoint.get(key)
@@ -1896,6 +1904,10 @@ def convert_animatediff_checkpoint_to_diffusers(checkpoint, **kwargs):
18961904

18971905
def convert_flux_transformer_checkpoint_to_diffusers(checkpoint, **kwargs):
18981906
converted_state_dict = {}
1907+
keys = list(checkpoint.keys())
1908+
for k in keys:
1909+
if "model.diffusion_model." in k:
1910+
checkpoint[k.replace("model.diffusion_model.", "")] = checkpoint.pop(k)
18991911

19001912
num_layers = list(set(int(k.split(".", 2)[1]) for k in checkpoint if "double_blocks." in k))[-1] + 1 # noqa: C401
19011913
num_single_layers = list(set(int(k.split(".", 2)[1]) for k in checkpoint if "single_blocks." in k))[-1] + 1 # noqa: C401

0 commit comments

Comments
 (0)