-
Notifications
You must be signed in to change notification settings - Fork 5.8k
New issue
Have a question about this project? # for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “#”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? # to your account
Add warning if Attention Processor is changed after loading IP Adapter #8872
Comments
we could add an
similarly to how we handle custom_diffusion
|
Hi @yiyixuxu i will appreciate if we have IPAdapterXformerAttnProcessor, i tested with sdp is woking fine for style transfer, but for xformers it dosent work, but if think in pipeline where load ip adapter is need to set this new attn ? we need to check if xformers was enabled on pipeline to use correct attn. this file diffusers/src has many instances references for IPAttnProcessor2 i an working in a local copy will be need to check if xformers was applied how i can do this? on my tests i saw xformersattn applied to all processors if i enable xformers on pipeline before load_ip_adapter, but after it is replaced by sdp. |
I did this is working for me, you guys can reuse this for production. class IPAdapterXformerAttnProcessor(torch.nn.Module):
r"""
Attention processor for IP-Adapter using xFormers.
Args:
hidden_size (`int`):
The hidden size of the attention layer.
cross_attention_dim (`int`):
The number of channels in the `encoder_hidden_states`.
num_tokens (`int`, `Tuple[int]` or `List[int]`, defaults to `(4,)`):
The context length of the image features.
scale (`float` or `List[float]`, defaults to 1.0):
the weight scale of image prompt.
"""
def _memory_efficient_attention_xformers(self, query, key, value, attention_mask):
# TODO attention_mask
query = query.contiguous()
key = key.contiguous()
value = value.contiguous()
hidden_states = xformers.ops.memory_efficient_attention(query, key, value, attn_bias=attention_mask)
# hidden_states = self.reshape_batch_dim_to_heads(hidden_states)
return hidden_states
def __init__(self, hidden_size, cross_attention_dim=None, num_tokens=(4,), scale=1.0, attention_op: Optional[Callable] = None):
super().__init__()
self.hidden_size = hidden_size
self.cross_attention_dim = cross_attention_dim
self.attention_op = attention_op
if not isinstance(num_tokens, (tuple, list)):
num_tokens = [num_tokens]
self.num_tokens = num_tokens
if not isinstance(scale, list):
scale = [scale] * len(num_tokens)
if len(scale) != len(num_tokens):
raise ValueError("`scale` should be a list of integers with the same length as `num_tokens`.")
self.scale = scale
self.to_k_ip = nn.ModuleList(
[nn.Linear(cross_attention_dim or hidden_size, hidden_size, bias=False) for _ in range(len(num_tokens))]
)
self.to_v_ip = nn.ModuleList(
[nn.Linear(cross_attention_dim or hidden_size, hidden_size, bias=False) for _ in range(len(num_tokens))]
)
def __call__(
self,
attn: Attention,
hidden_states: torch.FloatTensor,
encoder_hidden_states: Optional[torch.FloatTensor] = None,
attention_mask: Optional[torch.FloatTensor] = None,
temb: Optional[torch.FloatTensor] = None,
scale: float = 1.0,
ip_adapter_masks: Optional[torch.FloatTensor] = None,
):
residual = hidden_states
# separate ip_hidden_states from encoder_hidden_states
if encoder_hidden_states is not None:
if isinstance(encoder_hidden_states, tuple):
encoder_hidden_states, ip_hidden_states = encoder_hidden_states
else:
deprecation_message = (
"You have passed a tensor as `encoder_hidden_states`. This is deprecated and will be removed in a future release."
" Please make sure to update your script to pass `encoder_hidden_states` as a tuple to suppress this warning."
)
deprecate("encoder_hidden_states not a tuple", "1.0.0", deprecation_message, standard_warn=False)
end_pos = encoder_hidden_states.shape[1] - self.num_tokens[0]
encoder_hidden_states, ip_hidden_states = (
encoder_hidden_states[:, :end_pos, :],
[encoder_hidden_states[:, end_pos:, :]],
)
if attn.spatial_norm is not None:
hidden_states = attn.spatial_norm(hidden_states, temb)
input_ndim = hidden_states.ndim
if input_ndim == 4:
batch_size, channel, height, width = hidden_states.shape
hidden_states = hidden_states.view(batch_size, channel, height * width).transpose(1, 2)
batch_size, sequence_length, _ = (
hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape
)
if attention_mask is not None:
attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size)
if attn.group_norm is not None:
hidden_states = attn.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2)
query = attn.to_q(hidden_states)
if encoder_hidden_states is None:
encoder_hidden_states = hidden_states
elif attn.norm_cross:
encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states)
key = attn.to_k(encoder_hidden_states)
value = attn.to_v(encoder_hidden_states)
query = attn.head_to_batch_dim(query)
key = attn.head_to_batch_dim(key)
value = attn.head_to_batch_dim(value)
hidden_states = self._memory_efficient_attention_xformers(query, key, value, attention_mask)
hidden_states = hidden_states.to(query.dtype)
hidden_states = attn.batch_to_head_dim(hidden_states)
if ip_adapter_masks is not None:
if not isinstance(ip_adapter_masks, List):
# for backward compatibility, we accept `ip_adapter_mask` as a tensor of shape [num_ip_adapter, 1, height, width]
ip_adapter_masks = list(ip_adapter_masks.unsqueeze(1))
if not (len(ip_adapter_masks) == len(self.scale) == len(ip_hidden_states)):
raise ValueError(
f"Length of ip_adapter_masks array ({len(ip_adapter_masks)}) must match "
f"length of self.scale array ({len(self.scale)}) and number of ip_hidden_states "
f"({len(ip_hidden_states)})"
)
else:
for index, (mask, scale, ip_state) in enumerate(zip(ip_adapter_masks, self.scale, ip_hidden_states)):
if not isinstance(mask, torch.Tensor) or mask.ndim != 4:
raise ValueError(
"Each element of the ip_adapter_masks array should be a tensor with shape "
"[1, num_images_for_ip_adapter, height, width]."
" Please use `IPAdapterMaskProcessor` to preprocess your mask"
)
if mask.shape[1] != ip_state.shape[1]:
raise ValueError(
f"Number of masks ({mask.shape[1]}) does not match "
f"number of ip images ({ip_state.shape[1]}) at index {index}"
)
if isinstance(scale, list) and not len(scale) == mask.shape[1]:
raise ValueError(
f"Number of masks ({mask.shape[1]}) does not match "
f"number of scales ({len(scale)}) at index {index}"
)
else:
ip_adapter_masks = [None] * len(self.scale)
# for ip-adapter
for current_ip_hidden_states, scale, to_k_ip, to_v_ip, mask in zip(
ip_hidden_states, self.scale, self.to_k_ip, self.to_v_ip, ip_adapter_masks
):
skip = False
if isinstance(scale, list):
if all(s == 0 for s in scale):
skip = True
elif scale == 0:
skip = True
if not skip:
if mask is not None:
if not isinstance(scale, list):
scale = [scale] * mask.shape[1]
current_num_images = mask.shape[1]
for i in range(current_num_images):
ip_key = to_k_ip(current_ip_hidden_states[:, i, :, :])
ip_value = to_v_ip(current_ip_hidden_states[:, i, :, :])
ip_key = attn.head_to_batch_dim(ip_key)
ip_value = attn.head_to_batch_dim(ip_value)
_current_ip_hidden_states = self._memory_efficient_attention_xformers(query, ip_key, ip_value, None)
_current_ip_hidden_states = _current_ip_hidden_states.to(query.dtype)
_current_ip_hidden_states = attn.batch_to_head_dim(_current_ip_hidden_states)
mask_downsample = IPAdapterMaskProcessor.downsample(
mask[:, i, :, :],
batch_size,
_current_ip_hidden_states.shape[1],
_current_ip_hidden_states.shape[2],
)
mask_downsample = mask_downsample.to(dtype=query.dtype, device=query.device)
hidden_states = hidden_states + scale[i] * (_current_ip_hidden_states * mask_downsample)
else:
ip_key = to_k_ip(current_ip_hidden_states)
ip_value = to_v_ip(current_ip_hidden_states)
ip_key = attn.head_to_batch_dim(ip_key)
ip_value = attn.head_to_batch_dim(ip_value)
current_ip_hidden_states = self._memory_efficient_attention_xformers(query, ip_key, ip_value, None)
current_ip_hidden_states = current_ip_hidden_states.to(query.dtype)
current_ip_hidden_states = attn.batch_to_head_dim(current_ip_hidden_states)
hidden_states = hidden_states + scale * current_ip_hidden_states
# linear proj
hidden_states = attn.to_out[0](hidden_states)
# dropout
hidden_states = attn.to_out[1](hidden_states)
if input_ndim == 4:
hidden_states = hidden_states.transpose(-1, -2).reshape(batch_size, channel, height, width)
if attn.residual_connection:
hidden_states = hidden_states + residual
hidden_states = hidden_states / attn.rescale_output_factor
return hidden_states |
I did a workaround on this file because sdp is priorized even xformers is enabled, diffusers/src/diffusers/loaders/unet.py Line 972 in 3b37fef
if cross_attention_dim is None or "motion_modules" in name:
if ('XFormers' not in str(self.attn_processors[name].__class__)):
attn_processor_class = (
AttnProcessor2_0 if hasattr(F, "scaled_dot_product_attention") else AttnProcessor
)
attn_procs[name] = attn_processor_class()
else:
attn_procs[name] = self.attn_processors[name]
else:
if ('XFormers' in str(self.attn_processors[name].__class__)):
attn_processor_class = (IPAdapterXformerAttnProcessor)
else:
attn_processor_class = (
IPAdapterAttnProcessor2_0 if hasattr(F, "scaled_dot_product_attention") else IPAdapterAttnProcessor
) So i inclued new class on this parts of code too: diffusers/src/diffusers/loaders/ip_adapter.py Line 284 in 3b37fef
diffusers/src/diffusers/loaders/ip_adapter.py Line 336 in 3b37fef
To works xformers needs be enabled on pipeline before load adapters. |
@elismasilva excellent work |
@elismasilva |
I've never done a PR on diffusers, I need to understand what I need to do, I didn't make these changes in the current version, I'll need to test it in the latest version, but now I'm finishing other work, if it's not urgent as soon as I can get free I can try create a PR. |
@elismasilva |
This issue has been automatically marked as stale because it has not had recent activity. If you think this still needs to be addressed please comment on this thread. Please note that issues that do not follow the contributing guidelines are likely to be ignored. |
Hi folks, I believe we have a solution here by @elismasilva. This would fix both this issue and #8863. Would you like to take it up @elismasilva? We'd be happy to help in the PR, but if not, I'll be happy to take over from your solution and add you as author on the commit |
Hi @a-r-r-o-w I didn't have time to make the PR, but recently I needed to adjust this code in my local solution due to the latest repository updates. I'll update my local code and test this again if it's working I'll submit the PR. |
…ng incorrect attention processor when setting Xformers attn after load ip adapter scale, issues: huggingface#8863 huggingface#8872
fixed in #9881 |
Using a pipeline method that changes attention processors after loading the IP Adapter can lead to weird errors when running the pipeline. e.g: #8863
e.g.
Will lead to the following error
Perhaps we could add a warning message if a pipeline has already loaded the IPAdapter attention processors and attempts to change them?
cc: @yiyixuxu
The text was updated successfully, but these errors were encountered: