Skip to content
New issue

Have a question about this project? # for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “#”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? # to your account

make gradient checkpointing with frozen model possible #9850

Closed
MikeTkachuk opened this issue Nov 3, 2024 · 10 comments
Closed

make gradient checkpointing with frozen model possible #9850

MikeTkachuk opened this issue Nov 3, 2024 · 10 comments
Labels
bug Something isn't working

Comments

@MikeTkachuk
Copy link
Contributor

MikeTkachuk commented Nov 3, 2024

Describe the bug

if self.training and self.gradient_checkpointing:

hi, the clause i highlighted in the link above prevents a model from using gradient checkpointing in eval mode. this is particularly useful for e.g. LORAs.
perhaps you meant to check something like this instead?
if torch.is_grad_enabled() and self.gradient_checkpointing:

Same for any other module in the repo

Reproduction

import torch
from diffusers.models.unets.unet_2d_blocks import UNetMidBlock2DCrossAttn

block = UNetMidBlock2DCrossAttn(32, 32, 32, cross_attention_dim=32)
block.gradient_checkpointing = True
block.eval()

block(torch.randn((1, 32, 64, 64)), torch.randn((1, 32,)))

Logs

No response

System Info

  • 🤗 Diffusers version: 0.30.3
  • Platform: Windows-10-10.0.22631-SP0
  • Running on Google Colab?: No
  • Python version: 3.10.13
  • PyTorch version (GPU?): 2.4.1+cu118 (True)
  • Flax version (CPU?/GPU?/TPU?): not installed (NA)
  • Jax version: not installed
  • JaxLib version: not installed
  • Huggingface_hub version: 0.25.1
  • Transformers version: 4.45.1
  • Accelerate version: 0.34.2
  • PEFT version: 0.12.0
  • Bitsandbytes version: 0.44.1
  • Safetensors version: 0.4.5
  • xFormers version: 0.0.28.post3
  • Accelerator: NVIDIA GeForce RTX 4060 Laptop GPU, 8188 MiB
  • Using GPU in script?: No
  • Using distributed or parallel set-up in script?: No

Who can help?

Other: @yiyixuxu @DN6

@MikeTkachuk MikeTkachuk added the bug Something isn't working label Nov 3, 2024
@SahilCarterr
Copy link
Contributor

The issue raised here highlights a specific use case: enabling gradient checkpointing during evaluation mode (i.e., eval() mode) when gradients are still needed. This scenario could apply in cases like LoRA (Low-Rank Adaptation) or other fine-tuning strategies where we might still compute gradients on a pre-trained model while in evaluation mode.

Reproduction

import torch
from diffusers.models.unets.unet_2d_blocks import UNetMidBlock2DCrossAttn

# Define the modified UNetMidBlock2DCrossAttn subclass
class ModifiedUNetMidBlock2DCrossAttn(UNetMidBlock2DCrossAttn):
    def forward(self, *args, **kwargs):
        # Enable gradient checkpointing regardless of training/evaluation mode, based on torch.is_grad_enabled()
        if self.training and self.gradient_checkpointing:
            print("Checkpointing is active in eval mode")
        else:
            print("Checkpointing is not active")
        return super().forward(*args, **kwargs)

# Instantiate the modified model
block = ModifiedUNetMidBlock2DCrossAttn(32, 32, 32, cross_attention_dim=32)
block.gradient_checkpointing = True  # Enable gradient checkpointing
block.eval() 

# Run the forward pass with gradients enabled to test checkpointing in eval mode
with torch.set_grad_enabled(True):  # Ensure gradients are enabled
    input_tensor = torch.randn((1, 32, 64, 64))  # Dummy input tensor
    conditioning_tensor = torch.randn((1, 32))    # Dummy conditioning tensor
    output = block(input_tensor, conditioning_tensor)

Output

Checkpointing is not active

Fix

Using torch.is_grad_enabled() Instead:

if torch.is_grad_enabled() and self.gradient_checkpointing:

This approach enables gradient checkpointing as long as gradients are enabled, regardless of whether the model is in train() or eval() mode.
This allows flexibility for tasks like LoRA or other fine-tuning techniques where we might need gradients in eval() mode.
@MikeTkachuk

@sayakpaul
Copy link
Member

@MikeTkachuk could you elaborate a use case where this could be beneficial?

@SahilCarterr your response seems very much LLM-generated :D

@MikeTkachuk
Copy link
Contributor Author

  1. Whenever I need to compute gradients and keep the model frozen. With eval() I prevent model running stats from drifting. A ton of use cases, gradient inspection, teacher model, input optimization, etc. Torch autograd has never had anything to do with 'training=True' flag, so in my opinion gradient checkpointing shouldn't as well.
  2. https://github.com/rohitgandikota/sliders I stumbled upon this while trying to use this repo, it uses Lora on top of SDXL in my case.

@sayakpaul
Copy link
Member

Whenever I need to compute gradients and keep the model frozen.

But you're still computing the gradients for the trainable params that might have been added to the model, LoRA, for example. I am struggling to understand how does the current implementation doesn't let you do that?

With eval() I prevent model running stats from drifting.

This is valid use-case for models that involve layers which have different train/inference behaviours. Batch normalization, for example. But diffusion models typically don't use them. So, maybe you could expand a bit further?

@MikeTkachuk
Copy link
Contributor Author

The current implementation allows gradient checkpointing only when the module is in the training state. As if deliberately preventing me from computing gradients in eval state. Is there any reason for that, maybe diffusers utilize training flag in some unconventional way? Imho it is a serious difference from torch design.

Also, I'm pretty sure the Dropout layer behaves differently in eval state.

@sayakpaul
Copy link
Member

Yes, you're right that dropout has a different inference behaviour (inverse dropout) but it's rarely used in the library, though.

Cc @yiyixuxu do you have any suggestions here?

@MikeTkachuk
Copy link
Contributor Author

The thing is that I'm still able to compute gradients in eval state, but not with gradient checkpointing, which is odd. I wonder what was the reason behind this design decision?

@yiyixuxu
Copy link
Collaborator

yiyixuxu commented Nov 5, 2024

it is indeed a small use case, but I think the proposed change makes a lot of sense! thanks a lot @MikeTkachuk!

I don't think we would break anything here either, if model is in training model, gradient checkpointing is still on; only difference that is now that when model is in eval mode, as long as the gradient computation is enabled and user explicitly turned on the gradient_checkpointing flag, they will now be able to use gradient checkpointing

cc @DN6 here for his opinion too

@MikeTkachuk
Copy link
Contributor Author

MikeTkachuk commented Nov 6, 2024

thanks for your reply. i would like to view it as an artificial and likely unnecessary limitation rather than a use case.
by the way, checkpointing works even with autograd turned off, so you might want to remove that check altogether and leave it as
if self.gradient_checkpointing:

apparently, it is possibly degrading performance of kohya-ss repo because of this
https://github.com/kohya-ss/sd-scripts/blob/ca44e3e447fc1185ce188229b4e1a0f7f3bbbf66/train_network.py#L447
i.e. switches unet to train() mode only in case of gradient checkpointing. regular flow is in eval()

@DN6
Copy link
Collaborator

DN6 commented Nov 6, 2024

Yeah change looks fine to me. Don't think it will break anything. @MikeTkachuk Would you be able to open a PR for it?

# for free to join this conversation on GitHub. Already have an account? # to comment
Labels
bug Something isn't working
Projects
None yet
Development

No branches or pull requests

5 participants