36
36
from datasets import load_dataset
37
37
from huggingface_hub import create_repo , upload_folder
38
38
from packaging import version
39
- from peft import LoraConfig , get_peft_model_state_dict
39
+ from peft import LoraConfig , get_peft_model_state_dict , set_peft_model_state_dict
40
40
from torchvision import transforms
41
41
from torchvision .transforms .functional import crop
42
42
from tqdm .auto import tqdm
52
52
)
53
53
from diffusers .optimization import get_scheduler
54
54
from diffusers .training_utils import cast_training_params , resolve_interpolation_mode
55
- from diffusers .utils import check_min_version , convert_state_dict_to_diffusers , is_wandb_available
55
+ from diffusers .utils import (
56
+ check_min_version ,
57
+ convert_state_dict_to_diffusers ,
58
+ convert_unet_state_dict_to_peft ,
59
+ is_wandb_available ,
60
+ )
56
61
from diffusers .utils .import_utils import is_xformers_available
57
62
58
63
@@ -858,11 +863,6 @@ def main(args):
858
863
)
859
864
unet .add_adapter (lora_config )
860
865
861
- # Make sure the trainable params are in float32.
862
- if args .mixed_precision == "fp16" :
863
- # only upcast trainable parameters (LoRA) into fp32
864
- cast_training_params (unet , dtype = torch .float32 )
865
-
866
866
# Also move the alpha and sigma noise schedules to accelerator.device.
867
867
alpha_schedule = alpha_schedule .to (accelerator .device )
868
868
sigma_schedule = sigma_schedule .to (accelerator .device )
@@ -887,13 +887,31 @@ def save_model_hook(models, weights, output_dir):
887
887
def load_model_hook (models , input_dir ):
888
888
# load the LoRA into the model
889
889
unet_ = accelerator .unwrap_model (unet )
890
- lora_state_dict , network_alphas = StableDiffusionXLPipeline .lora_state_dict (input_dir )
891
- StableDiffusionXLPipeline .load_lora_into_unet (lora_state_dict , network_alphas = network_alphas , unet = unet_ )
890
+ lora_state_dict , _ = StableDiffusionXLPipeline .lora_state_dict (input_dir )
891
+ unet_state_dict = {
892
+ f'{ k .replace ("unet." , "" )} ' : v for k , v in lora_state_dict .items () if k .startswith ("unet." )
893
+ }
894
+ unet_state_dict = convert_unet_state_dict_to_peft (unet_state_dict )
895
+ incompatible_keys = set_peft_model_state_dict (unet_ , unet_state_dict , adapter_name = "default" )
896
+ if incompatible_keys is not None :
897
+ # check only for unexpected keys
898
+ unexpected_keys = getattr (incompatible_keys , "unexpected_keys" , None )
899
+ if unexpected_keys :
900
+ logger .warning (
901
+ f"Loading adapter weights from state_dict led to unexpected keys not found in the model: "
902
+ f" { unexpected_keys } . "
903
+ )
892
904
893
905
for _ in range (len (models )):
894
906
# pop models so that they are not loaded again
895
907
models .pop ()
896
908
909
+ # Make sure the trainable params are in float32. This is again needed since the base models
910
+ # are in `weight_dtype`. More details:
911
+ # https://github.com/huggingface/diffusers/pull/6514#discussion_r1449796804
912
+ if args .mixed_precision == "fp16" :
913
+ cast_training_params (unet_ , dtype = torch .float32 )
914
+
897
915
accelerator .register_save_state_pre_hook (save_model_hook )
898
916
accelerator .register_load_state_pre_hook (load_model_hook )
899
917
@@ -1092,6 +1110,11 @@ def compute_time_ids(original_size, crops_coords_top_left):
1092
1110
args .learning_rate * args .gradient_accumulation_steps * args .train_batch_size * accelerator .num_processes
1093
1111
)
1094
1112
1113
+ # Make sure the trainable params are in float32.
1114
+ if args .mixed_precision == "fp16" :
1115
+ # only upcast trainable parameters (LoRA) into fp32
1116
+ cast_training_params (unet , dtype = torch .float32 )
1117
+
1095
1118
lr_scheduler = get_scheduler (
1096
1119
args .lr_scheduler ,
1097
1120
optimizer = optimizer ,
0 commit comments