|
67 | 67 | convert_state_dict_to_diffusers,
|
68 | 68 | convert_state_dict_to_kohya,
|
69 | 69 | convert_unet_state_dict_to_peft,
|
| 70 | + is_peft_version, |
70 | 71 | is_wandb_available,
|
71 | 72 | )
|
72 | 73 | from diffusers.utils.hub_utils import load_or_create_model_card, populate_model_card
|
@@ -1183,26 +1184,33 @@ def main(args):
|
1183 | 1184 | text_encoder_one.gradient_checkpointing_enable()
|
1184 | 1185 | text_encoder_two.gradient_checkpointing_enable()
|
1185 | 1186 |
|
| 1187 | + def get_lora_config(rank, use_dora, target_modules): |
| 1188 | + base_config = { |
| 1189 | + "r": rank, |
| 1190 | + "lora_alpha": rank, |
| 1191 | + "init_lora_weights": "gaussian", |
| 1192 | + "target_modules": target_modules, |
| 1193 | + } |
| 1194 | + if use_dora: |
| 1195 | + if is_peft_version("<", "0.9.0"): |
| 1196 | + raise ValueError( |
| 1197 | + "You need `peft` 0.9.0 at least to use DoRA-enabled LoRAs. Please upgrade your installation of `peft`." |
| 1198 | + ) |
| 1199 | + else: |
| 1200 | + base_config["use_dora"] = True |
| 1201 | + |
| 1202 | + return LoraConfig(**base_config) |
| 1203 | + |
1186 | 1204 | # now we will add new LoRA weights to the attention layers
|
1187 |
| - unet_lora_config = LoraConfig( |
1188 |
| - r=args.rank, |
1189 |
| - use_dora=args.use_dora, |
1190 |
| - lora_alpha=args.rank, |
1191 |
| - init_lora_weights="gaussian", |
1192 |
| - target_modules=["to_k", "to_q", "to_v", "to_out.0"], |
1193 |
| - ) |
| 1205 | + unet_target_modules = ["to_k", "to_q", "to_v", "to_out.0"] |
| 1206 | + unet_lora_config = get_lora_config(rank=args.rank, use_dora=args.use_dora, target_modules=unet_target_modules) |
1194 | 1207 | unet.add_adapter(unet_lora_config)
|
1195 | 1208 |
|
1196 | 1209 | # The text encoder comes from 🤗 transformers, so we cannot directly modify it.
|
1197 | 1210 | # So, instead, we monkey-patch the forward calls of its attention-blocks.
|
1198 | 1211 | if args.train_text_encoder:
|
1199 |
| - text_lora_config = LoraConfig( |
1200 |
| - r=args.rank, |
1201 |
| - use_dora=args.use_dora, |
1202 |
| - lora_alpha=args.rank, |
1203 |
| - init_lora_weights="gaussian", |
1204 |
| - target_modules=["q_proj", "k_proj", "v_proj", "out_proj"], |
1205 |
| - ) |
| 1212 | + text_target_modules = ["q_proj", "k_proj", "v_proj", "out_proj"] |
| 1213 | + text_lora_config = get_lora_config(rank=args.rank, use_dora=args.use_dora, target_modules=text_target_modules) |
1206 | 1214 | text_encoder_one.add_adapter(text_lora_config)
|
1207 | 1215 | text_encoder_two.add_adapter(text_lora_config)
|
1208 | 1216 |
|
|
0 commit comments