16
16
import argparse
17
17
import copy
18
18
import gc
19
- import itertools
20
19
import logging
21
20
import math
22
21
import os
35
34
from huggingface_hub import create_repo , upload_folder
36
35
from huggingface_hub .utils import insecure_hashlib
37
36
from packaging import version
37
+ from peft import LoraConfig
38
+ from peft .utils import get_peft_model_state_dict
38
39
from PIL import Image
39
40
from PIL .ImageOps import exif_transpose
40
41
from torch .utils .data import Dataset
52
53
UNet2DConditionModel ,
53
54
)
54
55
from diffusers .loaders import LoraLoaderMixin
55
- from diffusers .models .attention_processor import (
56
- AttnAddedKVProcessor ,
57
- AttnAddedKVProcessor2_0 ,
58
- SlicedAttnAddedKVProcessor ,
59
- )
60
- from diffusers .models .lora import LoRALinearLayer
61
56
from diffusers .optimization import get_scheduler
62
- from diffusers .training_utils import unet_lora_state_dict
63
57
from diffusers .utils import check_min_version , is_wandb_available
64
58
from diffusers .utils .import_utils import is_xformers_available
65
59
@@ -864,79 +858,19 @@ def main(args):
864
858
text_encoder .gradient_checkpointing_enable ()
865
859
866
860
# now we will add new LoRA weights to the attention layers
867
- # It's important to realize here how many attention weights will be added and of which sizes
868
- # The sizes of the attention layers consist only of two different variables:
869
- # 1) - the "hidden_size", which is increased according to `unet.config.block_out_channels`.
870
- # 2) - the "cross attention size", which is set to `unet.config.cross_attention_dim`.
871
-
872
- # Let's first see how many attention processors we will have to set.
873
- # For Stable Diffusion, it should be equal to:
874
- # - down blocks (2x attention layers) * (2x transformer layers) * (3x down blocks) = 12
875
- # - mid blocks (2x attention layers) * (1x transformer layers) * (1x mid blocks) = 2
876
- # - up blocks (2x attention layers) * (3x transformer layers) * (3x up blocks) = 18
877
- # => 32 layers
878
-
879
- # Set correct lora layers
880
- unet_lora_parameters = []
881
- for attn_processor_name , attn_processor in unet .attn_processors .items ():
882
- # Parse the attention module.
883
- attn_module = unet
884
- for n in attn_processor_name .split ("." )[:- 1 ]:
885
- attn_module = getattr (attn_module , n )
886
-
887
- # Set the `lora_layer` attribute of the attention-related matrices.
888
- attn_module .to_q .set_lora_layer (
889
- LoRALinearLayer (
890
- in_features = attn_module .to_q .in_features , out_features = attn_module .to_q .out_features , rank = args .rank
891
- )
892
- )
893
- attn_module .to_k .set_lora_layer (
894
- LoRALinearLayer (
895
- in_features = attn_module .to_k .in_features , out_features = attn_module .to_k .out_features , rank = args .rank
896
- )
897
- )
898
- attn_module .to_v .set_lora_layer (
899
- LoRALinearLayer (
900
- in_features = attn_module .to_v .in_features , out_features = attn_module .to_v .out_features , rank = args .rank
901
- )
902
- )
903
- attn_module .to_out [0 ].set_lora_layer (
904
- LoRALinearLayer (
905
- in_features = attn_module .to_out [0 ].in_features ,
906
- out_features = attn_module .to_out [0 ].out_features ,
907
- rank = args .rank ,
908
- )
909
- )
910
-
911
- # Accumulate the LoRA params to optimize.
912
- unet_lora_parameters .extend (attn_module .to_q .lora_layer .parameters ())
913
- unet_lora_parameters .extend (attn_module .to_k .lora_layer .parameters ())
914
- unet_lora_parameters .extend (attn_module .to_v .lora_layer .parameters ())
915
- unet_lora_parameters .extend (attn_module .to_out [0 ].lora_layer .parameters ())
916
-
917
- if isinstance (attn_processor , (AttnAddedKVProcessor , SlicedAttnAddedKVProcessor , AttnAddedKVProcessor2_0 )):
918
- attn_module .add_k_proj .set_lora_layer (
919
- LoRALinearLayer (
920
- in_features = attn_module .add_k_proj .in_features ,
921
- out_features = attn_module .add_k_proj .out_features ,
922
- rank = args .rank ,
923
- )
924
- )
925
- attn_module .add_v_proj .set_lora_layer (
926
- LoRALinearLayer (
927
- in_features = attn_module .add_v_proj .in_features ,
928
- out_features = attn_module .add_v_proj .out_features ,
929
- rank = args .rank ,
930
- )
931
- )
932
- unet_lora_parameters .extend (attn_module .add_k_proj .lora_layer .parameters ())
933
- unet_lora_parameters .extend (attn_module .add_v_proj .lora_layer .parameters ())
861
+ unet_lora_config = LoraConfig (
862
+ r = args .rank ,
863
+ init_lora_weights = "gaussian" ,
864
+ target_modules = ["to_k" , "to_q" , "to_v" , "to_out.0" , "add_k_proj" , "add_v_proj" ],
865
+ )
866
+ unet .add_adapter (unet_lora_config )
934
867
935
- # The text encoder comes from 🤗 transformers, so we cannot directly modify it.
936
- # So, instead, we monkey-patch the forward calls of its attention-blocks.
868
+ # The text encoder comes from 🤗 transformers, we will also attach adapters to it.
937
869
if args .train_text_encoder :
938
- # ensure that dtype is float32, even if rest of the model that isn't trained is loaded in fp16
939
- text_lora_parameters = LoraLoaderMixin ._modify_text_encoder (text_encoder , dtype = torch .float32 , rank = args .rank )
870
+ text_lora_config = LoraConfig (
871
+ r = args .rank , init_lora_weights = "gaussian" , target_modules = ["q_proj" , "k_proj" , "v_proj" , "out_proj" ]
872
+ )
873
+ text_encoder .add_adapter (text_lora_config )
940
874
941
875
# create custom saving & loading hooks so that `accelerator.save_state(...)` serializes in a nice format
942
876
def save_model_hook (models , weights , output_dir ):
@@ -948,9 +882,9 @@ def save_model_hook(models, weights, output_dir):
948
882
949
883
for model in models :
950
884
if isinstance (model , type (accelerator .unwrap_model (unet ))):
951
- unet_lora_layers_to_save = unet_lora_state_dict (model )
885
+ unet_lora_layers_to_save = get_peft_model_state_dict (model )
952
886
elif isinstance (model , type (accelerator .unwrap_model (text_encoder ))):
953
- text_encoder_lora_layers_to_save = text_encoder_lora_state_dict (model )
887
+ text_encoder_lora_layers_to_save = get_peft_model_state_dict (model )
954
888
else :
955
889
raise ValueError (f"unexpected save model: { model .__class__ } " )
956
890
@@ -1010,11 +944,10 @@ def load_model_hook(models, input_dir):
1010
944
optimizer_class = torch .optim .AdamW
1011
945
1012
946
# Optimizer creation
1013
- params_to_optimize = (
1014
- itertools .chain (unet_lora_parameters , text_lora_parameters )
1015
- if args .train_text_encoder
1016
- else unet_lora_parameters
1017
- )
947
+ params_to_optimize = list (filter (lambda p : p .requires_grad , unet .parameters ()))
948
+ if args .train_text_encoder :
949
+ params_to_optimize = params_to_optimize + list (filter (lambda p : p .requires_grad , text_encoder .parameters ()))
950
+
1018
951
optimizer = optimizer_class (
1019
952
params_to_optimize ,
1020
953
lr = args .learning_rate ,
@@ -1257,12 +1190,7 @@ def compute_text_embeddings(prompt):
1257
1190
1258
1191
accelerator .backward (loss )
1259
1192
if accelerator .sync_gradients :
1260
- params_to_clip = (
1261
- itertools .chain (unet_lora_parameters , text_lora_parameters )
1262
- if args .train_text_encoder
1263
- else unet_lora_parameters
1264
- )
1265
- accelerator .clip_grad_norm_ (params_to_clip , args .max_grad_norm )
1193
+ accelerator .clip_grad_norm_ (params_to_optimize , args .max_grad_norm )
1266
1194
optimizer .step ()
1267
1195
lr_scheduler .step ()
1268
1196
optimizer .zero_grad ()
@@ -1385,19 +1313,19 @@ def compute_text_embeddings(prompt):
1385
1313
if accelerator .is_main_process :
1386
1314
unet = accelerator .unwrap_model (unet )
1387
1315
unet = unet .to (torch .float32 )
1388
- unet_lora_layers = unet_lora_state_dict (unet )
1389
1316
1390
- if text_encoder is not None and args .train_text_encoder :
1317
+ unet_lora_state_dict = get_peft_model_state_dict (unet )
1318
+
1319
+ if args .train_text_encoder :
1391
1320
text_encoder = accelerator .unwrap_model (text_encoder )
1392
- text_encoder = text_encoder .to (torch .float32 )
1393
- text_encoder_lora_layers = text_encoder_lora_state_dict (text_encoder )
1321
+ text_encoder_state_dict = get_peft_model_state_dict (text_encoder )
1394
1322
else :
1395
- text_encoder_lora_layers = None
1323
+ text_encoder_state_dict = None
1396
1324
1397
1325
LoraLoaderMixin .save_lora_weights (
1398
1326
save_directory = args .output_dir ,
1399
- unet_lora_layers = unet_lora_layers ,
1400
- text_encoder_lora_layers = text_encoder_lora_layers ,
1327
+ unet_lora_layers = unet_lora_state_dict ,
1328
+ text_encoder_lora_layers = text_encoder_state_dict ,
1401
1329
)
1402
1330
1403
1331
# Final inference
0 commit comments