Skip to content

Commit 39794ec

Browse files
asrimanthsayakpaul
authored and
Jimmy
committed
Fix: training resume from fp16 for SDXL Consistency Distillation (huggingface#6840)
* Fix: training resume from fp16 for lcm distill lora sdxl * Fix coding quality - run linter * Fix 1 - shift mixed precision cast before optimizer * Fix 2 - State dict errors by removing load_lora_into_unet * Update train_lcm_distill_lora_sdxl.py - Revert default cache dir to None --------- Co-authored-by: Sayak Paul <spsayakpaul@gmail.com>
1 parent ec903e4 commit 39794ec

File tree

1 file changed

+32
-9
lines changed

1 file changed

+32
-9
lines changed

Diff for: examples/consistency_distillation/train_lcm_distill_lora_sdxl.py

+32-9
Original file line numberDiff line numberDiff line change
@@ -36,7 +36,7 @@
3636
from datasets import load_dataset
3737
from huggingface_hub import create_repo, upload_folder
3838
from packaging import version
39-
from peft import LoraConfig, get_peft_model_state_dict
39+
from peft import LoraConfig, get_peft_model_state_dict, set_peft_model_state_dict
4040
from torchvision import transforms
4141
from torchvision.transforms.functional import crop
4242
from tqdm.auto import tqdm
@@ -52,7 +52,12 @@
5252
)
5353
from diffusers.optimization import get_scheduler
5454
from diffusers.training_utils import cast_training_params, resolve_interpolation_mode
55-
from diffusers.utils import check_min_version, convert_state_dict_to_diffusers, is_wandb_available
55+
from diffusers.utils import (
56+
check_min_version,
57+
convert_state_dict_to_diffusers,
58+
convert_unet_state_dict_to_peft,
59+
is_wandb_available,
60+
)
5661
from diffusers.utils.import_utils import is_xformers_available
5762

5863

@@ -858,11 +863,6 @@ def main(args):
858863
)
859864
unet.add_adapter(lora_config)
860865

861-
# Make sure the trainable params are in float32.
862-
if args.mixed_precision == "fp16":
863-
# only upcast trainable parameters (LoRA) into fp32
864-
cast_training_params(unet, dtype=torch.float32)
865-
866866
# Also move the alpha and sigma noise schedules to accelerator.device.
867867
alpha_schedule = alpha_schedule.to(accelerator.device)
868868
sigma_schedule = sigma_schedule.to(accelerator.device)
@@ -887,13 +887,31 @@ def save_model_hook(models, weights, output_dir):
887887
def load_model_hook(models, input_dir):
888888
# load the LoRA into the model
889889
unet_ = accelerator.unwrap_model(unet)
890-
lora_state_dict, network_alphas = StableDiffusionXLPipeline.lora_state_dict(input_dir)
891-
StableDiffusionXLPipeline.load_lora_into_unet(lora_state_dict, network_alphas=network_alphas, unet=unet_)
890+
lora_state_dict, _ = StableDiffusionXLPipeline.lora_state_dict(input_dir)
891+
unet_state_dict = {
892+
f'{k.replace("unet.", "")}': v for k, v in lora_state_dict.items() if k.startswith("unet.")
893+
}
894+
unet_state_dict = convert_unet_state_dict_to_peft(unet_state_dict)
895+
incompatible_keys = set_peft_model_state_dict(unet_, unet_state_dict, adapter_name="default")
896+
if incompatible_keys is not None:
897+
# check only for unexpected keys
898+
unexpected_keys = getattr(incompatible_keys, "unexpected_keys", None)
899+
if unexpected_keys:
900+
logger.warning(
901+
f"Loading adapter weights from state_dict led to unexpected keys not found in the model: "
902+
f" {unexpected_keys}. "
903+
)
892904

893905
for _ in range(len(models)):
894906
# pop models so that they are not loaded again
895907
models.pop()
896908

909+
# Make sure the trainable params are in float32. This is again needed since the base models
910+
# are in `weight_dtype`. More details:
911+
# https://github.com/huggingface/diffusers/pull/6514#discussion_r1449796804
912+
if args.mixed_precision == "fp16":
913+
cast_training_params(unet_, dtype=torch.float32)
914+
897915
accelerator.register_save_state_pre_hook(save_model_hook)
898916
accelerator.register_load_state_pre_hook(load_model_hook)
899917

@@ -1092,6 +1110,11 @@ def compute_time_ids(original_size, crops_coords_top_left):
10921110
args.learning_rate * args.gradient_accumulation_steps * args.train_batch_size * accelerator.num_processes
10931111
)
10941112

1113+
# Make sure the trainable params are in float32.
1114+
if args.mixed_precision == "fp16":
1115+
# only upcast trainable parameters (LoRA) into fp32
1116+
cast_training_params(unet, dtype=torch.float32)
1117+
10951118
lr_scheduler = get_scheduler(
10961119
args.lr_scheduler,
10971120
optimizer=optimizer,

0 commit comments

Comments
 (0)