-
Notifications
You must be signed in to change notification settings - Fork 5.8k
New issue
Have a question about this project? # for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “#”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? # to your account
make gradient checkpointing with frozen model possible #9850
Comments
The issue raised here highlights a specific use case: enabling gradient checkpointing during evaluation mode (i.e., eval() mode) when gradients are still needed. This scenario could apply in cases like LoRA (Low-Rank Adaptation) or other fine-tuning strategies where we might still compute gradients on a pre-trained model while in evaluation mode. Reproductionimport torch
from diffusers.models.unets.unet_2d_blocks import UNetMidBlock2DCrossAttn
# Define the modified UNetMidBlock2DCrossAttn subclass
class ModifiedUNetMidBlock2DCrossAttn(UNetMidBlock2DCrossAttn):
def forward(self, *args, **kwargs):
# Enable gradient checkpointing regardless of training/evaluation mode, based on torch.is_grad_enabled()
if self.training and self.gradient_checkpointing:
print("Checkpointing is active in eval mode")
else:
print("Checkpointing is not active")
return super().forward(*args, **kwargs)
# Instantiate the modified model
block = ModifiedUNetMidBlock2DCrossAttn(32, 32, 32, cross_attention_dim=32)
block.gradient_checkpointing = True # Enable gradient checkpointing
block.eval()
# Run the forward pass with gradients enabled to test checkpointing in eval mode
with torch.set_grad_enabled(True): # Ensure gradients are enabled
input_tensor = torch.randn((1, 32, 64, 64)) # Dummy input tensor
conditioning_tensor = torch.randn((1, 32)) # Dummy conditioning tensor
output = block(input_tensor, conditioning_tensor) Output
FixUsing torch.is_grad_enabled() Instead: if torch.is_grad_enabled() and self.gradient_checkpointing: This approach enables gradient checkpointing as long as gradients are enabled, regardless of whether the model is in train() or eval() mode. |
@MikeTkachuk could you elaborate a use case where this could be beneficial? @SahilCarterr your response seems very much LLM-generated :D |
|
But you're still computing the gradients for the trainable params that might have been added to the model, LoRA, for example. I am struggling to understand how does the current implementation doesn't let you do that?
This is valid use-case for models that involve layers which have different train/inference behaviours. Batch normalization, for example. But diffusion models typically don't use them. So, maybe you could expand a bit further? |
The current implementation allows gradient checkpointing only when the module is in the training state. As if deliberately preventing me from computing gradients in eval state. Is there any reason for that, maybe diffusers utilize training flag in some unconventional way? Imho it is a serious difference from torch design. Also, I'm pretty sure the Dropout layer behaves differently in eval state. |
Yes, you're right that dropout has a different inference behaviour (inverse dropout) but it's rarely used in the library, though. Cc @yiyixuxu do you have any suggestions here? |
The thing is that I'm still able to compute gradients in eval state, but not with gradient checkpointing, which is odd. I wonder what was the reason behind this design decision? |
it is indeed a small use case, but I think the proposed change makes a lot of sense! thanks a lot @MikeTkachuk! I don't think we would break anything here either, if model is in training model, gradient checkpointing is still on; only difference that is now that when model is in eval mode, as long as the gradient computation is enabled and user explicitly turned on the gradient_checkpointing flag, they will now be able to use gradient checkpointing cc @DN6 here for his opinion too |
thanks for your reply. i would like to view it as an artificial and likely unnecessary limitation rather than a use case. apparently, it is possibly degrading performance of kohya-ss repo because of this |
Yeah change looks fine to me. Don't think it will break anything. @MikeTkachuk Would you be able to open a PR for it? |
Describe the bug
diffusers/src/diffusers/models/unets/unet_2d_blocks.py
Line 862 in 89e4d62
hi, the clause i highlighted in the link above prevents a model from using gradient checkpointing in eval mode. this is particularly useful for e.g. LORAs.
perhaps you meant to check something like this instead?
if torch.is_grad_enabled() and self.gradient_checkpointing:
Same for any other module in the repo
Reproduction
Logs
No response
System Info
Who can help?
Other: @yiyixuxu @DN6
The text was updated successfully, but these errors were encountered: