Skip to content
New issue

Have a question about this project? # for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “#”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? # to your account

Add warning if Attention Processor is changed after loading IP Adapter #8872

Closed
DN6 opened this issue Jul 16, 2024 · 12 comments
Closed

Add warning if Attention Processor is changed after loading IP Adapter #8872

DN6 opened this issue Jul 16, 2024 · 12 comments

Comments

@DN6
Copy link
Collaborator

DN6 commented Jul 16, 2024

Using a pipeline method that changes attention processors after loading the IP Adapter can lead to weird errors when running the pipeline. e.g: #8863

e.g.

pipe.load_ip_adapter(
    "h94/IP-Adapter",
    subfolder="models",
    weight_name="ip-adapter-plus_sd15.bin"
)
pipe.set_ip_adapter_scale(0.7)
pipe.enable_xformers_memory_efficient_attention()

Will lead to the following error

AttributeError: 'tuple' object has no attribute 'shape'

Perhaps we could add a warning message if a pipeline has already loaded the IPAdapter attention processors and attempts to change them?

cc: @yiyixuxu

@yiyixuxu
Copy link
Collaborator

yiyixuxu commented Jul 16, 2024

we could add an IPAdapterXformerAttnProcessor!
just have to swap out scaled_dot_product_attention with xformers.ops.memory_efficient_attention(...)

hidden_states = xformers.ops.memory_efficient_attention(

similarly to how we handle custom_diffusion

def set_use_memory_efficient_attention_xformers(

@elismasilva
Copy link
Contributor

elismasilva commented Jul 17, 2024

IPAdapterXformerAttnProcessor

Hi @yiyixuxu i will appreciate if we have IPAdapterXformerAttnProcessor, i tested with sdp is woking fine for style transfer, but for xformers it dosent work, but if think in pipeline where load ip adapter is need to set this new attn ? we need to check if xformers was enabled on pipeline to use correct attn.

this file diffusers/src has many instances references for IPAttnProcessor2 i an working in a local copy will be need to check if xformers was applied how i can do this?

on my tests i saw xformersattn applied to all processors if i enable xformers on pipeline before load_ip_adapter, but after it is replaced by sdp.
image

@elismasilva
Copy link
Contributor

I did this is working for me, you guys can reuse this for production.

class IPAdapterXformerAttnProcessor(torch.nn.Module):
    r"""
    Attention processor for IP-Adapter using xFormers.

    Args:
        hidden_size (`int`):
            The hidden size of the attention layer.
        cross_attention_dim (`int`):
            The number of channels in the `encoder_hidden_states`.
        num_tokens (`int`, `Tuple[int]` or `List[int]`, defaults to `(4,)`):
            The context length of the image features.
        scale (`float` or `List[float]`, defaults to 1.0):
            the weight scale of image prompt.
    """
    def _memory_efficient_attention_xformers(self, query, key, value, attention_mask):
        # TODO attention_mask
        query = query.contiguous()
        key = key.contiguous()
        value = value.contiguous()
        hidden_states = xformers.ops.memory_efficient_attention(query, key, value, attn_bias=attention_mask)
        # hidden_states = self.reshape_batch_dim_to_heads(hidden_states)
        return hidden_states
      
    def __init__(self, hidden_size, cross_attention_dim=None, num_tokens=(4,), scale=1.0, attention_op: Optional[Callable] = None):
        super().__init__()

        self.hidden_size = hidden_size
        self.cross_attention_dim = cross_attention_dim
        self.attention_op = attention_op
        
        if not isinstance(num_tokens, (tuple, list)):
            num_tokens = [num_tokens]
        self.num_tokens = num_tokens

        if not isinstance(scale, list):
            scale = [scale] * len(num_tokens)
        if len(scale) != len(num_tokens):
            raise ValueError("`scale` should be a list of integers with the same length as `num_tokens`.")
        self.scale = scale

        self.to_k_ip = nn.ModuleList(
            [nn.Linear(cross_attention_dim or hidden_size, hidden_size, bias=False) for _ in range(len(num_tokens))]
        )
        self.to_v_ip = nn.ModuleList(
            [nn.Linear(cross_attention_dim or hidden_size, hidden_size, bias=False) for _ in range(len(num_tokens))]
        )

    def __call__(
        self,
        attn: Attention,
        hidden_states: torch.FloatTensor,
        encoder_hidden_states: Optional[torch.FloatTensor] = None,
        attention_mask: Optional[torch.FloatTensor] = None,
        temb: Optional[torch.FloatTensor] = None,
        scale: float = 1.0,
        ip_adapter_masks: Optional[torch.FloatTensor] = None,
    ):
        residual = hidden_states

        # separate ip_hidden_states from encoder_hidden_states
        if encoder_hidden_states is not None:
            if isinstance(encoder_hidden_states, tuple):
                encoder_hidden_states, ip_hidden_states = encoder_hidden_states
            else:
                deprecation_message = (
                    "You have passed a tensor as `encoder_hidden_states`. This is deprecated and will be removed in a future release."
                    " Please make sure to update your script to pass `encoder_hidden_states` as a tuple to suppress this warning."
                )
                deprecate("encoder_hidden_states not a tuple", "1.0.0", deprecation_message, standard_warn=False)
                end_pos = encoder_hidden_states.shape[1] - self.num_tokens[0]
                encoder_hidden_states, ip_hidden_states = (
                    encoder_hidden_states[:, :end_pos, :],
                    [encoder_hidden_states[:, end_pos:, :]],
                )

        if attn.spatial_norm is not None:
            hidden_states = attn.spatial_norm(hidden_states, temb)

        input_ndim = hidden_states.ndim

        if input_ndim == 4:
            batch_size, channel, height, width = hidden_states.shape
            hidden_states = hidden_states.view(batch_size, channel, height * width).transpose(1, 2)

        batch_size, sequence_length, _ = (
            hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape
        )

        if attention_mask is not None:
            attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size)          

        if attn.group_norm is not None:
            hidden_states = attn.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2)

        query = attn.to_q(hidden_states)

        if encoder_hidden_states is None:
            encoder_hidden_states = hidden_states
        elif attn.norm_cross:
            encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states)

        key = attn.to_k(encoder_hidden_states)
        value = attn.to_v(encoder_hidden_states)

        query = attn.head_to_batch_dim(query)
        key = attn.head_to_batch_dim(key)
        value = attn.head_to_batch_dim(value)
      
        hidden_states = self._memory_efficient_attention_xformers(query, key, value, attention_mask)

        hidden_states = hidden_states.to(query.dtype)
        hidden_states = attn.batch_to_head_dim(hidden_states)

        if ip_adapter_masks is not None:
            if not isinstance(ip_adapter_masks, List):
                # for backward compatibility, we accept `ip_adapter_mask` as a tensor of shape [num_ip_adapter, 1, height, width]
                ip_adapter_masks = list(ip_adapter_masks.unsqueeze(1))
            if not (len(ip_adapter_masks) == len(self.scale) == len(ip_hidden_states)):
                raise ValueError(
                    f"Length of ip_adapter_masks array ({len(ip_adapter_masks)}) must match "
                    f"length of self.scale array ({len(self.scale)}) and number of ip_hidden_states "
                    f"({len(ip_hidden_states)})"
                )
            else:
                for index, (mask, scale, ip_state) in enumerate(zip(ip_adapter_masks, self.scale, ip_hidden_states)):
                    if not isinstance(mask, torch.Tensor) or mask.ndim != 4:
                        raise ValueError(
                            "Each element of the ip_adapter_masks array should be a tensor with shape "
                            "[1, num_images_for_ip_adapter, height, width]."
                            " Please use `IPAdapterMaskProcessor` to preprocess your mask"
                        )
                    if mask.shape[1] != ip_state.shape[1]:
                        raise ValueError(
                            f"Number of masks ({mask.shape[1]}) does not match "
                            f"number of ip images ({ip_state.shape[1]}) at index {index}"
                        )
                    if isinstance(scale, list) and not len(scale) == mask.shape[1]:
                        raise ValueError(
                            f"Number of masks ({mask.shape[1]}) does not match "
                            f"number of scales ({len(scale)}) at index {index}"
                        )
        else:
            ip_adapter_masks = [None] * len(self.scale)

        # for ip-adapter             
        for current_ip_hidden_states, scale, to_k_ip, to_v_ip, mask in zip(
            ip_hidden_states, self.scale, self.to_k_ip, self.to_v_ip, ip_adapter_masks
        ):
            skip = False
            if isinstance(scale, list):
                if all(s == 0 for s in scale):
                    skip = True
            elif scale == 0:
                skip = True
            if not skip:
                if mask is not None:
                    if not isinstance(scale, list):
                        scale = [scale] * mask.shape[1]

                    current_num_images = mask.shape[1]
                    for i in range(current_num_images):
                        ip_key = to_k_ip(current_ip_hidden_states[:, i, :, :])
                        ip_value = to_v_ip(current_ip_hidden_states[:, i, :, :])

                        ip_key = attn.head_to_batch_dim(ip_key)
                        ip_value = attn.head_to_batch_dim(ip_value)
                        
                        _current_ip_hidden_states = self._memory_efficient_attention_xformers(query, ip_key, ip_value, None)
                       
                        _current_ip_hidden_states = _current_ip_hidden_states.to(query.dtype)
                        _current_ip_hidden_states = attn.batch_to_head_dim(_current_ip_hidden_states)

                        mask_downsample = IPAdapterMaskProcessor.downsample(
                            mask[:, i, :, :],
                            batch_size,
                            _current_ip_hidden_states.shape[1],
                            _current_ip_hidden_states.shape[2],
                        )

                        mask_downsample = mask_downsample.to(dtype=query.dtype, device=query.device)
                        hidden_states = hidden_states + scale[i] * (_current_ip_hidden_states * mask_downsample)
                else:
                    ip_key = to_k_ip(current_ip_hidden_states)
                    ip_value = to_v_ip(current_ip_hidden_states)

                    ip_key = attn.head_to_batch_dim(ip_key)
                    ip_value = attn.head_to_batch_dim(ip_value)
                    
                    current_ip_hidden_states = self._memory_efficient_attention_xformers(query, ip_key, ip_value, None)

                    current_ip_hidden_states = current_ip_hidden_states.to(query.dtype)
                    current_ip_hidden_states = attn.batch_to_head_dim(current_ip_hidden_states)

                    hidden_states = hidden_states + scale * current_ip_hidden_states
            
        # linear proj
        hidden_states = attn.to_out[0](hidden_states)
        # dropout
        hidden_states = attn.to_out[1](hidden_states)

        if input_ndim == 4:
            hidden_states = hidden_states.transpose(-1, -2).reshape(batch_size, channel, height, width)

        if attn.residual_connection:
            hidden_states = hidden_states + residual

        hidden_states = hidden_states / attn.rescale_output_factor

        return hidden_states

@elismasilva
Copy link
Contributor

elismasilva commented Jul 17, 2024

I did a workaround on this file because sdp is priorized even xformers is enabled,

if cross_attention_dim is None or "motion_modules" in name:
changing it to, But i think you have better way to implement this.

  if cross_attention_dim is None or "motion_modules" in name:
            if ('XFormers' not in str(self.attn_processors[name].__class__)):                
                attn_processor_class = (
                    AttnProcessor2_0 if hasattr(F, "scaled_dot_product_attention") else AttnProcessor
                )
                attn_procs[name] = attn_processor_class()
            else:
                attn_procs[name] = self.attn_processors[name]
        else:
            if ('XFormers' in str(self.attn_processors[name].__class__)):
                attn_processor_class = (IPAdapterXformerAttnProcessor)
            else:
                attn_processor_class = (
                    IPAdapterAttnProcessor2_0 if hasattr(F, "scaled_dot_product_attention") else IPAdapterAttnProcessor
                )

So i inclued new class on this parts of code too:

if isinstance(attn_processor, (IPAdapterAttnProcessor, IPAdapterAttnProcessor2_0)):

if isinstance(value, (IPAdapterAttnProcessor, IPAdapterAttnProcessor2_0))

To works xformers needs be enabled on pipeline before load adapters.

@FurkanGozukara
Copy link

@elismasilva excellent work

@yiyixuxu
Copy link
Collaborator

@elismasilva
thanks for your work! would you be willing to open a PR?

@elismasilva
Copy link
Contributor

@elismasilva thanks for your work! would you be willing to open a PR?

I've never done a PR on diffusers, I need to understand what I need to do, I didn't make these changes in the current version, I'll need to test it in the latest version, but now I'm finishing other work, if it's not urgent as soon as I can get free I can try create a PR.

@yiyixuxu
Copy link
Collaborator

@elismasilva
it is not urgent! we have a guide on how to open a PR here https://huggingface.co/docs/diffusers/en/conceptual/contribution#how-to-open-a-pr

Copy link
Contributor

This issue has been automatically marked as stale because it has not had recent activity. If you think this still needs to be addressed please comment on this thread.

Please note that issues that do not follow the contributing guidelines are likely to be ignored.

@github-actions github-actions bot added the stale Issues that haven't received updates label Sep 14, 2024
@a-r-r-o-w
Copy link
Member

Hi folks, I believe we have a solution here by @elismasilva. This would fix both this issue and #8863. Would you like to take it up @elismasilva? We'd be happy to help in the PR, but if not, I'll be happy to take over from your solution and add you as author on the commit

@a-r-r-o-w a-r-r-o-w removed the stale Issues that haven't received updates label Oct 31, 2024
@elismasilva
Copy link
Contributor

Hi @a-r-r-o-w I didn't have time to make the PR, but recently I needed to adjust this code in my local solution due to the latest repository updates. I'll update my local code and test this again if it's working I'll submit the PR.

elismasilva added a commit to DEVAIEXP/diffusers that referenced this issue Nov 6, 2024
…ng incorrect attention processor when setting Xformers attn after load ip adapter scale, issues: huggingface#8863 huggingface#8872
yiyixuxu pushed a commit that referenced this issue Nov 9, 2024
* Feature IP Adapter Xformers Attention Processor: this fix error loading incorrect attention processor when setting Xformers attn after load ip adapter scale, issues: #8863 #8872
@yiyixuxu
Copy link
Collaborator

yiyixuxu commented Nov 9, 2024

fixed in #9881
thanks a lot @elismasilva!!!

@yiyixuxu yiyixuxu closed this as completed Nov 9, 2024
sayakpaul pushed a commit that referenced this issue Dec 23, 2024
* Feature IP Adapter Xformers Attention Processor: this fix error loading incorrect attention processor when setting Xformers attn after load ip adapter scale, issues: #8863 #8872
# for free to join this conversation on GitHub. Already have an account? # to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

5 participants