-
Notifications
You must be signed in to change notification settings - Fork 5.8k
New issue
Have a question about this project? # for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “#”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? # to your account
[feat] add load_lora_adapter()
for compatible models
#9712
Conversation
The docs for this PR live here. All of your documentation changes will be reflected on that endpoint. The docs are available until 30 days after the last update. |
@@ -181,6 +181,123 @@ def _remove_text_encoder_monkey_patch(text_encoder): | |||
text_encoder._hf_peft_config_loaded = None | |||
|
|||
|
|||
def _fetch_state_dict( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Just taking it out of the class to be able to better reuse.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Nice clean up 👍🏽
@@ -233,126 +350,6 @@ def _optionally_disable_offloading(cls, _pipeline): | |||
|
|||
return (is_model_cpu_offload, is_sequential_cpu_offload) | |||
|
|||
@classmethod | |||
def _fetch_state_dict( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
These are internal methods, so it should be okay to move them around. But would be good to run a quick Github search to see if they aren't being used directly somewhere? Just to sanity check that we don't backwards break anything.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Valid. I deprecated and added tests.
@DN6 LMK what you think of the latest changes. Additionally, what do you think about the |
@BenjaminBossan could you give this a look too? |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks for this refactor, always happy to see more lines being removed than added. I didn't check the functions that were moved around, as I assume they were left identical. Regarding the rest, just some smaller comments.
_pipeline=self, | ||
low_cpu_mem_usage=low_cpu_mem_usage, | ||
) | ||
transformer_state_dict = {k: v for k, v in state_dict.items() if "transformer." in k} |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
For my understanding: This is a fix independent of the main change of this PR, right? Would it be possible to move this check inside of load_lora_into_transformer
or would that not be a good idea?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think it kind of depends on the entrypoint to the underlying method.
We already have a similar check within load_lora_adapter()
:
diffusers/src/diffusers/loaders/peft.py
Line 153 in e187b70
transformer_keys = [k for k in keys if k.startswith(prefix)] |
So, I think it should be fine, without.
@@ -341,7 +339,9 @@ def load_lora_into_text_encoder( | |||
adapter_name (`str`, *optional*): | |||
Adapter name to be used for referencing the loaded adapter model. If not specified, it will use | |||
`default_{i}` where i is the total number of adapters being loaded. | |||
Speed up model loading by only loading the pretrained LoRA weights and not initializing the random weights.: | |||
low_cpu_mem_usage (`bool`, *optional*): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I saw that too, thx for fixing.
Co-authored-by: Benjamin Bossan <BenjaminBossan@users.noreply.github.com>
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks for addressing my comments, LGTM.
Ran the Flux integration tests and they pass. Failing tests are unrelated. |
* add first draft. * fix * updates. * updates. * updates * updates * updates. * fix-copies * lora constants. * add tests * Apply suggestions from code review Co-authored-by: Benjamin Bossan <BenjaminBossan@users.noreply.github.com> * docstrings. --------- Co-authored-by: Benjamin Bossan <BenjaminBossan@users.noreply.github.com>
What does this PR do?
Similar to
load_attn_procs()
, we want to have something similar for loading LoRAs into models, as the LoRA loading logic is generic.This way, we can reduce the LoC and have better maintainability. I am not too fixated on the
load_lora_adapter()
name. Could also doload_adapter()
.@DN6 as discussed via Slack, could you give this a check? Could also add a
save_lora_adapter()
method to complement.