Skip to content

Commit 11c5f6b

Browse files
committed
Revert "[LoRA] introduce LoraBaseMixin to promote reusability." (#8773)
Revert "[LoRA] introduce `LoraBaseMixin` to promote reusability. (#8670)" This reverts commit a2071a1.
1 parent 2686552 commit 11c5f6b

File tree

13 files changed

+1708
-2255
lines changed

13 files changed

+1708
-2255
lines changed

docs/source/en/api/loaders/lora.md

+2-17
Original file line numberDiff line numberDiff line change
@@ -12,13 +12,10 @@ specific language governing permissions and limitations under the License.
1212

1313
# LoRA
1414

15-
LoRA is a fast and lightweight training method that inserts and trains a significantly smaller number of parameters instead of all the model parameters. This produces a smaller file (~100 MBs) and makes it easier to quickly train a model to learn a new concept. LoRA weights are typically loaded into the denoiser, text encoder or both. The denoiser usually corresponds to a UNet ([`UNet2DConditionModel`], for example) or a Transformer ([`SD3Transformer2DModel`], for example). There are several classes for loading LoRA weights:
15+
LoRA is a fast and lightweight training method that inserts and trains a significantly smaller number of parameters instead of all the model parameters. This produces a smaller file (~100 MBs) and makes it easier to quickly train a model to learn a new concept. LoRA weights are typically loaded into the UNet, text encoder or both. There are two classes for loading LoRA weights:
1616

1717
- [`LoraLoaderMixin`] provides functions for loading and unloading, fusing and unfusing, enabling and disabling, and more functions for managing LoRA weights. This class can be used with any model.
1818
- [`StableDiffusionXLLoraLoaderMixin`] is a [Stable Diffusion (SDXL)](../../api/pipelines/stable_diffusion/stable_diffusion_xl) version of the [`LoraLoaderMixin`] class for loading and saving LoRA weights. It can only be used with the SDXL model.
19-
- [`SD3LoraLoaderMixin`] provides similar functions for [Stable Diffusion 3](https://huggingface.co/blog/sd3).
20-
- [`AmusedLoraLoaderMixin`] is for the [`AmusedPipeline`].
21-
- [`LoraBaseMixin`] provides a base class with several utility methods to fuse, unfuse, unload, LoRAs and more.
2219

2320
<Tip>
2421

@@ -32,16 +29,4 @@ To learn more about how to load LoRA weights, see the [LoRA](../../using-diffuse
3229

3330
## StableDiffusionXLLoraLoaderMixin
3431

35-
[[autodoc]] loaders.lora.StableDiffusionXLLoraLoaderMixin
36-
37-
## SD3LoraLoaderMixin
38-
39-
[[autodoc]] loaders.lora.SD3LoraLoaderMixin
40-
41-
## AmusedLoraLoaderMixin
42-
43-
[[autodoc]] loaders.lora.AmusedLoraLoaderMixin
44-
45-
## LoraBaseMixin
46-
47-
[[autodoc]] loaders.lora_base.LoraBaseMixin
32+
[[autodoc]] loaders.lora.StableDiffusionXLLoraLoaderMixin

examples/amused/train_amused.py

+5-5
Original file line numberDiff line numberDiff line change
@@ -41,7 +41,7 @@
4141

4242
import diffusers.optimization
4343
from diffusers import AmusedPipeline, AmusedScheduler, EMAModel, UVit2DModel, VQModel
44-
from diffusers.loaders import AmusedLoraLoaderMixin
44+
from diffusers.loaders import LoraLoaderMixin
4545
from diffusers.utils import is_wandb_available
4646

4747

@@ -532,7 +532,7 @@ def save_model_hook(models, weights, output_dir):
532532
weights.pop()
533533

534534
if transformer_lora_layers_to_save is not None or text_encoder_lora_layers_to_save is not None:
535-
AmusedLoraLoaderMixin.save_lora_weights(
535+
LoraLoaderMixin.save_lora_weights(
536536
output_dir,
537537
transformer_lora_layers=transformer_lora_layers_to_save,
538538
text_encoder_lora_layers=text_encoder_lora_layers_to_save,
@@ -566,11 +566,11 @@ def load_model_hook(models, input_dir):
566566
raise ValueError(f"unexpected save model: {model.__class__}")
567567

568568
if transformer is not None or text_encoder_ is not None:
569-
lora_state_dict, network_alphas = AmusedLoraLoaderMixin.lora_state_dict(input_dir)
570-
AmusedLoraLoaderMixin.load_lora_into_text_encoder(
569+
lora_state_dict, network_alphas = LoraLoaderMixin.lora_state_dict(input_dir)
570+
LoraLoaderMixin.load_lora_into_text_encoder(
571571
lora_state_dict, network_alphas=network_alphas, text_encoder=text_encoder_
572572
)
573-
AmusedLoraLoaderMixin.load_lora_into_transformer(
573+
LoraLoaderMixin.load_lora_into_transformer(
574574
lora_state_dict, network_alphas=network_alphas, transformer=transformer
575575
)
576576

src/diffusers/loaders/__init__.py

+2-15
Original file line numberDiff line numberDiff line change
@@ -55,18 +55,11 @@ def text_encoder_attn_modules(text_encoder):
5555

5656
if is_torch_available():
5757
_import_structure["single_file_model"] = ["FromOriginalModelMixin"]
58-
_import_structure["transformer_sd3"] = ["SD3TransformerLoadersMixin"]
59-
6058
_import_structure["unet"] = ["UNet2DConditionLoadersMixin"]
6159
_import_structure["utils"] = ["AttnProcsLayers"]
6260
if is_transformers_available():
6361
_import_structure["single_file"] = ["FromSingleFileMixin"]
64-
_import_structure["lora"] = [
65-
"AmusedLoraLoaderMixin",
66-
"LoraLoaderMixin",
67-
"SD3LoraLoaderMixin",
68-
"StableDiffusionXLLoraLoaderMixin",
69-
]
62+
_import_structure["lora"] = ["LoraLoaderMixin", "StableDiffusionXLLoraLoaderMixin", "SD3LoraLoaderMixin"]
7063
_import_structure["textual_inversion"] = ["TextualInversionLoaderMixin"]
7164
_import_structure["ip_adapter"] = ["IPAdapterMixin"]
7265

@@ -76,18 +69,12 @@ def text_encoder_attn_modules(text_encoder):
7669
if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
7770
if is_torch_available():
7871
from .single_file_model import FromOriginalModelMixin
79-
from .transformer_sd3 import SD3TransformerLoadersMixin
8072
from .unet import UNet2DConditionLoadersMixin
8173
from .utils import AttnProcsLayers
8274

8375
if is_transformers_available():
8476
from .ip_adapter import IPAdapterMixin
85-
from .lora import (
86-
AmusedLoraLoaderMixin,
87-
LoraLoaderMixin,
88-
SD3LoraLoaderMixin,
89-
StableDiffusionXLLoraLoaderMixin,
90-
)
77+
from .lora import LoraLoaderMixin, SD3LoraLoaderMixin, StableDiffusionXLLoraLoaderMixin
9178
from .single_file import FromSingleFileMixin
9279
from .textual_inversion import TextualInversionLoaderMixin
9380

0 commit comments

Comments
 (0)