|
23 | 23 | from ...configuration_utils import ConfigMixin, register_to_config
|
24 | 24 | from ...loaders import FromOriginalModelMixin, PeftAdapterMixin
|
25 | 25 | from ...models.attention import FeedForward
|
26 |
| -from ...models.attention_processor import Attention, FluxAttnProcessor2_0 |
| 26 | +from ...models.attention_processor import ( |
| 27 | + Attention, |
| 28 | + AttentionProcessor, |
| 29 | + FluxAttnProcessor2_0, |
| 30 | + FusedFluxAttnProcessor2_0, |
| 31 | +) |
27 | 32 | from ...models.modeling_utils import ModelMixin
|
28 | 33 | from ...models.normalization import AdaLayerNormContinuous, AdaLayerNormZero, AdaLayerNormZeroSingle
|
29 | 34 | from ...utils import USE_PEFT_BACKEND, is_torch_version, logging, scale_lora_layers, unscale_lora_layers
|
@@ -276,6 +281,106 @@ def __init__(
|
276 | 281 |
|
277 | 282 | self.gradient_checkpointing = False
|
278 | 283 |
|
| 284 | + @property |
| 285 | + # Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.attn_processors |
| 286 | + def attn_processors(self) -> Dict[str, AttentionProcessor]: |
| 287 | + r""" |
| 288 | + Returns: |
| 289 | + `dict` of attention processors: A dictionary containing all attention processors used in the model with |
| 290 | + indexed by its weight name. |
| 291 | + """ |
| 292 | + # set recursively |
| 293 | + processors = {} |
| 294 | + |
| 295 | + def fn_recursive_add_processors(name: str, module: torch.nn.Module, processors: Dict[str, AttentionProcessor]): |
| 296 | + if hasattr(module, "get_processor"): |
| 297 | + processors[f"{name}.processor"] = module.get_processor() |
| 298 | + |
| 299 | + for sub_name, child in module.named_children(): |
| 300 | + fn_recursive_add_processors(f"{name}.{sub_name}", child, processors) |
| 301 | + |
| 302 | + return processors |
| 303 | + |
| 304 | + for name, module in self.named_children(): |
| 305 | + fn_recursive_add_processors(name, module, processors) |
| 306 | + |
| 307 | + return processors |
| 308 | + |
| 309 | + # Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.set_attn_processor |
| 310 | + def set_attn_processor(self, processor: Union[AttentionProcessor, Dict[str, AttentionProcessor]]): |
| 311 | + r""" |
| 312 | + Sets the attention processor to use to compute attention. |
| 313 | +
|
| 314 | + Parameters: |
| 315 | + processor (`dict` of `AttentionProcessor` or only `AttentionProcessor`): |
| 316 | + The instantiated processor class or a dictionary of processor classes that will be set as the processor |
| 317 | + for **all** `Attention` layers. |
| 318 | +
|
| 319 | + If `processor` is a dict, the key needs to define the path to the corresponding cross attention |
| 320 | + processor. This is strongly recommended when setting trainable attention processors. |
| 321 | +
|
| 322 | + """ |
| 323 | + count = len(self.attn_processors.keys()) |
| 324 | + |
| 325 | + if isinstance(processor, dict) and len(processor) != count: |
| 326 | + raise ValueError( |
| 327 | + f"A dict of processors was passed, but the number of processors {len(processor)} does not match the" |
| 328 | + f" number of attention layers: {count}. Please make sure to pass {count} processor classes." |
| 329 | + ) |
| 330 | + |
| 331 | + def fn_recursive_attn_processor(name: str, module: torch.nn.Module, processor): |
| 332 | + if hasattr(module, "set_processor"): |
| 333 | + if not isinstance(processor, dict): |
| 334 | + module.set_processor(processor) |
| 335 | + else: |
| 336 | + module.set_processor(processor.pop(f"{name}.processor")) |
| 337 | + |
| 338 | + for sub_name, child in module.named_children(): |
| 339 | + fn_recursive_attn_processor(f"{name}.{sub_name}", child, processor) |
| 340 | + |
| 341 | + for name, module in self.named_children(): |
| 342 | + fn_recursive_attn_processor(name, module, processor) |
| 343 | + |
| 344 | + # Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.fuse_qkv_projections with FusedAttnProcessor2_0->FusedFluxAttnProcessor2_0 |
| 345 | + def fuse_qkv_projections(self): |
| 346 | + """ |
| 347 | + Enables fused QKV projections. For self-attention modules, all projection matrices (i.e., query, key, value) |
| 348 | + are fused. For cross-attention modules, key and value projection matrices are fused. |
| 349 | +
|
| 350 | + <Tip warning={true}> |
| 351 | +
|
| 352 | + This API is 🧪 experimental. |
| 353 | +
|
| 354 | + </Tip> |
| 355 | + """ |
| 356 | + self.original_attn_processors = None |
| 357 | + |
| 358 | + for _, attn_processor in self.attn_processors.items(): |
| 359 | + if "Added" in str(attn_processor.__class__.__name__): |
| 360 | + raise ValueError("`fuse_qkv_projections()` is not supported for models having added KV projections.") |
| 361 | + |
| 362 | + self.original_attn_processors = self.attn_processors |
| 363 | + |
| 364 | + for module in self.modules(): |
| 365 | + if isinstance(module, Attention): |
| 366 | + module.fuse_projections(fuse=True) |
| 367 | + |
| 368 | + self.set_attn_processor(FusedFluxAttnProcessor2_0()) |
| 369 | + |
| 370 | + # Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.unfuse_qkv_projections |
| 371 | + def unfuse_qkv_projections(self): |
| 372 | + """Disables the fused QKV projection if enabled. |
| 373 | +
|
| 374 | + <Tip warning={true}> |
| 375 | +
|
| 376 | + This API is 🧪 experimental. |
| 377 | +
|
| 378 | + </Tip> |
| 379 | +
|
| 380 | + """ |
| 381 | + if self.original_attn_processors is not None: |
| 382 | + self.set_attn_processor(self.original_attn_processors) |
| 383 | + |
279 | 384 | def _set_gradient_checkpointing(self, module, value=False):
|
280 | 385 | if hasattr(module, "gradient_checkpointing"):
|
281 | 386 | module.gradient_checkpointing = value
|
|
0 commit comments