@@ -879,6 +879,9 @@ def __call__(
879
879
scale : float = 1.0 ,
880
880
) -> torch .Tensor :
881
881
residual = hidden_states
882
+
883
+ args = () if USE_PEFT_BACKEND else (scale ,)
884
+
882
885
hidden_states = hidden_states .view (hidden_states .shape [0 ], hidden_states .shape [1 ], - 1 ).transpose (1 , 2 )
883
886
batch_size , sequence_length , _ = hidden_states .shape
884
887
@@ -891,17 +894,17 @@ def __call__(
891
894
892
895
hidden_states = attn .group_norm (hidden_states .transpose (1 , 2 )).transpose (1 , 2 )
893
896
894
- query = attn .to_q (hidden_states , scale = scale )
897
+ query = attn .to_q (hidden_states , * args )
895
898
query = attn .head_to_batch_dim (query )
896
899
897
- encoder_hidden_states_key_proj = attn .add_k_proj (encoder_hidden_states , scale = scale )
898
- encoder_hidden_states_value_proj = attn .add_v_proj (encoder_hidden_states , scale = scale )
900
+ encoder_hidden_states_key_proj = attn .add_k_proj (encoder_hidden_states , * args )
901
+ encoder_hidden_states_value_proj = attn .add_v_proj (encoder_hidden_states , * args )
899
902
encoder_hidden_states_key_proj = attn .head_to_batch_dim (encoder_hidden_states_key_proj )
900
903
encoder_hidden_states_value_proj = attn .head_to_batch_dim (encoder_hidden_states_value_proj )
901
904
902
905
if not attn .only_cross_attention :
903
- key = attn .to_k (hidden_states , scale = scale )
904
- value = attn .to_v (hidden_states , scale = scale )
906
+ key = attn .to_k (hidden_states , * args )
907
+ value = attn .to_v (hidden_states , * args )
905
908
key = attn .head_to_batch_dim (key )
906
909
value = attn .head_to_batch_dim (value )
907
910
key = torch .cat ([encoder_hidden_states_key_proj , key ], dim = 1 )
@@ -915,7 +918,7 @@ def __call__(
915
918
hidden_states = attn .batch_to_head_dim (hidden_states )
916
919
917
920
# linear proj
918
- hidden_states = attn .to_out [0 ](hidden_states , scale = scale )
921
+ hidden_states = attn .to_out [0 ](hidden_states , * args )
919
922
# dropout
920
923
hidden_states = attn .to_out [1 ](hidden_states )
921
924
@@ -946,6 +949,9 @@ def __call__(
946
949
scale : float = 1.0 ,
947
950
) -> torch .Tensor :
948
951
residual = hidden_states
952
+
953
+ args = () if USE_PEFT_BACKEND else (scale ,)
954
+
949
955
hidden_states = hidden_states .view (hidden_states .shape [0 ], hidden_states .shape [1 ], - 1 ).transpose (1 , 2 )
950
956
batch_size , sequence_length , _ = hidden_states .shape
951
957
@@ -958,7 +964,7 @@ def __call__(
958
964
959
965
hidden_states = attn .group_norm (hidden_states .transpose (1 , 2 )).transpose (1 , 2 )
960
966
961
- query = attn .to_q (hidden_states , scale = scale )
967
+ query = attn .to_q (hidden_states , * args )
962
968
query = attn .head_to_batch_dim (query , out_dim = 4 )
963
969
964
970
encoder_hidden_states_key_proj = attn .add_k_proj (encoder_hidden_states )
@@ -967,8 +973,8 @@ def __call__(
967
973
encoder_hidden_states_value_proj = attn .head_to_batch_dim (encoder_hidden_states_value_proj , out_dim = 4 )
968
974
969
975
if not attn .only_cross_attention :
970
- key = attn .to_k (hidden_states , scale = scale )
971
- value = attn .to_v (hidden_states , scale = scale )
976
+ key = attn .to_k (hidden_states , * args )
977
+ value = attn .to_v (hidden_states , * args )
972
978
key = attn .head_to_batch_dim (key , out_dim = 4 )
973
979
value = attn .head_to_batch_dim (value , out_dim = 4 )
974
980
key = torch .cat ([encoder_hidden_states_key_proj , key ], dim = 2 )
@@ -985,7 +991,7 @@ def __call__(
985
991
hidden_states = hidden_states .transpose (1 , 2 ).reshape (batch_size , - 1 , residual .shape [1 ])
986
992
987
993
# linear proj
988
- hidden_states = attn .to_out [0 ](hidden_states , scale = scale )
994
+ hidden_states = attn .to_out [0 ](hidden_states , * args )
989
995
# dropout
990
996
hidden_states = attn .to_out [1 ](hidden_states )
991
997
@@ -1177,6 +1183,8 @@ def __call__(
1177
1183
) -> torch .FloatTensor :
1178
1184
residual = hidden_states
1179
1185
1186
+ args = () if USE_PEFT_BACKEND else (scale ,)
1187
+
1180
1188
if attn .spatial_norm is not None :
1181
1189
hidden_states = attn .spatial_norm (hidden_states , temb )
1182
1190
@@ -1207,12 +1215,8 @@ def __call__(
1207
1215
elif attn .norm_cross :
1208
1216
encoder_hidden_states = attn .norm_encoder_hidden_states (encoder_hidden_states )
1209
1217
1210
- key = (
1211
- attn .to_k (encoder_hidden_states , scale = scale ) if not USE_PEFT_BACKEND else attn .to_k (encoder_hidden_states )
1212
- )
1213
- value = (
1214
- attn .to_v (encoder_hidden_states , scale = scale ) if not USE_PEFT_BACKEND else attn .to_v (encoder_hidden_states )
1215
- )
1218
+ key = attn .to_k (encoder_hidden_states , * args )
1219
+ value = attn .to_v (encoder_hidden_states , * args )
1216
1220
1217
1221
inner_dim = key .shape [- 1 ]
1218
1222
head_dim = inner_dim // attn .heads
@@ -1232,9 +1236,7 @@ def __call__(
1232
1236
hidden_states = hidden_states .to (query .dtype )
1233
1237
1234
1238
# linear proj
1235
- hidden_states = (
1236
- attn .to_out [0 ](hidden_states , scale = scale ) if not USE_PEFT_BACKEND else attn .to_out [0 ](hidden_states )
1237
- )
1239
+ hidden_states = attn .to_out [0 ](hidden_states , * args )
1238
1240
# dropout
1239
1241
hidden_states = attn .to_out [1 ](hidden_states )
1240
1242
0 commit comments