@@ -318,7 +318,10 @@ def set_use_memory_efficient_attention_xformers(
318
318
XFormersAttnAddedKVProcessor ,
319
319
),
320
320
)
321
-
321
+ is_ip_adapter = hasattr (self , "processor" ) and isinstance (
322
+ self .processor ,
323
+ (IPAdapterAttnProcessor , IPAdapterAttnProcessor2_0 , IPAdapterXFormersAttnProcessor ),
324
+ )
322
325
if use_memory_efficient_attention_xformers :
323
326
if is_added_kv_processor and is_custom_diffusion :
324
327
raise NotImplementedError (
@@ -368,6 +371,19 @@ def set_use_memory_efficient_attention_xformers(
368
371
"Memory efficient attention with `xformers` might currently not work correctly if an attention mask is required for the attention operation."
369
372
)
370
373
processor = XFormersAttnAddedKVProcessor (attention_op = attention_op )
374
+ elif is_ip_adapter :
375
+ processor = IPAdapterXFormersAttnProcessor (
376
+ hidden_size = self .processor .hidden_size ,
377
+ cross_attention_dim = self .processor .cross_attention_dim ,
378
+ num_tokens = self .processor .num_tokens ,
379
+ scale = self .processor .scale ,
380
+ attention_op = attention_op ,
381
+ )
382
+ processor .load_state_dict (self .processor .state_dict ())
383
+ if hasattr (self .processor , "to_k_ip" ):
384
+ processor .to (
385
+ device = self .processor .to_k_ip [0 ].weight .device , dtype = self .processor .to_k_ip [0 ].weight .dtype
386
+ )
371
387
else :
372
388
processor = XFormersAttnProcessor (attention_op = attention_op )
373
389
else :
@@ -386,6 +402,18 @@ def set_use_memory_efficient_attention_xformers(
386
402
processor .load_state_dict (self .processor .state_dict ())
387
403
if hasattr (self .processor , "to_k_custom_diffusion" ):
388
404
processor .to (self .processor .to_k_custom_diffusion .weight .device )
405
+ elif is_ip_adapter :
406
+ processor = IPAdapterAttnProcessor2_0 (
407
+ hidden_size = self .processor .hidden_size ,
408
+ cross_attention_dim = self .processor .cross_attention_dim ,
409
+ num_tokens = self .processor .num_tokens ,
410
+ scale = self .processor .scale ,
411
+ )
412
+ processor .load_state_dict (self .processor .state_dict ())
413
+ if hasattr (self .processor , "to_k_ip" ):
414
+ processor .to (
415
+ device = self .processor .to_k_ip [0 ].weight .device , dtype = self .processor .to_k_ip [0 ].weight .dtype
416
+ )
389
417
else :
390
418
# set attention processor
391
419
# We use the AttnProcessor2_0 by default when torch 2.x is used which uses
@@ -4542,6 +4570,238 @@ def __call__(
4542
4570
return hidden_states
4543
4571
4544
4572
4573
+ class IPAdapterXFormersAttnProcessor (torch .nn .Module ):
4574
+ r"""
4575
+ Attention processor for IP-Adapter using xFormers.
4576
+
4577
+ Args:
4578
+ hidden_size (`int`):
4579
+ The hidden size of the attention layer.
4580
+ cross_attention_dim (`int`):
4581
+ The number of channels in the `encoder_hidden_states`.
4582
+ num_tokens (`int`, `Tuple[int]` or `List[int]`, defaults to `(4,)`):
4583
+ The context length of the image features.
4584
+ scale (`float` or `List[float]`, defaults to 1.0):
4585
+ the weight scale of image prompt.
4586
+ attention_op (`Callable`, *optional*, defaults to `None`):
4587
+ The base
4588
+ [operator](https://facebookresearch.github.io/xformers/components/ops.html#xformers.ops.AttentionOpBase) to
4589
+ use as the attention operator. It is recommended to set to `None`, and allow xFormers to choose the best
4590
+ operator.
4591
+ """
4592
+
4593
+ def __init__ (
4594
+ self ,
4595
+ hidden_size ,
4596
+ cross_attention_dim = None ,
4597
+ num_tokens = (4 ,),
4598
+ scale = 1.0 ,
4599
+ attention_op : Optional [Callable ] = None ,
4600
+ ):
4601
+ super ().__init__ ()
4602
+
4603
+ self .hidden_size = hidden_size
4604
+ self .cross_attention_dim = cross_attention_dim
4605
+ self .attention_op = attention_op
4606
+
4607
+ if not isinstance (num_tokens , (tuple , list )):
4608
+ num_tokens = [num_tokens ]
4609
+ self .num_tokens = num_tokens
4610
+
4611
+ if not isinstance (scale , list ):
4612
+ scale = [scale ] * len (num_tokens )
4613
+ if len (scale ) != len (num_tokens ):
4614
+ raise ValueError ("`scale` should be a list of integers with the same length as `num_tokens`." )
4615
+ self .scale = scale
4616
+
4617
+ self .to_k_ip = nn .ModuleList (
4618
+ [nn .Linear (cross_attention_dim or hidden_size , hidden_size , bias = False ) for _ in range (len (num_tokens ))]
4619
+ )
4620
+ self .to_v_ip = nn .ModuleList (
4621
+ [nn .Linear (cross_attention_dim or hidden_size , hidden_size , bias = False ) for _ in range (len (num_tokens ))]
4622
+ )
4623
+
4624
+ def __call__ (
4625
+ self ,
4626
+ attn : Attention ,
4627
+ hidden_states : torch .FloatTensor ,
4628
+ encoder_hidden_states : Optional [torch .FloatTensor ] = None ,
4629
+ attention_mask : Optional [torch .FloatTensor ] = None ,
4630
+ temb : Optional [torch .FloatTensor ] = None ,
4631
+ scale : float = 1.0 ,
4632
+ ip_adapter_masks : Optional [torch .FloatTensor ] = None ,
4633
+ ):
4634
+ residual = hidden_states
4635
+
4636
+ # separate ip_hidden_states from encoder_hidden_states
4637
+ if encoder_hidden_states is not None :
4638
+ if isinstance (encoder_hidden_states , tuple ):
4639
+ encoder_hidden_states , ip_hidden_states = encoder_hidden_states
4640
+ else :
4641
+ deprecation_message = (
4642
+ "You have passed a tensor as `encoder_hidden_states`. This is deprecated and will be removed in a future release."
4643
+ " Please make sure to update your script to pass `encoder_hidden_states` as a tuple to suppress this warning."
4644
+ )
4645
+ deprecate ("encoder_hidden_states not a tuple" , "1.0.0" , deprecation_message , standard_warn = False )
4646
+ end_pos = encoder_hidden_states .shape [1 ] - self .num_tokens [0 ]
4647
+ encoder_hidden_states , ip_hidden_states = (
4648
+ encoder_hidden_states [:, :end_pos , :],
4649
+ [encoder_hidden_states [:, end_pos :, :]],
4650
+ )
4651
+
4652
+ if attn .spatial_norm is not None :
4653
+ hidden_states = attn .spatial_norm (hidden_states , temb )
4654
+
4655
+ input_ndim = hidden_states .ndim
4656
+
4657
+ if input_ndim == 4 :
4658
+ batch_size , channel , height , width = hidden_states .shape
4659
+ hidden_states = hidden_states .view (batch_size , channel , height * width ).transpose (1 , 2 )
4660
+
4661
+ batch_size , sequence_length , _ = (
4662
+ hidden_states .shape if encoder_hidden_states is None else encoder_hidden_states .shape
4663
+ )
4664
+
4665
+ if attention_mask is not None :
4666
+ attention_mask = attn .prepare_attention_mask (attention_mask , sequence_length , batch_size )
4667
+ # expand our mask's singleton query_tokens dimension:
4668
+ # [batch*heads, 1, key_tokens] ->
4669
+ # [batch*heads, query_tokens, key_tokens]
4670
+ # so that it can be added as a bias onto the attention scores that xformers computes:
4671
+ # [batch*heads, query_tokens, key_tokens]
4672
+ # we do this explicitly because xformers doesn't broadcast the singleton dimension for us.
4673
+ _ , query_tokens , _ = hidden_states .shape
4674
+ attention_mask = attention_mask .expand (- 1 , query_tokens , - 1 )
4675
+
4676
+ if attn .group_norm is not None :
4677
+ hidden_states = attn .group_norm (hidden_states .transpose (1 , 2 )).transpose (1 , 2 )
4678
+
4679
+ query = attn .to_q (hidden_states )
4680
+
4681
+ if encoder_hidden_states is None :
4682
+ encoder_hidden_states = hidden_states
4683
+ elif attn .norm_cross :
4684
+ encoder_hidden_states = attn .norm_encoder_hidden_states (encoder_hidden_states )
4685
+
4686
+ key = attn .to_k (encoder_hidden_states )
4687
+ value = attn .to_v (encoder_hidden_states )
4688
+
4689
+ query = attn .head_to_batch_dim (query ).contiguous ()
4690
+ key = attn .head_to_batch_dim (key ).contiguous ()
4691
+ value = attn .head_to_batch_dim (value ).contiguous ()
4692
+
4693
+ hidden_states = xformers .ops .memory_efficient_attention (
4694
+ query , key , value , attn_bias = attention_mask , op = self .attention_op
4695
+ )
4696
+ hidden_states = hidden_states .to (query .dtype )
4697
+ hidden_states = attn .batch_to_head_dim (hidden_states )
4698
+
4699
+ if ip_hidden_states :
4700
+ if ip_adapter_masks is not None :
4701
+ if not isinstance (ip_adapter_masks , List ):
4702
+ # for backward compatibility, we accept `ip_adapter_mask` as a tensor of shape [num_ip_adapter, 1, height, width]
4703
+ ip_adapter_masks = list (ip_adapter_masks .unsqueeze (1 ))
4704
+ if not (len (ip_adapter_masks ) == len (self .scale ) == len (ip_hidden_states )):
4705
+ raise ValueError (
4706
+ f"Length of ip_adapter_masks array ({ len (ip_adapter_masks )} ) must match "
4707
+ f"length of self.scale array ({ len (self .scale )} ) and number of ip_hidden_states "
4708
+ f"({ len (ip_hidden_states )} )"
4709
+ )
4710
+ else :
4711
+ for index , (mask , scale , ip_state ) in enumerate (
4712
+ zip (ip_adapter_masks , self .scale , ip_hidden_states )
4713
+ ):
4714
+ if mask is None :
4715
+ continue
4716
+ if not isinstance (mask , torch .Tensor ) or mask .ndim != 4 :
4717
+ raise ValueError (
4718
+ "Each element of the ip_adapter_masks array should be a tensor with shape "
4719
+ "[1, num_images_for_ip_adapter, height, width]."
4720
+ " Please use `IPAdapterMaskProcessor` to preprocess your mask"
4721
+ )
4722
+ if mask .shape [1 ] != ip_state .shape [1 ]:
4723
+ raise ValueError (
4724
+ f"Number of masks ({ mask .shape [1 ]} ) does not match "
4725
+ f"number of ip images ({ ip_state .shape [1 ]} ) at index { index } "
4726
+ )
4727
+ if isinstance (scale , list ) and not len (scale ) == mask .shape [1 ]:
4728
+ raise ValueError (
4729
+ f"Number of masks ({ mask .shape [1 ]} ) does not match "
4730
+ f"number of scales ({ len (scale )} ) at index { index } "
4731
+ )
4732
+ else :
4733
+ ip_adapter_masks = [None ] * len (self .scale )
4734
+
4735
+ # for ip-adapter
4736
+ for current_ip_hidden_states , scale , to_k_ip , to_v_ip , mask in zip (
4737
+ ip_hidden_states , self .scale , self .to_k_ip , self .to_v_ip , ip_adapter_masks
4738
+ ):
4739
+ skip = False
4740
+ if isinstance (scale , list ):
4741
+ if all (s == 0 for s in scale ):
4742
+ skip = True
4743
+ elif scale == 0 :
4744
+ skip = True
4745
+ if not skip :
4746
+ if mask is not None :
4747
+ mask = mask .to (torch .float16 )
4748
+ if not isinstance (scale , list ):
4749
+ scale = [scale ] * mask .shape [1 ]
4750
+
4751
+ current_num_images = mask .shape [1 ]
4752
+ for i in range (current_num_images ):
4753
+ ip_key = to_k_ip (current_ip_hidden_states [:, i , :, :])
4754
+ ip_value = to_v_ip (current_ip_hidden_states [:, i , :, :])
4755
+
4756
+ ip_key = attn .head_to_batch_dim (ip_key ).contiguous ()
4757
+ ip_value = attn .head_to_batch_dim (ip_value ).contiguous ()
4758
+
4759
+ _current_ip_hidden_states = xformers .ops .memory_efficient_attention (
4760
+ query , ip_key , ip_value , op = self .attention_op
4761
+ )
4762
+ _current_ip_hidden_states = _current_ip_hidden_states .to (query .dtype )
4763
+ _current_ip_hidden_states = attn .batch_to_head_dim (_current_ip_hidden_states )
4764
+
4765
+ mask_downsample = IPAdapterMaskProcessor .downsample (
4766
+ mask [:, i , :, :],
4767
+ batch_size ,
4768
+ _current_ip_hidden_states .shape [1 ],
4769
+ _current_ip_hidden_states .shape [2 ],
4770
+ )
4771
+
4772
+ mask_downsample = mask_downsample .to (dtype = query .dtype , device = query .device )
4773
+ hidden_states = hidden_states + scale [i ] * (_current_ip_hidden_states * mask_downsample )
4774
+ else :
4775
+ ip_key = to_k_ip (current_ip_hidden_states )
4776
+ ip_value = to_v_ip (current_ip_hidden_states )
4777
+
4778
+ ip_key = attn .head_to_batch_dim (ip_key ).contiguous ()
4779
+ ip_value = attn .head_to_batch_dim (ip_value ).contiguous ()
4780
+
4781
+ current_ip_hidden_states = xformers .ops .memory_efficient_attention (
4782
+ query , ip_key , ip_value , op = self .attention_op
4783
+ )
4784
+ current_ip_hidden_states = current_ip_hidden_states .to (query .dtype )
4785
+ current_ip_hidden_states = attn .batch_to_head_dim (current_ip_hidden_states )
4786
+
4787
+ hidden_states = hidden_states + scale * current_ip_hidden_states
4788
+
4789
+ # linear proj
4790
+ hidden_states = attn .to_out [0 ](hidden_states )
4791
+ # dropout
4792
+ hidden_states = attn .to_out [1 ](hidden_states )
4793
+
4794
+ if input_ndim == 4 :
4795
+ hidden_states = hidden_states .transpose (- 1 , - 2 ).reshape (batch_size , channel , height , width )
4796
+
4797
+ if attn .residual_connection :
4798
+ hidden_states = hidden_states + residual
4799
+
4800
+ hidden_states = hidden_states / attn .rescale_output_factor
4801
+
4802
+ return hidden_states
4803
+
4804
+
4545
4805
class PAGIdentitySelfAttnProcessor2_0 :
4546
4806
r"""
4547
4807
Processor for implementing PAG using scaled dot-product attention (enabled by default if you're using PyTorch 2.0).
0 commit comments