-
Notifications
You must be signed in to change notification settings - Fork 5.8k
New issue
Have a question about this project? # for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “#”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? # to your account
[Training] fix training resuming problem when using FP16 (SDXL LoRA DreamBooth) #6514
Changes from all commits
fa57c21
3994545
356ef29
9155406
32ecbee
03bef1d
967eeee
820522f
85e6b6b
90d50e4
e8f1d38
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -34,7 +34,7 @@ | |
from huggingface_hub import create_repo, upload_folder | ||
from huggingface_hub.utils import insecure_hashlib | ||
from packaging import version | ||
from peft import LoraConfig | ||
from peft import LoraConfig, set_peft_model_state_dict | ||
from peft.utils import get_peft_model_state_dict | ||
from PIL import Image | ||
from PIL.ImageOps import exif_transpose | ||
|
@@ -53,8 +53,13 @@ | |
) | ||
from diffusers.loaders import LoraLoaderMixin | ||
from diffusers.optimization import get_scheduler | ||
from diffusers.training_utils import compute_snr | ||
from diffusers.utils import check_min_version, convert_state_dict_to_diffusers, is_wandb_available | ||
from diffusers.training_utils import _set_state_dict_into_text_encoder, compute_snr | ||
from diffusers.utils import ( | ||
check_min_version, | ||
convert_state_dict_to_diffusers, | ||
convert_unet_state_dict_to_peft, | ||
is_wandb_available, | ||
) | ||
from diffusers.utils.import_utils import is_xformers_available | ||
from diffusers.utils.torch_utils import is_compiled_module | ||
|
||
|
@@ -997,17 +1002,6 @@ def main(args): | |
text_encoder_one.add_adapter(text_lora_config) | ||
text_encoder_two.add_adapter(text_lora_config) | ||
|
||
# Make sure the trainable params are in float32. | ||
if args.mixed_precision == "fp16": | ||
models = [unet] | ||
if args.train_text_encoder: | ||
models.extend([text_encoder_one, text_encoder_two]) | ||
for model in models: | ||
for param in model.parameters(): | ||
# only upcast trainable parameters (LoRA) into fp32 | ||
if param.requires_grad: | ||
param.data = param.to(torch.float32) | ||
|
||
def unwrap_model(model): | ||
model = accelerator.unwrap_model(model) | ||
model = model._orig_mod if is_compiled_module(model) else model | ||
|
@@ -1064,17 +1058,39 @@ def load_model_hook(models, input_dir): | |
raise ValueError(f"unexpected save model: {model.__class__}") | ||
|
||
lora_state_dict, network_alphas = LoraLoaderMixin.lora_state_dict(input_dir) | ||
LoraLoaderMixin.load_lora_into_unet(lora_state_dict, network_alphas=network_alphas, unet=unet_) | ||
|
||
text_encoder_state_dict = {k: v for k, v in lora_state_dict.items() if "text_encoder." in k} | ||
LoraLoaderMixin.load_lora_into_text_encoder( | ||
text_encoder_state_dict, network_alphas=network_alphas, text_encoder=text_encoder_one_ | ||
) | ||
Comment on lines
-1067
to
-1072
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. We cannot be using
|
||
unet_state_dict = {f'{k.replace("unet.", "")}': v for k, v in lora_state_dict.items() if k.startswith("unet.")} | ||
BenjaminBossan marked this conversation as resolved.
Show resolved
Hide resolved
|
||
unet_state_dict = convert_unet_state_dict_to_peft(unet_state_dict) | ||
incompatible_keys = set_peft_model_state_dict(unet_, unet_state_dict, adapter_name="default") | ||
if incompatible_keys is not None: | ||
# check only for unexpected keys | ||
unexpected_keys = getattr(incompatible_keys, "unexpected_keys", None) | ||
if unexpected_keys: | ||
logger.warning( | ||
f"Loading adapter weights from state_dict led to unexpected keys not found in the model: " | ||
f" {unexpected_keys}. " | ||
) | ||
|
||
text_encoder_2_state_dict = {k: v for k, v in lora_state_dict.items() if "text_encoder_2." in k} | ||
LoraLoaderMixin.load_lora_into_text_encoder( | ||
text_encoder_2_state_dict, network_alphas=network_alphas, text_encoder=text_encoder_two_ | ||
) | ||
if args.train_text_encoder: | ||
# Do we need to call `scale_lora_layers()` here? | ||
sayakpaul marked this conversation as resolved.
Show resolved
Hide resolved
|
||
_set_state_dict_into_text_encoder(lora_state_dict, prefix="text_encoder.", text_encoder=text_encoder_one_) | ||
|
||
_set_state_dict_into_text_encoder( | ||
lora_state_dict, prefix="text_encoder_2.", text_encoder=text_encoder_one_ | ||
) | ||
|
||
# Make sure the trainable params are in float32. This is again needed since the base models | ||
sayakpaul marked this conversation as resolved.
Show resolved
Hide resolved
|
||
# are in `weight_dtype`. More details: | ||
# https://github.com/huggingface/diffusers/pull/6514#discussion_r1449796804 | ||
if args.mixed_precision == "fp16": | ||
models = [unet_] | ||
if args.train_text_encoder: | ||
models.extend([text_encoder_one_, text_encoder_two_]) | ||
for model in models: | ||
for param in model.parameters(): | ||
# only upcast trainable parameters (LoRA) into fp32 | ||
if param.requires_grad: | ||
param.data = param.to(torch.float32) | ||
|
||
accelerator.register_save_state_pre_hook(save_model_hook) | ||
accelerator.register_load_state_pre_hook(load_model_hook) | ||
|
@@ -1089,6 +1105,17 @@ def load_model_hook(models, input_dir): | |
args.learning_rate * args.gradient_accumulation_steps * args.train_batch_size * accelerator.num_processes | ||
) | ||
|
||
# Make sure the trainable params are in float32. | ||
if args.mixed_precision == "fp16": | ||
models = [unet] | ||
if args.train_text_encoder: | ||
models.extend([text_encoder_one, text_encoder_two]) | ||
for model in models: | ||
for param in model.parameters(): | ||
# only upcast trainable parameters (LoRA) into fp32 | ||
if param.requires_grad: | ||
param.data = param.to(torch.float32) | ||
Comment on lines
+1108
to
+1117
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. We do it just before assigning the parameters to the optimizer to avoid any consequences. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. In a follow-up PR, I can wrap this utility into a function and move to |
||
|
||
unet_lora_parameters = list(filter(lambda p: p.requires_grad, unet.parameters())) | ||
|
||
if args.train_text_encoder: | ||
|
@@ -1506,6 +1533,7 @@ def compute_text_embeddings(prompt, text_encoders, tokenizers): | |
else unet_lora_parameters | ||
) | ||
accelerator.clip_grad_norm_(params_to_clip, args.max_grad_norm) | ||
|
||
optimizer.step() | ||
lr_scheduler.step() | ||
optimizer.zero_grad() | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
https://github.com/huggingface/diffusers/pull/6514/files#r1447020705