Skip to content

Commit 1616a63

Browse files
committed
[Core] add QKV fusion to AuraFlow and PixArt Sigma (#8952)
* add fusion support to pixart * add to auraflow. * add tests * apply review feedback. * add back args and kwargs * style
1 parent 51f45da commit 1616a63

File tree

5 files changed

+344
-9
lines changed

5 files changed

+344
-9
lines changed

src/diffusers/models/attention_processor.py

+106-5
Original file line numberDiff line numberDiff line change
@@ -227,6 +227,7 @@ def __init__(
227227
self.to_k = None
228228
self.to_v = None
229229

230+
self.added_proj_bias = added_proj_bias
230231
if self.added_kv_proj_dim is not None:
231232
self.add_k_proj = nn.Linear(added_kv_proj_dim, self.inner_kv_dim, bias=added_proj_bias)
232233
self.add_v_proj = nn.Linear(added_kv_proj_dim, self.inner_kv_dim, bias=added_proj_bias)
@@ -698,12 +699,15 @@ def fuse_projections(self, fuse=True):
698699
in_features = concatenated_weights.shape[1]
699700
out_features = concatenated_weights.shape[0]
700701

701-
self.to_added_qkv = nn.Linear(in_features, out_features, bias=True, device=device, dtype=dtype)
702-
self.to_added_qkv.weight.copy_(concatenated_weights)
703-
concatenated_bias = torch.cat(
704-
[self.add_q_proj.bias.data, self.add_k_proj.bias.data, self.add_v_proj.bias.data]
702+
self.to_added_qkv = nn.Linear(
703+
in_features, out_features, bias=self.added_proj_bias, device=device, dtype=dtype
705704
)
706-
self.to_added_qkv.bias.copy_(concatenated_bias)
705+
self.to_added_qkv.weight.copy_(concatenated_weights)
706+
if self.added_proj_bias:
707+
concatenated_bias = torch.cat(
708+
[self.add_q_proj.bias.data, self.add_k_proj.bias.data, self.add_v_proj.bias.data]
709+
)
710+
self.to_added_qkv.bias.copy_(concatenated_bias)
707711

708712
self.fused_projections = fuse
709713

@@ -1274,6 +1278,103 @@ def __call__(
12741278
return hidden_states
12751279

12761280

1281+
class FusedAuraFlowAttnProcessor2_0:
1282+
"""Attention processor used typically in processing Aura Flow with fused projections."""
1283+
1284+
def __init__(self):
1285+
if not hasattr(F, "scaled_dot_product_attention") and is_torch_version("<", "2.1"):
1286+
raise ImportError(
1287+
"FusedAuraFlowAttnProcessor2_0 requires PyTorch 2.0, to use it, please upgrade PyTorch to at least 2.1 or above as we use `scale` in `F.scaled_dot_product_attention()`. "
1288+
)
1289+
1290+
def __call__(
1291+
self,
1292+
attn: Attention,
1293+
hidden_states: torch.FloatTensor,
1294+
encoder_hidden_states: torch.FloatTensor = None,
1295+
*args,
1296+
**kwargs,
1297+
) -> torch.FloatTensor:
1298+
batch_size = hidden_states.shape[0]
1299+
1300+
# `sample` projections.
1301+
qkv = attn.to_qkv(hidden_states)
1302+
split_size = qkv.shape[-1] // 3
1303+
query, key, value = torch.split(qkv, split_size, dim=-1)
1304+
1305+
# `context` projections.
1306+
if encoder_hidden_states is not None:
1307+
encoder_qkv = attn.to_added_qkv(encoder_hidden_states)
1308+
split_size = encoder_qkv.shape[-1] // 3
1309+
(
1310+
encoder_hidden_states_query_proj,
1311+
encoder_hidden_states_key_proj,
1312+
encoder_hidden_states_value_proj,
1313+
) = torch.split(encoder_qkv, split_size, dim=-1)
1314+
1315+
# Reshape.
1316+
inner_dim = key.shape[-1]
1317+
head_dim = inner_dim // attn.heads
1318+
query = query.view(batch_size, -1, attn.heads, head_dim)
1319+
key = key.view(batch_size, -1, attn.heads, head_dim)
1320+
value = value.view(batch_size, -1, attn.heads, head_dim)
1321+
1322+
# Apply QK norm.
1323+
if attn.norm_q is not None:
1324+
query = attn.norm_q(query)
1325+
if attn.norm_k is not None:
1326+
key = attn.norm_k(key)
1327+
1328+
# Concatenate the projections.
1329+
if encoder_hidden_states is not None:
1330+
encoder_hidden_states_query_proj = encoder_hidden_states_query_proj.view(
1331+
batch_size, -1, attn.heads, head_dim
1332+
)
1333+
encoder_hidden_states_key_proj = encoder_hidden_states_key_proj.view(batch_size, -1, attn.heads, head_dim)
1334+
encoder_hidden_states_value_proj = encoder_hidden_states_value_proj.view(
1335+
batch_size, -1, attn.heads, head_dim
1336+
)
1337+
1338+
if attn.norm_added_q is not None:
1339+
encoder_hidden_states_query_proj = attn.norm_added_q(encoder_hidden_states_query_proj)
1340+
if attn.norm_added_k is not None:
1341+
encoder_hidden_states_key_proj = attn.norm_added_q(encoder_hidden_states_key_proj)
1342+
1343+
query = torch.cat([encoder_hidden_states_query_proj, query], dim=1)
1344+
key = torch.cat([encoder_hidden_states_key_proj, key], dim=1)
1345+
value = torch.cat([encoder_hidden_states_value_proj, value], dim=1)
1346+
1347+
query = query.transpose(1, 2)
1348+
key = key.transpose(1, 2)
1349+
value = value.transpose(1, 2)
1350+
1351+
# Attention.
1352+
hidden_states = F.scaled_dot_product_attention(
1353+
query, key, value, dropout_p=0.0, scale=attn.scale, is_causal=False
1354+
)
1355+
hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim)
1356+
hidden_states = hidden_states.to(query.dtype)
1357+
1358+
# Split the attention outputs.
1359+
if encoder_hidden_states is not None:
1360+
hidden_states, encoder_hidden_states = (
1361+
hidden_states[:, encoder_hidden_states.shape[1] :],
1362+
hidden_states[:, : encoder_hidden_states.shape[1]],
1363+
)
1364+
1365+
# linear proj
1366+
hidden_states = attn.to_out[0](hidden_states)
1367+
# dropout
1368+
hidden_states = attn.to_out[1](hidden_states)
1369+
if encoder_hidden_states is not None:
1370+
encoder_hidden_states = attn.to_add_out(encoder_hidden_states)
1371+
1372+
if encoder_hidden_states is not None:
1373+
return hidden_states, encoder_hidden_states
1374+
else:
1375+
return hidden_states
1376+
1377+
12771378
# YiYi to-do: refactor rope related functions/classes
12781379
def apply_rope(xq, xk, freqs_cis):
12791380
xq_ = xq.float().reshape(*xq.shape[:-1], -1, 1, 2)

src/diffusers/models/transformers/auraflow_transformer_2d.py

+106-1
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,12 @@
2222
from ...configuration_utils import ConfigMixin, register_to_config
2323
from ...utils import is_torch_version, logging
2424
from ...utils.torch_utils import maybe_allow_in_graph
25-
from ..attention_processor import Attention, AuraFlowAttnProcessor2_0
25+
from ..attention_processor import (
26+
Attention,
27+
AttentionProcessor,
28+
AuraFlowAttnProcessor2_0,
29+
FusedAuraFlowAttnProcessor2_0,
30+
)
2631
from ..embeddings import TimestepEmbedding, Timesteps
2732
from ..modeling_outputs import Transformer2DModelOutput
2833
from ..modeling_utils import ModelMixin
@@ -320,6 +325,106 @@ def __init__(
320325

321326
self.gradient_checkpointing = False
322327

328+
@property
329+
# Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.attn_processors
330+
def attn_processors(self) -> Dict[str, AttentionProcessor]:
331+
r"""
332+
Returns:
333+
`dict` of attention processors: A dictionary containing all attention processors used in the model with
334+
indexed by its weight name.
335+
"""
336+
# set recursively
337+
processors = {}
338+
339+
def fn_recursive_add_processors(name: str, module: torch.nn.Module, processors: Dict[str, AttentionProcessor]):
340+
if hasattr(module, "get_processor"):
341+
processors[f"{name}.processor"] = module.get_processor()
342+
343+
for sub_name, child in module.named_children():
344+
fn_recursive_add_processors(f"{name}.{sub_name}", child, processors)
345+
346+
return processors
347+
348+
for name, module in self.named_children():
349+
fn_recursive_add_processors(name, module, processors)
350+
351+
return processors
352+
353+
# Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.set_attn_processor
354+
def set_attn_processor(self, processor: Union[AttentionProcessor, Dict[str, AttentionProcessor]]):
355+
r"""
356+
Sets the attention processor to use to compute attention.
357+
358+
Parameters:
359+
processor (`dict` of `AttentionProcessor` or only `AttentionProcessor`):
360+
The instantiated processor class or a dictionary of processor classes that will be set as the processor
361+
for **all** `Attention` layers.
362+
363+
If `processor` is a dict, the key needs to define the path to the corresponding cross attention
364+
processor. This is strongly recommended when setting trainable attention processors.
365+
366+
"""
367+
count = len(self.attn_processors.keys())
368+
369+
if isinstance(processor, dict) and len(processor) != count:
370+
raise ValueError(
371+
f"A dict of processors was passed, but the number of processors {len(processor)} does not match the"
372+
f" number of attention layers: {count}. Please make sure to pass {count} processor classes."
373+
)
374+
375+
def fn_recursive_attn_processor(name: str, module: torch.nn.Module, processor):
376+
if hasattr(module, "set_processor"):
377+
if not isinstance(processor, dict):
378+
module.set_processor(processor)
379+
else:
380+
module.set_processor(processor.pop(f"{name}.processor"))
381+
382+
for sub_name, child in module.named_children():
383+
fn_recursive_attn_processor(f"{name}.{sub_name}", child, processor)
384+
385+
for name, module in self.named_children():
386+
fn_recursive_attn_processor(name, module, processor)
387+
388+
# Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.fuse_qkv_projections with FusedAttnProcessor2_0->FusedAuraFlowAttnProcessor2_0
389+
def fuse_qkv_projections(self):
390+
"""
391+
Enables fused QKV projections. For self-attention modules, all projection matrices (i.e., query, key, value)
392+
are fused. For cross-attention modules, key and value projection matrices are fused.
393+
394+
<Tip warning={true}>
395+
396+
This API is 🧪 experimental.
397+
398+
</Tip>
399+
"""
400+
self.original_attn_processors = None
401+
402+
for _, attn_processor in self.attn_processors.items():
403+
if "Added" in str(attn_processor.__class__.__name__):
404+
raise ValueError("`fuse_qkv_projections()` is not supported for models having added KV projections.")
405+
406+
self.original_attn_processors = self.attn_processors
407+
408+
for module in self.modules():
409+
if isinstance(module, Attention):
410+
module.fuse_projections(fuse=True)
411+
412+
self.set_attn_processor(FusedAuraFlowAttnProcessor2_0())
413+
414+
# Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.unfuse_qkv_projections
415+
def unfuse_qkv_projections(self):
416+
"""Disables the fused QKV projection if enabled.
417+
418+
<Tip warning={true}>
419+
420+
This API is 🧪 experimental.
421+
422+
</Tip>
423+
424+
"""
425+
if self.original_attn_processors is not None:
426+
self.set_attn_processor(self.original_attn_processors)
427+
323428
def _set_gradient_checkpointing(self, module, value=False):
324429
if hasattr(module, "gradient_checkpointing"):
325430
module.gradient_checkpointing = value

src/diffusers/models/transformers/pixart_transformer_2d.py

+41-1
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,7 @@
1919
from ...configuration_utils import ConfigMixin, register_to_config
2020
from ...utils import is_torch_version, logging
2121
from ..attention import BasicTransformerBlock
22-
from ..attention_processor import AttentionProcessor
22+
from ..attention_processor import Attention, AttentionProcessor, FusedAttnProcessor2_0
2323
from ..embeddings import PatchEmbed, PixArtAlphaTextProjection
2424
from ..modeling_outputs import Transformer2DModelOutput
2525
from ..modeling_utils import ModelMixin
@@ -247,6 +247,46 @@ def fn_recursive_attn_processor(name: str, module: torch.nn.Module, processor):
247247
for name, module in self.named_children():
248248
fn_recursive_attn_processor(name, module, processor)
249249

250+
# Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.fuse_qkv_projections
251+
def fuse_qkv_projections(self):
252+
"""
253+
Enables fused QKV projections. For self-attention modules, all projection matrices (i.e., query, key, value)
254+
are fused. For cross-attention modules, key and value projection matrices are fused.
255+
256+
<Tip warning={true}>
257+
258+
This API is 🧪 experimental.
259+
260+
</Tip>
261+
"""
262+
self.original_attn_processors = None
263+
264+
for _, attn_processor in self.attn_processors.items():
265+
if "Added" in str(attn_processor.__class__.__name__):
266+
raise ValueError("`fuse_qkv_projections()` is not supported for models having added KV projections.")
267+
268+
self.original_attn_processors = self.attn_processors
269+
270+
for module in self.modules():
271+
if isinstance(module, Attention):
272+
module.fuse_projections(fuse=True)
273+
274+
self.set_attn_processor(FusedAttnProcessor2_0())
275+
276+
# Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.unfuse_qkv_projections
277+
def unfuse_qkv_projections(self):
278+
"""Disables the fused QKV projection if enabled.
279+
280+
<Tip warning={true}>
281+
282+
This API is 🧪 experimental.
283+
284+
</Tip>
285+
286+
"""
287+
if self.original_attn_processors is not None:
288+
self.set_attn_processor(self.original_attn_processors)
289+
250290
def forward(
251291
self,
252292
hidden_states: torch.Tensor,

tests/pipelines/aura_flow/test_pipeline_aura_flow.py

+45-1
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,11 @@
99
torch_device,
1010
)
1111

12-
from ..test_pipelines_common import PipelineTesterMixin
12+
from ..test_pipelines_common import (
13+
PipelineTesterMixin,
14+
check_qkv_fusion_matches_attn_procs_length,
15+
check_qkv_fusion_processors_exist,
16+
)
1317

1418

1519
class AuraFlowPipelineFastTests(unittest.TestCase, PipelineTesterMixin):
@@ -119,3 +123,43 @@ def test_attention_slicing_forward_pass(self):
119123
# Attention slicing needs to implemented differently for this because how single DiT and MMDiT
120124
# blocks interfere with each other.
121125
return
126+
127+
def test_fused_qkv_projections(self):
128+
device = "cpu" # ensure determinism for the device-dependent torch.Generator
129+
components = self.get_dummy_components()
130+
pipe = self.pipeline_class(**components)
131+
pipe = pipe.to(device)
132+
pipe.set_progress_bar_config(disable=None)
133+
134+
inputs = self.get_dummy_inputs(device)
135+
image = pipe(**inputs).images
136+
original_image_slice = image[0, -3:, -3:, -1]
137+
138+
# TODO (sayakpaul): will refactor this once `fuse_qkv_projections()` has been added
139+
# to the pipeline level.
140+
pipe.transformer.fuse_qkv_projections()
141+
assert check_qkv_fusion_processors_exist(
142+
pipe.transformer
143+
), "Something wrong with the fused attention processors. Expected all the attention processors to be fused."
144+
assert check_qkv_fusion_matches_attn_procs_length(
145+
pipe.transformer, pipe.transformer.original_attn_processors
146+
), "Something wrong with the attention processors concerning the fused QKV projections."
147+
148+
inputs = self.get_dummy_inputs(device)
149+
image = pipe(**inputs).images
150+
image_slice_fused = image[0, -3:, -3:, -1]
151+
152+
pipe.transformer.unfuse_qkv_projections()
153+
inputs = self.get_dummy_inputs(device)
154+
image = pipe(**inputs).images
155+
image_slice_disabled = image[0, -3:, -3:, -1]
156+
157+
assert np.allclose(
158+
original_image_slice, image_slice_fused, atol=1e-3, rtol=1e-3
159+
), "Fusion of QKV projections shouldn't affect the outputs."
160+
assert np.allclose(
161+
image_slice_fused, image_slice_disabled, atol=1e-3, rtol=1e-3
162+
), "Outputs, with QKV projection fusion enabled, shouldn't change when fused QKV projections are disabled."
163+
assert np.allclose(
164+
original_image_slice, image_slice_disabled, atol=1e-2, rtol=1e-2
165+
), "Original outputs should match when fused QKV projections are disabled."

0 commit comments

Comments
 (0)