Skip to content
New issue

Have a question about this project? # for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “#”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? # to your account

[Training] fix training resuming problem when using FP16 (SDXL LoRA DreamBooth) #6514

Merged
merged 11 commits into from
Jan 12, 2024
74 changes: 51 additions & 23 deletions examples/dreambooth/train_dreambooth_lora_sdxl.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,7 @@
from huggingface_hub import create_repo, upload_folder
from huggingface_hub.utils import insecure_hashlib
from packaging import version
from peft import LoraConfig
from peft import LoraConfig, set_peft_model_state_dict
from peft.utils import get_peft_model_state_dict
from PIL import Image
from PIL.ImageOps import exif_transpose
Expand All @@ -53,8 +53,13 @@
)
from diffusers.loaders import LoraLoaderMixin
from diffusers.optimization import get_scheduler
from diffusers.training_utils import compute_snr
from diffusers.utils import check_min_version, convert_state_dict_to_diffusers, is_wandb_available
from diffusers.training_utils import _set_state_dict_into_text_encoder, compute_snr
from diffusers.utils import (
check_min_version,
convert_state_dict_to_diffusers,
convert_unet_state_dict_to_peft,
is_wandb_available,
)
from diffusers.utils.import_utils import is_xformers_available
from diffusers.utils.torch_utils import is_compiled_module

Expand Down Expand Up @@ -997,17 +1002,6 @@ def main(args):
text_encoder_one.add_adapter(text_lora_config)
text_encoder_two.add_adapter(text_lora_config)

# Make sure the trainable params are in float32.
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

if args.mixed_precision == "fp16":
models = [unet]
if args.train_text_encoder:
models.extend([text_encoder_one, text_encoder_two])
for model in models:
for param in model.parameters():
# only upcast trainable parameters (LoRA) into fp32
if param.requires_grad:
param.data = param.to(torch.float32)

def unwrap_model(model):
model = accelerator.unwrap_model(model)
model = model._orig_mod if is_compiled_module(model) else model
Expand Down Expand Up @@ -1064,17 +1058,39 @@ def load_model_hook(models, input_dir):
raise ValueError(f"unexpected save model: {model.__class__}")

lora_state_dict, network_alphas = LoraLoaderMixin.lora_state_dict(input_dir)
LoraLoaderMixin.load_lora_into_unet(lora_state_dict, network_alphas=network_alphas, unet=unet_)

text_encoder_state_dict = {k: v for k, v in lora_state_dict.items() if "text_encoder." in k}
LoraLoaderMixin.load_lora_into_text_encoder(
text_encoder_state_dict, network_alphas=network_alphas, text_encoder=text_encoder_one_
)
Comment on lines -1067 to -1072
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We cannot be using load_lora_into_unet() and load_lora_into_text_encoder() for the following reason (described only for the UNet but applicable to the text encoders, too).

  • We call add_adapter() once on unet at the beginning of training. This creates an adapter config inside of the UNet.
  • Then during loading an intermediate checkpoint in the accelerate hook, we call load_lora_into_unet(). It internally again calls inject_adapter_in_model() with the config inferred from the state dict provided. So, it internally creates another adapter. This is undesirable, right?

unet_state_dict = {f'{k.replace("unet.", "")}': v for k, v in lora_state_dict.items() if k.startswith("unet.")}
unet_state_dict = convert_unet_state_dict_to_peft(unet_state_dict)
incompatible_keys = set_peft_model_state_dict(unet_, unet_state_dict, adapter_name="default")
if incompatible_keys is not None:
# check only for unexpected keys
unexpected_keys = getattr(incompatible_keys, "unexpected_keys", None)
if unexpected_keys:
logger.warning(
f"Loading adapter weights from state_dict led to unexpected keys not found in the model: "
f" {unexpected_keys}. "
)

text_encoder_2_state_dict = {k: v for k, v in lora_state_dict.items() if "text_encoder_2." in k}
LoraLoaderMixin.load_lora_into_text_encoder(
text_encoder_2_state_dict, network_alphas=network_alphas, text_encoder=text_encoder_two_
)
if args.train_text_encoder:
# Do we need to call `scale_lora_layers()` here?
_set_state_dict_into_text_encoder(lora_state_dict, prefix="text_encoder.", text_encoder=text_encoder_one_)

_set_state_dict_into_text_encoder(
lora_state_dict, prefix="text_encoder_2.", text_encoder=text_encoder_one_
)

# Make sure the trainable params are in float32. This is again needed since the base models
# are in `weight_dtype`. More details:
# https://github.com/huggingface/diffusers/pull/6514#discussion_r1449796804
if args.mixed_precision == "fp16":
models = [unet_]
if args.train_text_encoder:
models.extend([text_encoder_one_, text_encoder_two_])
for model in models:
for param in model.parameters():
# only upcast trainable parameters (LoRA) into fp32
if param.requires_grad:
param.data = param.to(torch.float32)

accelerator.register_save_state_pre_hook(save_model_hook)
accelerator.register_load_state_pre_hook(load_model_hook)
Expand All @@ -1089,6 +1105,17 @@ def load_model_hook(models, input_dir):
args.learning_rate * args.gradient_accumulation_steps * args.train_batch_size * accelerator.num_processes
)

# Make sure the trainable params are in float32.
if args.mixed_precision == "fp16":
models = [unet]
if args.train_text_encoder:
models.extend([text_encoder_one, text_encoder_two])
for model in models:
for param in model.parameters():
# only upcast trainable parameters (LoRA) into fp32
if param.requires_grad:
param.data = param.to(torch.float32)
Comment on lines +1108 to +1117
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We do it just before assigning the parameters to the optimizer to avoid any consequences.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

In a follow-up PR, I can wrap this utility into a function and move to training_utils.py. It's shared by a number of scripts.


unet_lora_parameters = list(filter(lambda p: p.requires_grad, unet.parameters()))

if args.train_text_encoder:
Expand Down Expand Up @@ -1506,6 +1533,7 @@ def compute_text_embeddings(prompt, text_encoders, tokenizers):
else unet_lora_parameters
)
accelerator.clip_grad_norm_(params_to_clip, args.max_grad_norm)

optimizer.step()
lr_scheduler.step()
optimizer.zero_grad()
Expand Down
1 change: 0 additions & 1 deletion src/diffusers/loaders/lora.py
Original file line number Diff line number Diff line change
Expand Up @@ -581,7 +581,6 @@ def load_lora_into_text_encoder(
lora_config_kwargs = get_peft_kwargs(
rank, network_alphas, text_encoder_lora_state_dict, is_unet=False
)

lora_config = LoraConfig(**lora_config_kwargs)

# adapter_name
Expand Down
30 changes: 29 additions & 1 deletion src/diffusers/training_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,12 +8,21 @@
from torchvision import transforms

from .models import UNet2DConditionModel
from .utils import deprecate, is_transformers_available
from .utils import (
convert_state_dict_to_diffusers,
convert_state_dict_to_peft,
deprecate,
is_peft_available,
is_transformers_available,
)


if is_transformers_available():
import transformers

if is_peft_available():
from peft import set_peft_model_state_dict


def set_seed(seed: int):
"""
Expand Down Expand Up @@ -112,6 +121,25 @@ def unet_lora_state_dict(unet: UNet2DConditionModel) -> Dict[str, torch.Tensor]:
return lora_state_dict


def _set_state_dict_into_text_encoder(
lora_state_dict: Dict[str, torch.Tensor], prefix: str, text_encoder: torch.nn.Module
):
"""
Sets the `lora_state_dict` into `text_encoder` coming from `transformers`.

Args:
lora_state_dict: The state dictionary to be set.
prefix: String identifier to retrieve the portion of the state dict that belongs to `text_encoder`.
text_encoder: Where the `lora_state_dict` is to be set.
"""

text_encoder_state_dict = {
f'{k.replace(prefix, "")}': v for k, v in lora_state_dict.items() if k.startswith(prefix)
}
text_encoder_state_dict = convert_state_dict_to_peft(convert_state_dict_to_diffusers(text_encoder_state_dict))
set_peft_model_state_dict(text_encoder, text_encoder_state_dict, adapter_name="default")


# Adapted from torch-ema https://github.com/fadel/pytorch_ema/blob/master/torch_ema/ema.py#L14
class EMAModel:
"""
Expand Down