Skip to content
New issue

Have a question about this project? # for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “#”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? # to your account

Remove instance-level forward on unpatch #2196

Open
wants to merge 1 commit into
base: main
Choose a base branch
from

Conversation

AlexanderDokuchaev
Copy link

@AlexanderDokuchaev AlexanderDokuchaev commented Feb 21, 2025

What does this PR do?

Prevents unexpected instance-level overrides of the forward function after the original method is restored in ModelPatcher

assert "forward" not in model.__dict__
patcher = ModelPatcher(export_config, model)
with patcher:
    pass
assert "forward" not in model.__dict__  # AssertionError

Before submitting

  • This PR fixes a typo or improves the docs (you can dismiss the other checks if that's the case).
  • Did you make sure to update the documentation with your changes?
  • Did you write any new necessary tests?

Who can review?

@AlexanderDokuchaev
Copy link
Author

Example when "forward" in model.__dict__ raise exception

from transformers import AutoModelForImageClassification

import torch

model = AutoModelForImageClassification.from_pretrained("microsoft/resnet-18").cuda()
inputs = torch.rand(10, 3, 224, 224).cuda()
# Set forward like https://github.com/huggingface/optimum/blob/v1.24.0-release/optimum/exporters/onnx/model_patcher.py#L292
setattr(model, "forward", model.forward)
data_parallel = torch.nn.DataParallel(model)
data_parallel(inputs)  # RuntimeError: Expected all tensors to be on the same device, but found at least two devices, cuda:1 and cuda:0!

# for free to join this conversation on GitHub. Already have an account? # to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

1 participant