Skip to content
New issue

Have a question about this project? # for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “#”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? # to your account

train_dreambooth_lora_sdxl_advanced.py --resume_from_checkpoint fails with ValueError: Attempting to unscale FP16 gradients. #6482

Closed
steverhoades opened this issue Jan 8, 2024 · 3 comments · Fixed by #6566
Assignees
Labels
bug Something isn't working

Comments

@steverhoades
Copy link
Contributor

Describe the bug

After a system crash I attempted to resume from a prior checkpoint.

Expected Result:
Continues training from last checkpoint and finishes successfully

Actual Result:

 File "/workspace/train_dreambooth_lora_sdxl_advanced.py", line 2104, in <module>
    main(args)
  File "/workspace/train_dreambooth_lora_sdxl_advanced.py", line 1861, in main
    accelerator.clip_grad_norm_(params_to_clip, args.max_grad_norm)
  File "/usr/local/lib/python3.10/dist-packages/accelerate/accelerator.py", line 2040, in clip_grad_norm_
    self.unscale_gradients()
  File "/usr/local/lib/python3.10/dist-packages/accelerate/accelerator.py", line 2003, in unscale_gradients
    self.scaler.unscale_(opt)
  File "/usr/local/lib/python3.10/dist-packages/torch/cuda/amp/grad_scaler.py", line 307, in unscale_
    optimizer_state["found_inf_per_device"] = self._unscale_grads_(
  File "/usr/local/lib/python3.10/dist-packages/torch/cuda/amp/grad_scaler.py", line 229, in _unscale_grads_
    raise ValueError("Attempting to unscale FP16 gradients.")
ValueError: Attempting to unscale FP16 gradients.

Reproduction

!accelerate launch train_dreambooth_lora_sdxl_advanced.py \
  --report_to="wandb" \
  --pretrained_model_name_or_path="stabilityai/stable-diffusion-xl-base-1.0" \
  --pretrained_vae_model_name_or_path="madebyollin/sdxl-vae-fp16-fix" \
  --dataset_name="./training_images" \
  --output_dir="father_lora_v9" \
  --cache_dir="./dataset_cache_dir" \
  --caption_column="prompt" \
  --mixed_precision="fp16" \
  --instance_prompt="a photo of Brian de palma man" \
  --resolution=1024 \
  --train_batch_size=1 \
  --gradient_accumulation_steps=4   \
  --gradient_checkpointing \
  --snr_gamma=5.0 \
  --lr_scheduler="constant" \
  --lr_warmup_steps=0 \
  --max_train_steps=1680 \
  --checkpointing_steps=200 \
  --validation_prompt="a photo of Brian de palma in a suit" \
  --validation_epochs=10 \
  --train_text_encoder \
  --with_prior_preservation \
  --class_data_dir="./prior_preservation" \
  --num_class_images=150 \
  --class_prompt="a photo of an old man" \
  --rank=32 \
  --optimizer="prodigy" \
  --prodigy_safeguard_warmup=True \
  --prodigy_use_bias_correction=True \
  --adam_beta1=0.9 \
  --adam_beta2=0.99 \
  --adam_weight_decay=0.01 \
  --train_text_encoder \
  --learning_rate=1 \
  --text_encoder_lr=1 \
  --resume_from_checkpoint="checkpoint-1600" \
  --seed="0"

Logs

Steps:  95%|████████████████████ | 1600/1680 [00:05<?, ?it/s, loss=0.0168, lr=1]Traceback (most recent call last):
  File "/workspace/train_dreambooth_lora_sdxl_advanced.py", line 2104, in <module>
    main(args)
  File "/workspace/train_dreambooth_lora_sdxl_advanced.py", line 1861, in main
    accelerator.clip_grad_norm_(params_to_clip, args.max_grad_norm)
  File "/usr/local/lib/python3.10/dist-packages/accelerate/accelerator.py", line 2040, in clip_grad_norm_
    self.unscale_gradients()
  File "/usr/local/lib/python3.10/dist-packages/accelerate/accelerator.py", line 2003, in unscale_gradients
    self.scaler.unscale_(opt)
  File "/usr/local/lib/python3.10/dist-packages/torch/cuda/amp/grad_scaler.py", line 307, in unscale_
    optimizer_state["found_inf_per_device"] = self._unscale_grads_(
  File "/usr/local/lib/python3.10/dist-packages/torch/cuda/amp/grad_scaler.py", line 229, in _unscale_grads_
    raise ValueError("Attempting to unscale FP16 gradients.")
ValueError: Attempting to unscale FP16 gradients.

System Info

diffusers` version: 0.26.0.dev0

  • Platform: Linux-5.4.0-156-generic-x86_64-with-glibc2.35
  • Python version: 3.10.12
  • PyTorch version (GPU?): 2.1.1+cu121 (True)
  • Huggingface_hub version: 0.20.2
  • Transformers version: 4.36.2
  • Accelerate version: 0.25.0
  • xFormers version: 0.0.23.post1
  • Using GPU in script?: yes
  • Using distributed or parallel set-up in script?: no

Who can help?

No response

@steverhoades steverhoades added the bug Something isn't working label Jan 8, 2024
@sayakpaul
Copy link
Member

Cc: @linoytsaban @apolinario

@linoytsaban linoytsaban self-assigned this Jan 9, 2024
@sayakpaul
Copy link
Member

This needs to be solved first: #6510.

@steverhoades
Copy link
Contributor Author

steverhoades commented Jan 10, 2024

This might be related but when attempting to pass --enable_xformers_memory_efficient_attention for a run, not including the checkpoint, the following error is thrown.

  ...
  File "/home/steve/.local/lib/python3.10/site-packages/xformers/ops/fmha/common.py", line 121, in validate_inputs
    raise ValueError(
ValueError: Query/Key/Value should either all have the same dtype, or (in the quantized case) Key/Value should have dtype torch.int32
  query.dtype: torch.float32
  key.dtype  : torch.float16
  value.dtype: torch.float16

params

#!/usr/bin/env bash
# adamW optimizer
!accelerate launch scripts/train_dreambooth_lora_sdxl_advanced.py \
  --report_to="wandb" \
  --enable_xformers_memory_efficient_attention \
  --pretrained_model_name_or_path="stabilityai/stable-diffusion-xl-base-1.0" \
  --pretrained_vae_model_name_or_path="madebyollin/sdxl-vae-fp16-fix" \
  --dataset_name="./training_images" \
  --output_dir="father_lora_v18" \
  --cache_dir="./dataset_cache_dir" \
  --caption_column="prompt" \
  --mixed_precision="fp16" \
  --instance_prompt="a photo of ohwx man" \
  --resolution=1024 \
  --train_batch_size=1 \
  --gradient_accumulation_steps=4 \
  --gradient_checkpointing \
  --snr_gamma=5.0 \
  --lr_scheduler="constant" \
  --lr_warmup_steps=0 \
  --max_train_steps=3600 \
  --checkpointing_steps=100 \
  --validation_prompt="a photo of ohwx man in a suit, looking directly at the camera" \
  --validation_epochs=1 \
  --with_prior_preservation \
  --class_data_dir="./prior_preservation-man-v1" \
  --num_class_images=100 \
  --class_prompt="a photo of a man" \
  --rank=16 \
  --train_text_encoder \
  --optimizer="adamW" \
  --learning_rate=1e-4 \
  --text_encoder_lr=3e-4 \
  --seed="0"

# for free to join this conversation on GitHub. Already have an account? # to comment
Labels
bug Something isn't working
Projects
None yet
Development

Successfully merging a pull request may close this issue.

3 participants