Skip to content
New issue

Have a question about this project? # for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “#”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? # to your account

[feat] add load_lora_adapter() for compatible models #9712

Merged
merged 18 commits into from
Nov 2, 2024
Merged

Conversation

sayakpaul
Copy link
Member

@sayakpaul sayakpaul commented Oct 18, 2024

What does this PR do?

Similar to load_attn_procs(), we want to have something similar for loading LoRAs into models, as the LoRA loading logic is generic.

This way, we can reduce the LoC and have better maintainability. I am not too fixated on the load_lora_adapter() name. Could also do load_adapter().

@DN6 as discussed via Slack, could you give this a check? Could also add a save_lora_adapter() method to complement.

@sayakpaul sayakpaul requested a review from DN6 October 18, 2024 15:44
@HuggingFaceDocBuilderDev

The docs for this PR live here. All of your documentation changes will be reflected on that endpoint. The docs are available until 30 days after the last update.

@@ -181,6 +181,123 @@ def _remove_text_encoder_monkey_patch(text_encoder):
text_encoder._hf_peft_config_loaded = None


def _fetch_state_dict(
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Just taking it out of the class to be able to better reuse.

Copy link
Collaborator

@DN6 DN6 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nice clean up 👍🏽

@@ -233,126 +350,6 @@ def _optionally_disable_offloading(cls, _pipeline):

return (is_model_cpu_offload, is_sequential_cpu_offload)

@classmethod
def _fetch_state_dict(
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

These are internal methods, so it should be okay to move them around. But would be good to run a quick Github search to see if they aren't being used directly somewhere? Just to sanity check that we don't backwards break anything.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Valid. I deprecated and added tests.

@sayakpaul sayakpaul requested a review from DN6 October 24, 2024 15:26
@sayakpaul
Copy link
Member Author

@DN6 LMK what you think of the latest changes.

Additionally, what do you think about the save_lora_adapter() method? Can do in another PR, LMK.

@sayakpaul
Copy link
Member Author

@BenjaminBossan could you give this a look too?

Copy link
Member

@BenjaminBossan BenjaminBossan left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for this refactor, always happy to see more lines being removed than added. I didn't check the functions that were moved around, as I assume they were left identical. Regarding the rest, just some smaller comments.

_pipeline=self,
low_cpu_mem_usage=low_cpu_mem_usage,
)
transformer_state_dict = {k: v for k, v in state_dict.items() if "transformer." in k}
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

For my understanding: This is a fix independent of the main change of this PR, right? Would it be possible to move this check inside of load_lora_into_transformer or would that not be a good idea?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think it kind of depends on the entrypoint to the underlying method.

We already have a similar check within load_lora_adapter():

transformer_keys = [k for k in keys if k.startswith(prefix)]

So, I think it should be fine, without.

@@ -341,7 +339,9 @@ def load_lora_into_text_encoder(
adapter_name (`str`, *optional*):
Adapter name to be used for referencing the loaded adapter model. If not specified, it will use
`default_{i}` where i is the total number of adapters being loaded.
Speed up model loading by only loading the pretrained LoRA weights and not initializing the random weights.:
low_cpu_mem_usage (`bool`, *optional*):
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I saw that too, thx for fixing.

sayakpaul and others added 2 commits November 1, 2024 10:41
Co-authored-by: Benjamin Bossan <BenjaminBossan@users.noreply.github.com>
Copy link
Member

@BenjaminBossan BenjaminBossan left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for addressing my comments, LGTM.

@sayakpaul
Copy link
Member Author

Ran the Flux integration tests and they pass. Failing tests are unrelated.

@sayakpaul sayakpaul merged commit 13e8fde into main Nov 2, 2024
15 of 18 checks passed
@sayakpaul sayakpaul deleted the lora-load-adapter branch November 2, 2024 04:20
sayakpaul added a commit that referenced this pull request Dec 23, 2024
* add first draft.

* fix

* updates.

* updates.

* updates

* updates

* updates.

* fix-copies

* lora constants.

* add tests

* Apply suggestions from code review

Co-authored-by: Benjamin Bossan <BenjaminBossan@users.noreply.github.com>

* docstrings.

---------

Co-authored-by: Benjamin Bossan <BenjaminBossan@users.noreply.github.com>
# for free to join this conversation on GitHub. Already have an account? # to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

4 participants