Skip to content

Commit aa71132

Browse files
elismasilvasayakpaul
authored andcommitted
Feature IP Adapter Xformers Attention Processor (#9881)
* Feature IP Adapter Xformers Attention Processor: this fix error loading incorrect attention processor when setting Xformers attn after load ip adapter scale, issues: #8863 #8872
1 parent 291db3e commit aa71132

File tree

3 files changed

+278
-11
lines changed

3 files changed

+278
-11
lines changed

src/diffusers/loaders/ip_adapter.py

+8-6
Original file line numberDiff line numberDiff line change
@@ -33,16 +33,14 @@
3333

3434

3535
if is_transformers_available():
36-
from transformers import (
37-
CLIPImageProcessor,
38-
CLIPVisionModelWithProjection,
39-
)
36+
from transformers import CLIPImageProcessor, CLIPVisionModelWithProjection
4037

4138
from ..models.attention_processor import (
4239
AttnProcessor,
4340
AttnProcessor2_0,
4441
IPAdapterAttnProcessor,
4542
IPAdapterAttnProcessor2_0,
43+
IPAdapterXFormersAttnProcessor,
4644
)
4745

4846
logger = logging.get_logger(__name__)
@@ -284,7 +282,9 @@ def set_ip_adapter_scale(self, scale):
284282
scale_configs = _maybe_expand_lora_scales(unet, scale, default_scale=0.0)
285283

286284
for attn_name, attn_processor in unet.attn_processors.items():
287-
if isinstance(attn_processor, (IPAdapterAttnProcessor, IPAdapterAttnProcessor2_0)):
285+
if isinstance(
286+
attn_processor, (IPAdapterAttnProcessor, IPAdapterAttnProcessor2_0, IPAdapterXFormersAttnProcessor)
287+
):
288288
if len(scale_configs) != len(attn_processor.scale):
289289
raise ValueError(
290290
f"Cannot assign {len(scale_configs)} scale_configs to "
@@ -342,7 +342,9 @@ def unload_ip_adapter(self):
342342
)
343343
attn_procs[name] = (
344344
attn_processor_class
345-
if isinstance(value, (IPAdapterAttnProcessor, IPAdapterAttnProcessor2_0))
345+
if isinstance(
346+
value, (IPAdapterAttnProcessor, IPAdapterAttnProcessor2_0, IPAdapterXFormersAttnProcessor)
347+
)
346348
else value.__class__()
347349
)
348350
self.unet.set_attn_processor(attn_procs)

src/diffusers/loaders/unet.py

+9-4
Original file line numberDiff line numberDiff line change
@@ -765,6 +765,7 @@ def _convert_ip_adapter_attn_to_diffusers(self, state_dicts, low_cpu_mem_usage=F
765765
from ..models.attention_processor import (
766766
IPAdapterAttnProcessor,
767767
IPAdapterAttnProcessor2_0,
768+
IPAdapterXFormersAttnProcessor,
768769
)
769770

770771
if low_cpu_mem_usage:
@@ -804,11 +805,15 @@ def _convert_ip_adapter_attn_to_diffusers(self, state_dicts, low_cpu_mem_usage=F
804805
if cross_attention_dim is None or "motion_modules" in name:
805806
attn_processor_class = self.attn_processors[name].__class__
806807
attn_procs[name] = attn_processor_class()
807-
808808
else:
809-
attn_processor_class = (
810-
IPAdapterAttnProcessor2_0 if hasattr(F, "scaled_dot_product_attention") else IPAdapterAttnProcessor
811-
)
809+
if "XFormers" in str(self.attn_processors[name].__class__):
810+
attn_processor_class = IPAdapterXFormersAttnProcessor
811+
else:
812+
attn_processor_class = (
813+
IPAdapterAttnProcessor2_0
814+
if hasattr(F, "scaled_dot_product_attention")
815+
else IPAdapterAttnProcessor
816+
)
812817
num_image_text_embeds = []
813818
for state_dict in state_dicts:
814819
if "proj.weight" in state_dict["image_proj"]:

src/diffusers/models/attention_processor.py

+261-1
Original file line numberDiff line numberDiff line change
@@ -318,7 +318,10 @@ def set_use_memory_efficient_attention_xformers(
318318
XFormersAttnAddedKVProcessor,
319319
),
320320
)
321-
321+
is_ip_adapter = hasattr(self, "processor") and isinstance(
322+
self.processor,
323+
(IPAdapterAttnProcessor, IPAdapterAttnProcessor2_0, IPAdapterXFormersAttnProcessor),
324+
)
322325
if use_memory_efficient_attention_xformers:
323326
if is_added_kv_processor and is_custom_diffusion:
324327
raise NotImplementedError(
@@ -368,6 +371,19 @@ def set_use_memory_efficient_attention_xformers(
368371
"Memory efficient attention with `xformers` might currently not work correctly if an attention mask is required for the attention operation."
369372
)
370373
processor = XFormersAttnAddedKVProcessor(attention_op=attention_op)
374+
elif is_ip_adapter:
375+
processor = IPAdapterXFormersAttnProcessor(
376+
hidden_size=self.processor.hidden_size,
377+
cross_attention_dim=self.processor.cross_attention_dim,
378+
num_tokens=self.processor.num_tokens,
379+
scale=self.processor.scale,
380+
attention_op=attention_op,
381+
)
382+
processor.load_state_dict(self.processor.state_dict())
383+
if hasattr(self.processor, "to_k_ip"):
384+
processor.to(
385+
device=self.processor.to_k_ip[0].weight.device, dtype=self.processor.to_k_ip[0].weight.dtype
386+
)
371387
else:
372388
processor = XFormersAttnProcessor(attention_op=attention_op)
373389
else:
@@ -386,6 +402,18 @@ def set_use_memory_efficient_attention_xformers(
386402
processor.load_state_dict(self.processor.state_dict())
387403
if hasattr(self.processor, "to_k_custom_diffusion"):
388404
processor.to(self.processor.to_k_custom_diffusion.weight.device)
405+
elif is_ip_adapter:
406+
processor = IPAdapterAttnProcessor2_0(
407+
hidden_size=self.processor.hidden_size,
408+
cross_attention_dim=self.processor.cross_attention_dim,
409+
num_tokens=self.processor.num_tokens,
410+
scale=self.processor.scale,
411+
)
412+
processor.load_state_dict(self.processor.state_dict())
413+
if hasattr(self.processor, "to_k_ip"):
414+
processor.to(
415+
device=self.processor.to_k_ip[0].weight.device, dtype=self.processor.to_k_ip[0].weight.dtype
416+
)
389417
else:
390418
# set attention processor
391419
# We use the AttnProcessor2_0 by default when torch 2.x is used which uses
@@ -4542,6 +4570,238 @@ def __call__(
45424570
return hidden_states
45434571

45444572

4573+
class IPAdapterXFormersAttnProcessor(torch.nn.Module):
4574+
r"""
4575+
Attention processor for IP-Adapter using xFormers.
4576+
4577+
Args:
4578+
hidden_size (`int`):
4579+
The hidden size of the attention layer.
4580+
cross_attention_dim (`int`):
4581+
The number of channels in the `encoder_hidden_states`.
4582+
num_tokens (`int`, `Tuple[int]` or `List[int]`, defaults to `(4,)`):
4583+
The context length of the image features.
4584+
scale (`float` or `List[float]`, defaults to 1.0):
4585+
the weight scale of image prompt.
4586+
attention_op (`Callable`, *optional*, defaults to `None`):
4587+
The base
4588+
[operator](https://facebookresearch.github.io/xformers/components/ops.html#xformers.ops.AttentionOpBase) to
4589+
use as the attention operator. It is recommended to set to `None`, and allow xFormers to choose the best
4590+
operator.
4591+
"""
4592+
4593+
def __init__(
4594+
self,
4595+
hidden_size,
4596+
cross_attention_dim=None,
4597+
num_tokens=(4,),
4598+
scale=1.0,
4599+
attention_op: Optional[Callable] = None,
4600+
):
4601+
super().__init__()
4602+
4603+
self.hidden_size = hidden_size
4604+
self.cross_attention_dim = cross_attention_dim
4605+
self.attention_op = attention_op
4606+
4607+
if not isinstance(num_tokens, (tuple, list)):
4608+
num_tokens = [num_tokens]
4609+
self.num_tokens = num_tokens
4610+
4611+
if not isinstance(scale, list):
4612+
scale = [scale] * len(num_tokens)
4613+
if len(scale) != len(num_tokens):
4614+
raise ValueError("`scale` should be a list of integers with the same length as `num_tokens`.")
4615+
self.scale = scale
4616+
4617+
self.to_k_ip = nn.ModuleList(
4618+
[nn.Linear(cross_attention_dim or hidden_size, hidden_size, bias=False) for _ in range(len(num_tokens))]
4619+
)
4620+
self.to_v_ip = nn.ModuleList(
4621+
[nn.Linear(cross_attention_dim or hidden_size, hidden_size, bias=False) for _ in range(len(num_tokens))]
4622+
)
4623+
4624+
def __call__(
4625+
self,
4626+
attn: Attention,
4627+
hidden_states: torch.FloatTensor,
4628+
encoder_hidden_states: Optional[torch.FloatTensor] = None,
4629+
attention_mask: Optional[torch.FloatTensor] = None,
4630+
temb: Optional[torch.FloatTensor] = None,
4631+
scale: float = 1.0,
4632+
ip_adapter_masks: Optional[torch.FloatTensor] = None,
4633+
):
4634+
residual = hidden_states
4635+
4636+
# separate ip_hidden_states from encoder_hidden_states
4637+
if encoder_hidden_states is not None:
4638+
if isinstance(encoder_hidden_states, tuple):
4639+
encoder_hidden_states, ip_hidden_states = encoder_hidden_states
4640+
else:
4641+
deprecation_message = (
4642+
"You have passed a tensor as `encoder_hidden_states`. This is deprecated and will be removed in a future release."
4643+
" Please make sure to update your script to pass `encoder_hidden_states` as a tuple to suppress this warning."
4644+
)
4645+
deprecate("encoder_hidden_states not a tuple", "1.0.0", deprecation_message, standard_warn=False)
4646+
end_pos = encoder_hidden_states.shape[1] - self.num_tokens[0]
4647+
encoder_hidden_states, ip_hidden_states = (
4648+
encoder_hidden_states[:, :end_pos, :],
4649+
[encoder_hidden_states[:, end_pos:, :]],
4650+
)
4651+
4652+
if attn.spatial_norm is not None:
4653+
hidden_states = attn.spatial_norm(hidden_states, temb)
4654+
4655+
input_ndim = hidden_states.ndim
4656+
4657+
if input_ndim == 4:
4658+
batch_size, channel, height, width = hidden_states.shape
4659+
hidden_states = hidden_states.view(batch_size, channel, height * width).transpose(1, 2)
4660+
4661+
batch_size, sequence_length, _ = (
4662+
hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape
4663+
)
4664+
4665+
if attention_mask is not None:
4666+
attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size)
4667+
# expand our mask's singleton query_tokens dimension:
4668+
# [batch*heads, 1, key_tokens] ->
4669+
# [batch*heads, query_tokens, key_tokens]
4670+
# so that it can be added as a bias onto the attention scores that xformers computes:
4671+
# [batch*heads, query_tokens, key_tokens]
4672+
# we do this explicitly because xformers doesn't broadcast the singleton dimension for us.
4673+
_, query_tokens, _ = hidden_states.shape
4674+
attention_mask = attention_mask.expand(-1, query_tokens, -1)
4675+
4676+
if attn.group_norm is not None:
4677+
hidden_states = attn.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2)
4678+
4679+
query = attn.to_q(hidden_states)
4680+
4681+
if encoder_hidden_states is None:
4682+
encoder_hidden_states = hidden_states
4683+
elif attn.norm_cross:
4684+
encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states)
4685+
4686+
key = attn.to_k(encoder_hidden_states)
4687+
value = attn.to_v(encoder_hidden_states)
4688+
4689+
query = attn.head_to_batch_dim(query).contiguous()
4690+
key = attn.head_to_batch_dim(key).contiguous()
4691+
value = attn.head_to_batch_dim(value).contiguous()
4692+
4693+
hidden_states = xformers.ops.memory_efficient_attention(
4694+
query, key, value, attn_bias=attention_mask, op=self.attention_op
4695+
)
4696+
hidden_states = hidden_states.to(query.dtype)
4697+
hidden_states = attn.batch_to_head_dim(hidden_states)
4698+
4699+
if ip_hidden_states:
4700+
if ip_adapter_masks is not None:
4701+
if not isinstance(ip_adapter_masks, List):
4702+
# for backward compatibility, we accept `ip_adapter_mask` as a tensor of shape [num_ip_adapter, 1, height, width]
4703+
ip_adapter_masks = list(ip_adapter_masks.unsqueeze(1))
4704+
if not (len(ip_adapter_masks) == len(self.scale) == len(ip_hidden_states)):
4705+
raise ValueError(
4706+
f"Length of ip_adapter_masks array ({len(ip_adapter_masks)}) must match "
4707+
f"length of self.scale array ({len(self.scale)}) and number of ip_hidden_states "
4708+
f"({len(ip_hidden_states)})"
4709+
)
4710+
else:
4711+
for index, (mask, scale, ip_state) in enumerate(
4712+
zip(ip_adapter_masks, self.scale, ip_hidden_states)
4713+
):
4714+
if mask is None:
4715+
continue
4716+
if not isinstance(mask, torch.Tensor) or mask.ndim != 4:
4717+
raise ValueError(
4718+
"Each element of the ip_adapter_masks array should be a tensor with shape "
4719+
"[1, num_images_for_ip_adapter, height, width]."
4720+
" Please use `IPAdapterMaskProcessor` to preprocess your mask"
4721+
)
4722+
if mask.shape[1] != ip_state.shape[1]:
4723+
raise ValueError(
4724+
f"Number of masks ({mask.shape[1]}) does not match "
4725+
f"number of ip images ({ip_state.shape[1]}) at index {index}"
4726+
)
4727+
if isinstance(scale, list) and not len(scale) == mask.shape[1]:
4728+
raise ValueError(
4729+
f"Number of masks ({mask.shape[1]}) does not match "
4730+
f"number of scales ({len(scale)}) at index {index}"
4731+
)
4732+
else:
4733+
ip_adapter_masks = [None] * len(self.scale)
4734+
4735+
# for ip-adapter
4736+
for current_ip_hidden_states, scale, to_k_ip, to_v_ip, mask in zip(
4737+
ip_hidden_states, self.scale, self.to_k_ip, self.to_v_ip, ip_adapter_masks
4738+
):
4739+
skip = False
4740+
if isinstance(scale, list):
4741+
if all(s == 0 for s in scale):
4742+
skip = True
4743+
elif scale == 0:
4744+
skip = True
4745+
if not skip:
4746+
if mask is not None:
4747+
mask = mask.to(torch.float16)
4748+
if not isinstance(scale, list):
4749+
scale = [scale] * mask.shape[1]
4750+
4751+
current_num_images = mask.shape[1]
4752+
for i in range(current_num_images):
4753+
ip_key = to_k_ip(current_ip_hidden_states[:, i, :, :])
4754+
ip_value = to_v_ip(current_ip_hidden_states[:, i, :, :])
4755+
4756+
ip_key = attn.head_to_batch_dim(ip_key).contiguous()
4757+
ip_value = attn.head_to_batch_dim(ip_value).contiguous()
4758+
4759+
_current_ip_hidden_states = xformers.ops.memory_efficient_attention(
4760+
query, ip_key, ip_value, op=self.attention_op
4761+
)
4762+
_current_ip_hidden_states = _current_ip_hidden_states.to(query.dtype)
4763+
_current_ip_hidden_states = attn.batch_to_head_dim(_current_ip_hidden_states)
4764+
4765+
mask_downsample = IPAdapterMaskProcessor.downsample(
4766+
mask[:, i, :, :],
4767+
batch_size,
4768+
_current_ip_hidden_states.shape[1],
4769+
_current_ip_hidden_states.shape[2],
4770+
)
4771+
4772+
mask_downsample = mask_downsample.to(dtype=query.dtype, device=query.device)
4773+
hidden_states = hidden_states + scale[i] * (_current_ip_hidden_states * mask_downsample)
4774+
else:
4775+
ip_key = to_k_ip(current_ip_hidden_states)
4776+
ip_value = to_v_ip(current_ip_hidden_states)
4777+
4778+
ip_key = attn.head_to_batch_dim(ip_key).contiguous()
4779+
ip_value = attn.head_to_batch_dim(ip_value).contiguous()
4780+
4781+
current_ip_hidden_states = xformers.ops.memory_efficient_attention(
4782+
query, ip_key, ip_value, op=self.attention_op
4783+
)
4784+
current_ip_hidden_states = current_ip_hidden_states.to(query.dtype)
4785+
current_ip_hidden_states = attn.batch_to_head_dim(current_ip_hidden_states)
4786+
4787+
hidden_states = hidden_states + scale * current_ip_hidden_states
4788+
4789+
# linear proj
4790+
hidden_states = attn.to_out[0](hidden_states)
4791+
# dropout
4792+
hidden_states = attn.to_out[1](hidden_states)
4793+
4794+
if input_ndim == 4:
4795+
hidden_states = hidden_states.transpose(-1, -2).reshape(batch_size, channel, height, width)
4796+
4797+
if attn.residual_connection:
4798+
hidden_states = hidden_states + residual
4799+
4800+
hidden_states = hidden_states / attn.rescale_output_factor
4801+
4802+
return hidden_states
4803+
4804+
45454805
class PAGIdentitySelfAttnProcessor2_0:
45464806
r"""
45474807
Processor for implementing PAG using scaled dot-product attention (enabled by default if you're using PyTorch 2.0).

0 commit comments

Comments
 (0)