|
79 | 79 | "animatediff_sdxl_beta": "up_blocks.2.motion_modules.0.temporal_transformer.norm.weight",
|
80 | 80 | "animatediff_scribble": "controlnet_cond_embedding.conv_in.weight",
|
81 | 81 | "animatediff_rgb": "controlnet_cond_embedding.weight",
|
82 |
| - "flux": "double_blocks.0.img_attn.norm.key_norm.scale", |
| 82 | + "flux": [ |
| 83 | + "double_blocks.0.img_attn.norm.key_norm.scale", |
| 84 | + "model.diffusion_model.double_blocks.0.img_attn.norm.key_norm.scale", |
| 85 | + ], |
83 | 86 | }
|
84 | 87 |
|
85 | 88 | DIFFUSERS_DEFAULT_PIPELINE_PATHS = {
|
|
258 | 261 | "timestep_spacing": "leading",
|
259 | 262 | }
|
260 | 263 |
|
261 |
| -LDM_VAE_KEY = "first_stage_model." |
| 264 | +LDM_VAE_KEYS = ["first_stage_model.", "vae."] |
262 | 265 | LDM_VAE_DEFAULT_SCALING_FACTOR = 0.18215
|
263 | 266 | PLAYGROUND_VAE_SCALING_FACTOR = 0.5
|
264 | 267 | LDM_UNET_KEY = "model.diffusion_model."
|
|
267 | 270 | "cond_stage_model.transformer.",
|
268 | 271 | "conditioner.embedders.0.transformer.",
|
269 | 272 | ]
|
270 |
| -OPEN_CLIP_PREFIX = "conditioner.embedders.0.model." |
271 | 273 | LDM_OPEN_CLIP_TEXT_PROJECTION_DIM = 1024
|
272 | 274 | SCHEDULER_LEGACY_KWARGS = ["prediction_type", "scheduler_type"]
|
273 | 275 |
|
@@ -523,8 +525,10 @@ def infer_diffusers_model_type(checkpoint):
|
523 | 525 | else:
|
524 | 526 | model_type = "animatediff_v3"
|
525 | 527 |
|
526 |
| - elif CHECKPOINT_KEY_NAMES["flux"] in checkpoint: |
527 |
| - if "guidance_in.in_layer.bias" in checkpoint: |
| 528 | + elif any(key in checkpoint for key in CHECKPOINT_KEY_NAMES["flux"]): |
| 529 | + if any( |
| 530 | + g in checkpoint for g in ["guidance_in.in_layer.bias", "model.diffusion_model.guidance_in.in_layer.bias"] |
| 531 | + ): |
528 | 532 | model_type = "flux-dev"
|
529 | 533 | else:
|
530 | 534 | model_type = "flux-schnell"
|
@@ -1183,7 +1187,11 @@ def convert_ldm_vae_checkpoint(checkpoint, config):
|
1183 | 1187 | # remove the LDM_VAE_KEY prefix from the ldm checkpoint keys so that it is easier to map them to diffusers keys
|
1184 | 1188 | vae_state_dict = {}
|
1185 | 1189 | keys = list(checkpoint.keys())
|
1186 |
| - vae_key = LDM_VAE_KEY if any(k.startswith(LDM_VAE_KEY) for k in keys) else "" |
| 1190 | + vae_key = "" |
| 1191 | + for ldm_vae_key in LDM_VAE_KEYS: |
| 1192 | + if any(k.startswith(ldm_vae_key) for k in keys): |
| 1193 | + vae_key = ldm_vae_key |
| 1194 | + |
1187 | 1195 | for key in keys:
|
1188 | 1196 | if key.startswith(vae_key):
|
1189 | 1197 | vae_state_dict[key.replace(vae_key, "")] = checkpoint.get(key)
|
@@ -1896,6 +1904,10 @@ def convert_animatediff_checkpoint_to_diffusers(checkpoint, **kwargs):
|
1896 | 1904 |
|
1897 | 1905 | def convert_flux_transformer_checkpoint_to_diffusers(checkpoint, **kwargs):
|
1898 | 1906 | converted_state_dict = {}
|
| 1907 | + keys = list(checkpoint.keys()) |
| 1908 | + for k in keys: |
| 1909 | + if "model.diffusion_model." in k: |
| 1910 | + checkpoint[k.replace("model.diffusion_model.", "")] = checkpoint.pop(k) |
1899 | 1911 |
|
1900 | 1912 | num_layers = list(set(int(k.split(".", 2)[1]) for k in checkpoint if "double_blocks." in k))[-1] + 1 # noqa: C401
|
1901 | 1913 | num_single_layers = list(set(int(k.split(".", 2)[1]) for k in checkpoint if "single_blocks." in k))[-1] + 1 # noqa: C401
|
|
0 commit comments