Skip to content

Commit 42bdd74

Browse files
okotakupatrickvonplaten
authored andcommitted
[Bugfix] fix error of peft lora when xformers enabled (huggingface#5697)
* bugfix peft lor * Apply suggestions from code review --------- Co-authored-by: Patrick von Platen <patrick.v.platen@gmail.com>
1 parent ca83ff1 commit 42bdd74

File tree

1 file changed

+21
-19
lines changed

1 file changed

+21
-19
lines changed

src/diffusers/models/attention_processor.py

Lines changed: 21 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -879,6 +879,9 @@ def __call__(
879879
scale: float = 1.0,
880880
) -> torch.Tensor:
881881
residual = hidden_states
882+
883+
args = () if USE_PEFT_BACKEND else (scale,)
884+
882885
hidden_states = hidden_states.view(hidden_states.shape[0], hidden_states.shape[1], -1).transpose(1, 2)
883886
batch_size, sequence_length, _ = hidden_states.shape
884887

@@ -891,17 +894,17 @@ def __call__(
891894

892895
hidden_states = attn.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2)
893896

894-
query = attn.to_q(hidden_states, scale=scale)
897+
query = attn.to_q(hidden_states, *args)
895898
query = attn.head_to_batch_dim(query)
896899

897-
encoder_hidden_states_key_proj = attn.add_k_proj(encoder_hidden_states, scale=scale)
898-
encoder_hidden_states_value_proj = attn.add_v_proj(encoder_hidden_states, scale=scale)
900+
encoder_hidden_states_key_proj = attn.add_k_proj(encoder_hidden_states, *args)
901+
encoder_hidden_states_value_proj = attn.add_v_proj(encoder_hidden_states, *args)
899902
encoder_hidden_states_key_proj = attn.head_to_batch_dim(encoder_hidden_states_key_proj)
900903
encoder_hidden_states_value_proj = attn.head_to_batch_dim(encoder_hidden_states_value_proj)
901904

902905
if not attn.only_cross_attention:
903-
key = attn.to_k(hidden_states, scale=scale)
904-
value = attn.to_v(hidden_states, scale=scale)
906+
key = attn.to_k(hidden_states, *args)
907+
value = attn.to_v(hidden_states, *args)
905908
key = attn.head_to_batch_dim(key)
906909
value = attn.head_to_batch_dim(value)
907910
key = torch.cat([encoder_hidden_states_key_proj, key], dim=1)
@@ -915,7 +918,7 @@ def __call__(
915918
hidden_states = attn.batch_to_head_dim(hidden_states)
916919

917920
# linear proj
918-
hidden_states = attn.to_out[0](hidden_states, scale=scale)
921+
hidden_states = attn.to_out[0](hidden_states, *args)
919922
# dropout
920923
hidden_states = attn.to_out[1](hidden_states)
921924

@@ -946,6 +949,9 @@ def __call__(
946949
scale: float = 1.0,
947950
) -> torch.Tensor:
948951
residual = hidden_states
952+
953+
args = () if USE_PEFT_BACKEND else (scale,)
954+
949955
hidden_states = hidden_states.view(hidden_states.shape[0], hidden_states.shape[1], -1).transpose(1, 2)
950956
batch_size, sequence_length, _ = hidden_states.shape
951957

@@ -958,7 +964,7 @@ def __call__(
958964

959965
hidden_states = attn.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2)
960966

961-
query = attn.to_q(hidden_states, scale=scale)
967+
query = attn.to_q(hidden_states, *args)
962968
query = attn.head_to_batch_dim(query, out_dim=4)
963969

964970
encoder_hidden_states_key_proj = attn.add_k_proj(encoder_hidden_states)
@@ -967,8 +973,8 @@ def __call__(
967973
encoder_hidden_states_value_proj = attn.head_to_batch_dim(encoder_hidden_states_value_proj, out_dim=4)
968974

969975
if not attn.only_cross_attention:
970-
key = attn.to_k(hidden_states, scale=scale)
971-
value = attn.to_v(hidden_states, scale=scale)
976+
key = attn.to_k(hidden_states, *args)
977+
value = attn.to_v(hidden_states, *args)
972978
key = attn.head_to_batch_dim(key, out_dim=4)
973979
value = attn.head_to_batch_dim(value, out_dim=4)
974980
key = torch.cat([encoder_hidden_states_key_proj, key], dim=2)
@@ -985,7 +991,7 @@ def __call__(
985991
hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, residual.shape[1])
986992

987993
# linear proj
988-
hidden_states = attn.to_out[0](hidden_states, scale=scale)
994+
hidden_states = attn.to_out[0](hidden_states, *args)
989995
# dropout
990996
hidden_states = attn.to_out[1](hidden_states)
991997

@@ -1177,6 +1183,8 @@ def __call__(
11771183
) -> torch.FloatTensor:
11781184
residual = hidden_states
11791185

1186+
args = () if USE_PEFT_BACKEND else (scale,)
1187+
11801188
if attn.spatial_norm is not None:
11811189
hidden_states = attn.spatial_norm(hidden_states, temb)
11821190

@@ -1207,12 +1215,8 @@ def __call__(
12071215
elif attn.norm_cross:
12081216
encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states)
12091217

1210-
key = (
1211-
attn.to_k(encoder_hidden_states, scale=scale) if not USE_PEFT_BACKEND else attn.to_k(encoder_hidden_states)
1212-
)
1213-
value = (
1214-
attn.to_v(encoder_hidden_states, scale=scale) if not USE_PEFT_BACKEND else attn.to_v(encoder_hidden_states)
1215-
)
1218+
key = attn.to_k(encoder_hidden_states, *args)
1219+
value = attn.to_v(encoder_hidden_states, *args)
12161220

12171221
inner_dim = key.shape[-1]
12181222
head_dim = inner_dim // attn.heads
@@ -1232,9 +1236,7 @@ def __call__(
12321236
hidden_states = hidden_states.to(query.dtype)
12331237

12341238
# linear proj
1235-
hidden_states = (
1236-
attn.to_out[0](hidden_states, scale=scale) if not USE_PEFT_BACKEND else attn.to_out[0](hidden_states)
1237-
)
1239+
hidden_states = attn.to_out[0](hidden_states, *args)
12381240
# dropout
12391241
hidden_states = attn.to_out[1](hidden_states)
12401242

0 commit comments

Comments
 (0)