Skip to content

Commit 2d9ccf3

Browse files
authoredAug 23, 2024
[Core] fuse_qkv_projection() to Flux (#9185)
* start fusing flux. * test * finish fusion * fix-copues
1 parent 960c149 commit 2d9ccf3

File tree

3 files changed

+245
-3
lines changed

3 files changed

+245
-3
lines changed
 

‎src/diffusers/models/attention_processor.py

+94
Original file line numberDiff line numberDiff line change
@@ -1783,6 +1783,100 @@ def __call__(
17831783
return hidden_states
17841784

17851785

1786+
class FusedFluxAttnProcessor2_0:
1787+
"""Attention processor used typically in processing the SD3-like self-attention projections."""
1788+
1789+
def __init__(self):
1790+
if not hasattr(F, "scaled_dot_product_attention"):
1791+
raise ImportError(
1792+
"FusedFluxAttnProcessor2_0 requires PyTorch 2.0, to use it, please upgrade PyTorch to 2.0."
1793+
)
1794+
1795+
def __call__(
1796+
self,
1797+
attn: Attention,
1798+
hidden_states: torch.FloatTensor,
1799+
encoder_hidden_states: torch.FloatTensor = None,
1800+
attention_mask: Optional[torch.FloatTensor] = None,
1801+
image_rotary_emb: Optional[torch.Tensor] = None,
1802+
) -> torch.FloatTensor:
1803+
batch_size, _, _ = hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape
1804+
1805+
# `sample` projections.
1806+
qkv = attn.to_qkv(hidden_states)
1807+
split_size = qkv.shape[-1] // 3
1808+
query, key, value = torch.split(qkv, split_size, dim=-1)
1809+
1810+
inner_dim = key.shape[-1]
1811+
head_dim = inner_dim // attn.heads
1812+
1813+
query = query.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
1814+
key = key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
1815+
value = value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
1816+
1817+
if attn.norm_q is not None:
1818+
query = attn.norm_q(query)
1819+
if attn.norm_k is not None:
1820+
key = attn.norm_k(key)
1821+
1822+
# the attention in FluxSingleTransformerBlock does not use `encoder_hidden_states`
1823+
# `context` projections.
1824+
if encoder_hidden_states is not None:
1825+
encoder_qkv = attn.to_added_qkv(encoder_hidden_states)
1826+
split_size = encoder_qkv.shape[-1] // 3
1827+
(
1828+
encoder_hidden_states_query_proj,
1829+
encoder_hidden_states_key_proj,
1830+
encoder_hidden_states_value_proj,
1831+
) = torch.split(encoder_qkv, split_size, dim=-1)
1832+
1833+
encoder_hidden_states_query_proj = encoder_hidden_states_query_proj.view(
1834+
batch_size, -1, attn.heads, head_dim
1835+
).transpose(1, 2)
1836+
encoder_hidden_states_key_proj = encoder_hidden_states_key_proj.view(
1837+
batch_size, -1, attn.heads, head_dim
1838+
).transpose(1, 2)
1839+
encoder_hidden_states_value_proj = encoder_hidden_states_value_proj.view(
1840+
batch_size, -1, attn.heads, head_dim
1841+
).transpose(1, 2)
1842+
1843+
if attn.norm_added_q is not None:
1844+
encoder_hidden_states_query_proj = attn.norm_added_q(encoder_hidden_states_query_proj)
1845+
if attn.norm_added_k is not None:
1846+
encoder_hidden_states_key_proj = attn.norm_added_k(encoder_hidden_states_key_proj)
1847+
1848+
# attention
1849+
query = torch.cat([encoder_hidden_states_query_proj, query], dim=2)
1850+
key = torch.cat([encoder_hidden_states_key_proj, key], dim=2)
1851+
value = torch.cat([encoder_hidden_states_value_proj, value], dim=2)
1852+
1853+
if image_rotary_emb is not None:
1854+
from .embeddings import apply_rotary_emb
1855+
1856+
query = apply_rotary_emb(query, image_rotary_emb)
1857+
key = apply_rotary_emb(key, image_rotary_emb)
1858+
1859+
hidden_states = F.scaled_dot_product_attention(query, key, value, dropout_p=0.0, is_causal=False)
1860+
hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim)
1861+
hidden_states = hidden_states.to(query.dtype)
1862+
1863+
if encoder_hidden_states is not None:
1864+
encoder_hidden_states, hidden_states = (
1865+
hidden_states[:, : encoder_hidden_states.shape[1]],
1866+
hidden_states[:, encoder_hidden_states.shape[1] :],
1867+
)
1868+
1869+
# linear proj
1870+
hidden_states = attn.to_out[0](hidden_states)
1871+
# dropout
1872+
hidden_states = attn.to_out[1](hidden_states)
1873+
encoder_hidden_states = attn.to_add_out(encoder_hidden_states)
1874+
1875+
return hidden_states, encoder_hidden_states
1876+
else:
1877+
return hidden_states
1878+
1879+
17861880
class CogVideoXAttnProcessor2_0:
17871881
r"""
17881882
Processor for implementing scaled dot-product attention for the CogVideoX model. It applies a rotary embedding on

‎src/diffusers/models/transformers/transformer_flux.py

+106-1
Original file line numberDiff line numberDiff line change
@@ -23,7 +23,12 @@
2323
from ...configuration_utils import ConfigMixin, register_to_config
2424
from ...loaders import FromOriginalModelMixin, PeftAdapterMixin
2525
from ...models.attention import FeedForward
26-
from ...models.attention_processor import Attention, FluxAttnProcessor2_0
26+
from ...models.attention_processor import (
27+
Attention,
28+
AttentionProcessor,
29+
FluxAttnProcessor2_0,
30+
FusedFluxAttnProcessor2_0,
31+
)
2732
from ...models.modeling_utils import ModelMixin
2833
from ...models.normalization import AdaLayerNormContinuous, AdaLayerNormZero, AdaLayerNormZeroSingle
2934
from ...utils import USE_PEFT_BACKEND, is_torch_version, logging, scale_lora_layers, unscale_lora_layers
@@ -276,6 +281,106 @@ def __init__(
276281

277282
self.gradient_checkpointing = False
278283

284+
@property
285+
# Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.attn_processors
286+
def attn_processors(self) -> Dict[str, AttentionProcessor]:
287+
r"""
288+
Returns:
289+
`dict` of attention processors: A dictionary containing all attention processors used in the model with
290+
indexed by its weight name.
291+
"""
292+
# set recursively
293+
processors = {}
294+
295+
def fn_recursive_add_processors(name: str, module: torch.nn.Module, processors: Dict[str, AttentionProcessor]):
296+
if hasattr(module, "get_processor"):
297+
processors[f"{name}.processor"] = module.get_processor()
298+
299+
for sub_name, child in module.named_children():
300+
fn_recursive_add_processors(f"{name}.{sub_name}", child, processors)
301+
302+
return processors
303+
304+
for name, module in self.named_children():
305+
fn_recursive_add_processors(name, module, processors)
306+
307+
return processors
308+
309+
# Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.set_attn_processor
310+
def set_attn_processor(self, processor: Union[AttentionProcessor, Dict[str, AttentionProcessor]]):
311+
r"""
312+
Sets the attention processor to use to compute attention.
313+
314+
Parameters:
315+
processor (`dict` of `AttentionProcessor` or only `AttentionProcessor`):
316+
The instantiated processor class or a dictionary of processor classes that will be set as the processor
317+
for **all** `Attention` layers.
318+
319+
If `processor` is a dict, the key needs to define the path to the corresponding cross attention
320+
processor. This is strongly recommended when setting trainable attention processors.
321+
322+
"""
323+
count = len(self.attn_processors.keys())
324+
325+
if isinstance(processor, dict) and len(processor) != count:
326+
raise ValueError(
327+
f"A dict of processors was passed, but the number of processors {len(processor)} does not match the"
328+
f" number of attention layers: {count}. Please make sure to pass {count} processor classes."
329+
)
330+
331+
def fn_recursive_attn_processor(name: str, module: torch.nn.Module, processor):
332+
if hasattr(module, "set_processor"):
333+
if not isinstance(processor, dict):
334+
module.set_processor(processor)
335+
else:
336+
module.set_processor(processor.pop(f"{name}.processor"))
337+
338+
for sub_name, child in module.named_children():
339+
fn_recursive_attn_processor(f"{name}.{sub_name}", child, processor)
340+
341+
for name, module in self.named_children():
342+
fn_recursive_attn_processor(name, module, processor)
343+
344+
# Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.fuse_qkv_projections with FusedAttnProcessor2_0->FusedFluxAttnProcessor2_0
345+
def fuse_qkv_projections(self):
346+
"""
347+
Enables fused QKV projections. For self-attention modules, all projection matrices (i.e., query, key, value)
348+
are fused. For cross-attention modules, key and value projection matrices are fused.
349+
350+
<Tip warning={true}>
351+
352+
This API is 🧪 experimental.
353+
354+
</Tip>
355+
"""
356+
self.original_attn_processors = None
357+
358+
for _, attn_processor in self.attn_processors.items():
359+
if "Added" in str(attn_processor.__class__.__name__):
360+
raise ValueError("`fuse_qkv_projections()` is not supported for models having added KV projections.")
361+
362+
self.original_attn_processors = self.attn_processors
363+
364+
for module in self.modules():
365+
if isinstance(module, Attention):
366+
module.fuse_projections(fuse=True)
367+
368+
self.set_attn_processor(FusedFluxAttnProcessor2_0())
369+
370+
# Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.unfuse_qkv_projections
371+
def unfuse_qkv_projections(self):
372+
"""Disables the fused QKV projection if enabled.
373+
374+
<Tip warning={true}>
375+
376+
This API is 🧪 experimental.
377+
378+
</Tip>
379+
380+
"""
381+
if self.original_attn_processors is not None:
382+
self.set_attn_processor(self.original_attn_processors)
383+
279384
def _set_gradient_checkpointing(self, module, value=False):
280385
if hasattr(module, "gradient_checkpointing"):
281386
module.gradient_checkpointing = value

‎tests/pipelines/flux/test_pipeline_flux.py

+45-2
Original file line numberDiff line numberDiff line change
@@ -13,10 +13,13 @@
1313
torch_device,
1414
)
1515

16-
from ..test_pipelines_common import PipelineTesterMixin
16+
from ..test_pipelines_common import (
17+
PipelineTesterMixin,
18+
check_qkv_fusion_matches_attn_procs_length,
19+
check_qkv_fusion_processors_exist,
20+
)
1721

1822

19-
@unittest.skipIf(torch_device == "mps", "Flux has a float64 operation which is not supported in MPS.")
2023
class FluxPipelineFastTests(unittest.TestCase, PipelineTesterMixin):
2124
pipeline_class = FluxPipeline
2225
params = frozenset(["prompt", "height", "width", "guidance_scale", "prompt_embeds", "pooled_prompt_embeds"])
@@ -143,6 +146,46 @@ def test_flux_prompt_embeds(self):
143146
max_diff = np.abs(output_with_prompt - output_with_embeds).max()
144147
assert max_diff < 1e-4
145148

149+
def test_fused_qkv_projections(self):
150+
device = "cpu" # ensure determinism for the device-dependent torch.Generator
151+
components = self.get_dummy_components()
152+
pipe = self.pipeline_class(**components)
153+
pipe = pipe.to(device)
154+
pipe.set_progress_bar_config(disable=None)
155+
156+
inputs = self.get_dummy_inputs(device)
157+
image = pipe(**inputs).images
158+
original_image_slice = image[0, -3:, -3:, -1]
159+
160+
# TODO (sayakpaul): will refactor this once `fuse_qkv_projections()` has been added
161+
# to the pipeline level.
162+
pipe.transformer.fuse_qkv_projections()
163+
assert check_qkv_fusion_processors_exist(
164+
pipe.transformer
165+
), "Something wrong with the fused attention processors. Expected all the attention processors to be fused."
166+
assert check_qkv_fusion_matches_attn_procs_length(
167+
pipe.transformer, pipe.transformer.original_attn_processors
168+
), "Something wrong with the attention processors concerning the fused QKV projections."
169+
170+
inputs = self.get_dummy_inputs(device)
171+
image = pipe(**inputs).images
172+
image_slice_fused = image[0, -3:, -3:, -1]
173+
174+
pipe.transformer.unfuse_qkv_projections()
175+
inputs = self.get_dummy_inputs(device)
176+
image = pipe(**inputs).images
177+
image_slice_disabled = image[0, -3:, -3:, -1]
178+
179+
assert np.allclose(
180+
original_image_slice, image_slice_fused, atol=1e-3, rtol=1e-3
181+
), "Fusion of QKV projections shouldn't affect the outputs."
182+
assert np.allclose(
183+
image_slice_fused, image_slice_disabled, atol=1e-3, rtol=1e-3
184+
), "Outputs, with QKV projection fusion enabled, shouldn't change when fused QKV projections are disabled."
185+
assert np.allclose(
186+
original_image_slice, image_slice_disabled, atol=1e-2, rtol=1e-2
187+
), "Original outputs should match when fused QKV projections are disabled."
188+
146189

147190
@slow
148191
@require_torch_gpu

0 commit comments

Comments
 (0)