From 0af910be2f4ee308b2c2f8cadafd94e95d1b434e Mon Sep 17 00:00:00 2001 From: Daniel Regado <danielregado@gmail.com> Date: Thu, 21 Nov 2024 18:18:46 +0000 Subject: [PATCH 01/44] Initial pipeline for SD3.5-Large-IP-Adapter --- .../pipeline_stable_diffusion_3_ipa.py | 1423 +++++++++++++++++ src/diffusers/models/attention.py | 8 +- .../models/transformers/transformer_sd3.py | 15 +- 3 files changed, 1434 insertions(+), 12 deletions(-) create mode 100644 examples/community/pipeline_stable_diffusion_3_ipa.py diff --git a/examples/community/pipeline_stable_diffusion_3_ipa.py b/examples/community/pipeline_stable_diffusion_3_ipa.py new file mode 100644 index 000000000000..490f65b7a719 --- /dev/null +++ b/examples/community/pipeline_stable_diffusion_3_ipa.py @@ -0,0 +1,1423 @@ +# Copyright 2024 Stability AI, The HuggingFace Team and The InstantX Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import inspect +from typing import Any, Callable, Dict, List, Optional, Union + +import torch +import torch.nn as nn +import torch.nn.functional as F +from transformers import ( + CLIPTextModelWithProjection, + CLIPTokenizer, + T5EncoderModel, + T5TokenizerFast, +) + +from diffusers.image_processor import VaeImageProcessor +from diffusers.loaders import FromSingleFileMixin, SD3LoraLoaderMixin +from diffusers.models.autoencoders import AutoencoderKL +from diffusers.schedulers import FlowMatchEulerDiscreteScheduler +from diffusers.utils import ( + USE_PEFT_BACKEND, + is_torch_xla_available, + logging, + replace_example_docstring, + scale_lora_layers, + unscale_lora_layers, +) +from diffusers.utils.torch_utils import randn_tensor +from diffusers.pipelines.pipeline_utils import DiffusionPipeline +from diffusers.pipelines.stable_diffusion_3.pipeline_output import StableDiffusion3PipelineOutput + +from diffusers.models.transformers import SD3Transformer2DModel +from diffusers.models.normalization import RMSNorm +from einops import rearrange +import math + +from diffusers.models.embeddings import Timesteps, TimestepEmbedding + + +if is_torch_xla_available(): + import torch_xla.core.xla_model as xm + + XLA_AVAILABLE = True +else: + XLA_AVAILABLE = False + + +logger = logging.get_logger(__name__) # pylint: disable=invalid-name + +EXAMPLE_DOC_STRING = """ + Examples: + ```py + >>> import torch + >>> from diffusers import StableDiffusion3Pipeline + + >>> pipe = StableDiffusion3Pipeline.from_pretrained( + ... "stabilityai/stable-diffusion-3-medium-diffusers", torch_dtype=torch.float16 + ... ) + >>> pipe.to("cuda") + >>> prompt = "A cat holding a sign that says hello world" + >>> image = pipe(prompt).images[0] + >>> image.save("sd3.png") + ``` +""" + + +# FFN +def FeedForward(dim, mult=4): + inner_dim = int(dim * mult) + return nn.Sequential( + nn.LayerNorm(dim), + nn.Linear(dim, inner_dim, bias=False), + nn.GELU(), + nn.Linear(inner_dim, dim, bias=False), + ) + + +def reshape_tensor(x, heads): + bs, length, width = x.shape + #(bs, length, width) --> (bs, length, n_heads, dim_per_head) + x = x.view(bs, length, heads, -1) + # (bs, length, n_heads, dim_per_head) --> (bs, n_heads, length, dim_per_head) + x = x.transpose(1, 2) + # (bs, n_heads, length, dim_per_head) --> (bs*n_heads, length, dim_per_head) + x = x.reshape(bs, heads, length, -1) + return x + + +class PerceiverAttention(nn.Module): + def __init__(self, *, dim, dim_head=64, heads=8): + super().__init__() + self.scale = dim_head**-0.5 + self.dim_head = dim_head + self.heads = heads + inner_dim = dim_head * heads + + self.norm1 = nn.LayerNorm(dim) + self.norm2 = nn.LayerNorm(dim) + + self.to_q = nn.Linear(dim, inner_dim, bias=False) + self.to_kv = nn.Linear(dim, inner_dim * 2, bias=False) + self.to_out = nn.Linear(inner_dim, dim, bias=False) + + + def forward(self, x, latents, shift=None, scale=None): + """ + Args: + x (torch.Tensor): image features + shape (b, n1, D) + latent (torch.Tensor): latent features + shape (b, n2, D) + """ + x = self.norm1(x) + latents = self.norm2(latents) + + if shift is not None and scale is not None: + latents = latents * (1 + scale.unsqueeze(1)) + shift.unsqueeze(1) + + b, l, _ = latents.shape + + q = self.to_q(latents) + kv_input = torch.cat((x, latents), dim=-2) + k, v = self.to_kv(kv_input).chunk(2, dim=-1) + + q = reshape_tensor(q, self.heads) + k = reshape_tensor(k, self.heads) + v = reshape_tensor(v, self.heads) + + # attention + scale = 1 / math.sqrt(math.sqrt(self.dim_head)) + weight = (q * scale) @ (k * scale).transpose(-2, -1) # More stable with f16 than dividing afterwards + weight = torch.softmax(weight.float(), dim=-1).type(weight.dtype) + out = weight @ v + + out = out.permute(0, 2, 1, 3).reshape(b, l, -1) + + return self.to_out(out) + + +# modified from https://github.com/mlfoundations/open_flamingo/blob/main/open_flamingo/src/helpers.py +class TimeResampler(nn.Module): + def __init__( + self, + dim=1024, + depth=8, + dim_head=64, + heads=16, + num_queries=8, + embedding_dim=768, + output_dim=1024, + ff_mult=4, + timestep_in_dim=320, + timestep_flip_sin_to_cos=True, + timestep_freq_shift=0, + ): + super().__init__() + + self.latents = nn.Parameter(torch.randn(1, num_queries, dim) / dim**0.5) + + self.proj_in = nn.Linear(embedding_dim, dim) + + self.proj_out = nn.Linear(dim, output_dim) + self.norm_out = nn.LayerNorm(output_dim) + + self.layers = nn.ModuleList([]) + for _ in range(depth): + self.layers.append( + nn.ModuleList( + [ + # msa + PerceiverAttention(dim=dim, dim_head=dim_head, heads=heads), + # ff + FeedForward(dim=dim, mult=ff_mult), + # adaLN + nn.Sequential(nn.SiLU(), nn.Linear(dim, 4 * dim, bias=True)) + ] + ) + ) + + # time + self.time_proj = Timesteps(timestep_in_dim, timestep_flip_sin_to_cos, timestep_freq_shift) + self.time_embedding = TimestepEmbedding(timestep_in_dim, dim, act_fn="silu") + + # adaLN + # self.adaLN_modulation = nn.Sequential( + # nn.SiLU(), + # nn.Linear(timestep_out_dim, 6 * timestep_out_dim, bias=True) + # ) + + + def forward(self, x, timestep, need_temb=False): + timestep_emb = self.embedding_time(x, timestep) # bs, dim + + latents = self.latents.repeat(x.size(0), 1, 1) + + x = self.proj_in(x) + x = x + timestep_emb[:, None] + + for attn, ff, adaLN_modulation in self.layers: + shift_msa, scale_msa, shift_mlp, scale_mlp = adaLN_modulation(timestep_emb).chunk(4, dim=1) + latents = attn(x, latents, shift_msa, scale_msa) + latents + + res = latents + for idx_ff in range(len(ff)): + layer_ff = ff[idx_ff] + latents = layer_ff(latents) + if idx_ff == 0 and isinstance(layer_ff, nn.LayerNorm): # adaLN + latents = latents * (1 + scale_mlp.unsqueeze(1)) + shift_mlp.unsqueeze(1) + latents = latents + res + + # latents = ff(latents) + latents + + latents = self.proj_out(latents) + latents = self.norm_out(latents) + + if need_temb: + return latents, timestep_emb + else: + return latents + + + + def embedding_time(self, sample, timestep): + + # 1. time + timesteps = timestep + if not torch.is_tensor(timesteps): + # TODO: this requires sync between CPU and GPU. So try to pass timesteps as tensors if you can + # This would be a good case for the `match` statement (Python 3.10+) + is_mps = sample.device.type == "mps" + if isinstance(timestep, float): + dtype = torch.float32 if is_mps else torch.float64 + else: + dtype = torch.int32 if is_mps else torch.int64 + timesteps = torch.tensor([timesteps], dtype=dtype, device=sample.device) + elif len(timesteps.shape) == 0: + timesteps = timesteps[None].to(sample.device) + + # broadcast to batch dimension in a way that's compatible with ONNX/Core ML + timesteps = timesteps.expand(sample.shape[0]) + + t_emb = self.time_proj(timesteps) + + # timesteps does not contain any weights and will always return f32 tensors + # but time_embedding might actually be running in fp16. so we need to cast here. + # there might be better ways to encapsulate this. + t_emb = t_emb.to(dtype=sample.dtype) + + emb = self.time_embedding(t_emb, None) + return emb + + +class AdaLayerNorm(nn.Module): + """ + Norm layer adaptive layer norm zero (adaLN-Zero). + + Parameters: + embedding_dim (`int`): The size of each embedding vector. + num_embeddings (`int`): The size of the embeddings dictionary. + """ + + def __init__(self, embedding_dim: int, time_embedding_dim=None, mode='normal'): + super().__init__() + + self.silu = nn.SiLU() + num_params_dict = dict( + zero=6, + normal=2, + ) + num_params = num_params_dict[mode] + self.linear = nn.Linear(time_embedding_dim or embedding_dim, num_params * embedding_dim, bias=True) + self.norm = nn.LayerNorm(embedding_dim, elementwise_affine=False, eps=1e-6) + self.mode = mode + + def forward( + self, + x, + hidden_dtype = None, + emb = None, + ): + emb = self.linear(self.silu(emb)) + if self.mode == 'normal': + shift_msa, scale_msa = emb.chunk(2, dim=1) + x = self.norm(x) * (1 + scale_msa[:, None]) + shift_msa[:, None] + return x + + elif self.mode == 'zero': + shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp = emb.chunk(6, dim=1) + x = self.norm(x) * (1 + scale_msa[:, None]) + shift_msa[:, None] + return x, gate_msa, shift_mlp, scale_mlp, gate_mlp + + +class JointIPAttnProcessor(torch.nn.Module): + """Attention processor used typically in processing the SD3-like self-attention projections.""" + + def __init__( + self, + hidden_size=None, + cross_attention_dim=None, + ip_hidden_states_dim=None, + ip_encoder_hidden_states_dim=None, + head_dim=None, + timesteps_emb_dim=1280, + ): + super().__init__() + + self.norm_ip = AdaLayerNorm(ip_hidden_states_dim, time_embedding_dim=timesteps_emb_dim) + self.to_k_ip = nn.Linear(ip_hidden_states_dim, hidden_size, bias=False) + self.to_v_ip = nn.Linear(ip_hidden_states_dim, hidden_size, bias=False) + self.norm_q = RMSNorm(head_dim, 1e-6) + self.norm_k = RMSNorm(head_dim, 1e-6) + self.norm_ip_k = RMSNorm(head_dim, 1e-6) + + + def __call__( + self, + attn, + hidden_states: torch.FloatTensor, + encoder_hidden_states: torch.FloatTensor = None, + attention_mask: Optional[torch.FloatTensor] = None, + emb_dict=None, + *args, + **kwargs, + ) -> torch.FloatTensor: + residual = hidden_states + + batch_size = hidden_states.shape[0] + + # `sample` projections. + query = attn.to_q(hidden_states) + key = attn.to_k(hidden_states) + value = attn.to_v(hidden_states) + img_query = query + img_key = key + img_value = value + + inner_dim = key.shape[-1] + head_dim = inner_dim // attn.heads + + query = query.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2) + key = key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2) + value = value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2) + + if attn.norm_q is not None: + query = attn.norm_q(query) + if attn.norm_k is not None: + key = attn.norm_k(key) + + # `context` projections. + if encoder_hidden_states is not None: + encoder_hidden_states_query_proj = attn.add_q_proj(encoder_hidden_states) + encoder_hidden_states_key_proj = attn.add_k_proj(encoder_hidden_states) + encoder_hidden_states_value_proj = attn.add_v_proj(encoder_hidden_states) + + encoder_hidden_states_query_proj = encoder_hidden_states_query_proj.view( + batch_size, -1, attn.heads, head_dim + ).transpose(1, 2) + encoder_hidden_states_key_proj = encoder_hidden_states_key_proj.view( + batch_size, -1, attn.heads, head_dim + ).transpose(1, 2) + encoder_hidden_states_value_proj = encoder_hidden_states_value_proj.view( + batch_size, -1, attn.heads, head_dim + ).transpose(1, 2) + + if attn.norm_added_q is not None: + encoder_hidden_states_query_proj = attn.norm_added_q(encoder_hidden_states_query_proj) + if attn.norm_added_k is not None: + encoder_hidden_states_key_proj = attn.norm_added_k(encoder_hidden_states_key_proj) + + query = torch.cat([query, encoder_hidden_states_query_proj], dim=2) + key = torch.cat([key, encoder_hidden_states_key_proj], dim=2) + value = torch.cat([value, encoder_hidden_states_value_proj], dim=2) + + hidden_states = F.scaled_dot_product_attention(query, key, value, dropout_p=0.0, is_causal=False) + hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim) + hidden_states = hidden_states.to(query.dtype) + + if encoder_hidden_states is not None: + # Split the attention outputs. + hidden_states, encoder_hidden_states = ( + hidden_states[:, : residual.shape[1]], + hidden_states[:, residual.shape[1] :], + ) + if not attn.context_pre_only: + encoder_hidden_states = attn.to_add_out(encoder_hidden_states) + + + # IPadapter + ip_hidden_states = emb_dict.get('ip_hidden_states', None) + ip_hidden_states = self.get_ip_hidden_states( + attn, + img_query, + ip_hidden_states, + img_key, + img_value, + None, + None, + emb_dict['temb'], + ) + if ip_hidden_states is not None: + hidden_states = hidden_states + ip_hidden_states * emb_dict.get('scale', 1.0) + + + # linear proj + hidden_states = attn.to_out[0](hidden_states) + # dropout + hidden_states = attn.to_out[1](hidden_states) + + if encoder_hidden_states is not None: + return hidden_states, encoder_hidden_states + else: + return hidden_states + + + def get_ip_hidden_states(self, attn, query, ip_hidden_states, img_key=None, img_value=None, text_key=None, text_value=None, temb=None): + if ip_hidden_states is None: + return None + + if not hasattr(self, 'to_k_ip') or not hasattr(self, 'to_v_ip'): + return None + + # norm ip input + norm_ip_hidden_states = self.norm_ip(ip_hidden_states, emb=temb) + + # to k and v + ip_key = self.to_k_ip(norm_ip_hidden_states) + ip_value = self.to_v_ip(norm_ip_hidden_states) + + # reshape + query = rearrange(query, 'b l (h d) -> b h l d', h=attn.heads) + img_key = rearrange(img_key, 'b l (h d) -> b h l d', h=attn.heads) + img_value = rearrange(img_value, 'b l (h d) -> b h l d', h=attn.heads) + ip_key = rearrange(ip_key, 'b l (h d) -> b h l d', h=attn.heads) + ip_value = rearrange(ip_value, 'b l (h d) -> b h l d', h=attn.heads) + + # norm + query = self.norm_q(query) + img_key = self.norm_k(img_key) + ip_key = self.norm_ip_k(ip_key) + + # cat img + key = torch.cat([img_key, ip_key], dim=2) + value = torch.cat([img_value, ip_value], dim=2) + + # + ip_hidden_states = F.scaled_dot_product_attention(query, key, value, dropout_p=0.0, is_causal=False) + ip_hidden_states = rearrange(ip_hidden_states, 'b h l d -> b l (h d)') + ip_hidden_states = ip_hidden_states.to(query.dtype) + return ip_hidden_states + + +# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.retrieve_timesteps +def retrieve_timesteps( + scheduler, + num_inference_steps: Optional[int] = None, + device: Optional[Union[str, torch.device]] = None, + timesteps: Optional[List[int]] = None, + sigmas: Optional[List[float]] = None, + **kwargs, +): + """ + Calls the scheduler's `set_timesteps` method and retrieves timesteps from the scheduler after the call. Handles + custom timesteps. Any kwargs will be supplied to `scheduler.set_timesteps`. + + Args: + scheduler (`SchedulerMixin`): + The scheduler to get timesteps from. + num_inference_steps (`int`): + The number of diffusion steps used when generating samples with a pre-trained model. If used, `timesteps` + must be `None`. + device (`str` or `torch.device`, *optional*): + The device to which the timesteps should be moved to. If `None`, the timesteps are not moved. + timesteps (`List[int]`, *optional*): + Custom timesteps used to override the timestep spacing strategy of the scheduler. If `timesteps` is passed, + `num_inference_steps` and `sigmas` must be `None`. + sigmas (`List[float]`, *optional*): + Custom sigmas used to override the timestep spacing strategy of the scheduler. If `sigmas` is passed, + `num_inference_steps` and `timesteps` must be `None`. + + Returns: + `Tuple[torch.Tensor, int]`: A tuple where the first element is the timestep schedule from the scheduler and the + second element is the number of inference steps. + """ + if timesteps is not None and sigmas is not None: + raise ValueError("Only one of `timesteps` or `sigmas` can be passed. Please choose one to set custom values") + if timesteps is not None: + accepts_timesteps = "timesteps" in set(inspect.signature(scheduler.set_timesteps).parameters.keys()) + if not accepts_timesteps: + raise ValueError( + f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom" + f" timestep schedules. Please check whether you are using the correct scheduler." + ) + scheduler.set_timesteps(timesteps=timesteps, device=device, **kwargs) + timesteps = scheduler.timesteps + num_inference_steps = len(timesteps) + elif sigmas is not None: + accept_sigmas = "sigmas" in set(inspect.signature(scheduler.set_timesteps).parameters.keys()) + if not accept_sigmas: + raise ValueError( + f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom" + f" sigmas schedules. Please check whether you are using the correct scheduler." + ) + scheduler.set_timesteps(sigmas=sigmas, device=device, **kwargs) + timesteps = scheduler.timesteps + num_inference_steps = len(timesteps) + else: + scheduler.set_timesteps(num_inference_steps, device=device, **kwargs) + timesteps = scheduler.timesteps + return timesteps, num_inference_steps + + +class StableDiffusion3Pipeline(DiffusionPipeline, SD3LoraLoaderMixin, FromSingleFileMixin): + r""" + Args: + transformer ([`SD3Transformer2DModel`]): + Conditional Transformer (MMDiT) architecture to denoise the encoded image latents. + scheduler ([`FlowMatchEulerDiscreteScheduler`]): + A scheduler to be used in combination with `transformer` to denoise the encoded image latents. + vae ([`AutoencoderKL`]): + Variational Auto-Encoder (VAE) Model to encode and decode images to and from latent representations. + text_encoder ([`CLIPTextModelWithProjection`]): + [CLIP](https://huggingface.co/docs/transformers/model_doc/clip#transformers.CLIPTextModelWithProjection), + specifically the [clip-vit-large-patch14](https://huggingface.co/openai/clip-vit-large-patch14) variant, + with an additional added projection layer that is initialized with a diagonal matrix with the `hidden_size` + as its dimension. + text_encoder_2 ([`CLIPTextModelWithProjection`]): + [CLIP](https://huggingface.co/docs/transformers/model_doc/clip#transformers.CLIPTextModelWithProjection), + specifically the + [laion/CLIP-ViT-bigG-14-laion2B-39B-b160k](https://huggingface.co/laion/CLIP-ViT-bigG-14-laion2B-39B-b160k) + variant. + text_encoder_3 ([`T5EncoderModel`]): + Frozen text-encoder. Stable Diffusion 3 uses + [T5](https://huggingface.co/docs/transformers/model_doc/t5#transformers.T5EncoderModel), specifically the + [t5-v1_1-xxl](https://huggingface.co/google/t5-v1_1-xxl) variant. + tokenizer (`CLIPTokenizer`): + Tokenizer of class + [CLIPTokenizer](https://huggingface.co/docs/transformers/v4.21.0/en/model_doc/clip#transformers.CLIPTokenizer). + tokenizer_2 (`CLIPTokenizer`): + Second Tokenizer of class + [CLIPTokenizer](https://huggingface.co/docs/transformers/v4.21.0/en/model_doc/clip#transformers.CLIPTokenizer). + tokenizer_3 (`T5TokenizerFast`): + Tokenizer of class + [T5Tokenizer](https://huggingface.co/docs/transformers/model_doc/t5#transformers.T5Tokenizer). + """ + + model_cpu_offload_seq = "text_encoder->text_encoder_2->text_encoder_3->transformer->vae" + _optional_components = [] + _callback_tensor_inputs = ["latents", "prompt_embeds", "negative_prompt_embeds", "negative_pooled_prompt_embeds"] + + def __init__( + self, + transformer: SD3Transformer2DModel, + scheduler: FlowMatchEulerDiscreteScheduler, + vae: AutoencoderKL, + text_encoder: CLIPTextModelWithProjection, + tokenizer: CLIPTokenizer, + text_encoder_2: CLIPTextModelWithProjection, + tokenizer_2: CLIPTokenizer, + text_encoder_3: T5EncoderModel, + tokenizer_3: T5TokenizerFast, + ): + super().__init__() + + self.register_modules( + vae=vae, + text_encoder=text_encoder, + text_encoder_2=text_encoder_2, + text_encoder_3=text_encoder_3, + tokenizer=tokenizer, + tokenizer_2=tokenizer_2, + tokenizer_3=tokenizer_3, + transformer=transformer, + scheduler=scheduler, + ) + self.vae_scale_factor = ( + 2 ** (len(self.vae.config.block_out_channels) - 1) if hasattr(self, "vae") and self.vae is not None else 8 + ) + self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor) + self.tokenizer_max_length = ( + self.tokenizer.model_max_length if hasattr(self, "tokenizer") and self.tokenizer is not None else 77 + ) + self.default_sample_size = ( + self.transformer.config.sample_size + if hasattr(self, "transformer") and self.transformer is not None + else 128 + ) + + def _get_t5_prompt_embeds( + self, + prompt: Union[str, List[str]] = None, + num_images_per_prompt: int = 1, + max_sequence_length: int = 256, + device: Optional[torch.device] = None, + dtype: Optional[torch.dtype] = None, + ): + device = device or self._execution_device + dtype = dtype or self.text_encoder.dtype + + prompt = [prompt] if isinstance(prompt, str) else prompt + batch_size = len(prompt) + + if self.text_encoder_3 is None: + return torch.zeros( + ( + batch_size * num_images_per_prompt, + self.tokenizer_max_length, + self.transformer.config.joint_attention_dim, + ), + device=device, + dtype=dtype, + ) + + text_inputs = self.tokenizer_3( + prompt, + padding="max_length", + max_length=max_sequence_length, + truncation=True, + add_special_tokens=True, + return_tensors="pt", + ) + text_input_ids = text_inputs.input_ids + untruncated_ids = self.tokenizer_3(prompt, padding="longest", return_tensors="pt").input_ids + + if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal(text_input_ids, untruncated_ids): + removed_text = self.tokenizer_3.batch_decode(untruncated_ids[:, self.tokenizer_max_length - 1 : -1]) + logger.warning( + "The following part of your input was truncated because `max_sequence_length` is set to " + f" {max_sequence_length} tokens: {removed_text}" + ) + + prompt_embeds = self.text_encoder_3(text_input_ids.to(device))[0] + + dtype = self.text_encoder_3.dtype + prompt_embeds = prompt_embeds.to(dtype=dtype, device=device) + + _, seq_len, _ = prompt_embeds.shape + + # duplicate text embeddings and attention mask for each generation per prompt, using mps friendly method + prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1) + prompt_embeds = prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1) + + return prompt_embeds + + def _get_clip_prompt_embeds( + self, + prompt: Union[str, List[str]], + num_images_per_prompt: int = 1, + device: Optional[torch.device] = None, + clip_skip: Optional[int] = None, + clip_model_index: int = 0, + ): + device = device or self._execution_device + + clip_tokenizers = [self.tokenizer, self.tokenizer_2] + clip_text_encoders = [self.text_encoder, self.text_encoder_2] + + tokenizer = clip_tokenizers[clip_model_index] + text_encoder = clip_text_encoders[clip_model_index] + + prompt = [prompt] if isinstance(prompt, str) else prompt + batch_size = len(prompt) + + text_inputs = tokenizer( + prompt, + padding="max_length", + max_length=self.tokenizer_max_length, + truncation=True, + return_tensors="pt", + ) + + text_input_ids = text_inputs.input_ids + untruncated_ids = tokenizer(prompt, padding="longest", return_tensors="pt").input_ids + if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal(text_input_ids, untruncated_ids): + removed_text = tokenizer.batch_decode(untruncated_ids[:, self.tokenizer_max_length - 1 : -1]) + logger.warning( + "The following part of your input was truncated because CLIP can only handle sequences up to" + f" {self.tokenizer_max_length} tokens: {removed_text}" + ) + prompt_embeds = text_encoder(text_input_ids.to(device), output_hidden_states=True) + pooled_prompt_embeds = prompt_embeds[0] + + if clip_skip is None: + prompt_embeds = prompt_embeds.hidden_states[-2] + else: + prompt_embeds = prompt_embeds.hidden_states[-(clip_skip + 2)] + + prompt_embeds = prompt_embeds.to(dtype=self.text_encoder.dtype, device=device) + + _, seq_len, _ = prompt_embeds.shape + # duplicate text embeddings for each generation per prompt, using mps friendly method + prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1) + prompt_embeds = prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1) + + pooled_prompt_embeds = pooled_prompt_embeds.repeat(1, num_images_per_prompt, 1) + pooled_prompt_embeds = pooled_prompt_embeds.view(batch_size * num_images_per_prompt, -1) + + return prompt_embeds, pooled_prompt_embeds + + def encode_prompt( + self, + prompt: Union[str, List[str]], + prompt_2: Union[str, List[str]], + prompt_3: Union[str, List[str]], + device: Optional[torch.device] = None, + num_images_per_prompt: int = 1, + do_classifier_free_guidance: bool = True, + negative_prompt: Optional[Union[str, List[str]]] = None, + negative_prompt_2: Optional[Union[str, List[str]]] = None, + negative_prompt_3: Optional[Union[str, List[str]]] = None, + prompt_embeds: Optional[torch.FloatTensor] = None, + negative_prompt_embeds: Optional[torch.FloatTensor] = None, + pooled_prompt_embeds: Optional[torch.FloatTensor] = None, + negative_pooled_prompt_embeds: Optional[torch.FloatTensor] = None, + clip_skip: Optional[int] = None, + max_sequence_length: int = 256, + lora_scale: Optional[float] = None, + ): + r""" + + Args: + prompt (`str` or `List[str]`, *optional*): + prompt to be encoded + prompt_2 (`str` or `List[str]`, *optional*): + The prompt or prompts to be sent to the `tokenizer_2` and `text_encoder_2`. If not defined, `prompt` is + used in all text-encoders + prompt_3 (`str` or `List[str]`, *optional*): + The prompt or prompts to be sent to the `tokenizer_3` and `text_encoder_3`. If not defined, `prompt` is + used in all text-encoders + device: (`torch.device`): + torch device + num_images_per_prompt (`int`): + number of images that should be generated per prompt + do_classifier_free_guidance (`bool`): + whether to use classifier free guidance or not + negative_prompt (`str` or `List[str]`, *optional*): + The prompt or prompts not to guide the image generation. If not defined, one has to pass + `negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is + less than `1`). + negative_prompt_2 (`str` or `List[str]`, *optional*): + The prompt or prompts not to guide the image generation to be sent to `tokenizer_2` and + `text_encoder_2`. If not defined, `negative_prompt` is used in all the text-encoders. + negative_prompt_2 (`str` or `List[str]`, *optional*): + The prompt or prompts not to guide the image generation to be sent to `tokenizer_3` and + `text_encoder_3`. If not defined, `negative_prompt` is used in both text-encoders + prompt_embeds (`torch.FloatTensor`, *optional*): + Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not + provided, text embeddings will be generated from `prompt` input argument. + negative_prompt_embeds (`torch.FloatTensor`, *optional*): + Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt + weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input + argument. + pooled_prompt_embeds (`torch.FloatTensor`, *optional*): + Pre-generated pooled text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. + If not provided, pooled text embeddings will be generated from `prompt` input argument. + negative_pooled_prompt_embeds (`torch.FloatTensor`, *optional*): + Pre-generated negative pooled text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt + weighting. If not provided, pooled negative_prompt_embeds will be generated from `negative_prompt` + input argument. + clip_skip (`int`, *optional*): + Number of layers to be skipped from CLIP while computing the prompt embeddings. A value of 1 means that + the output of the pre-final layer will be used for computing the prompt embeddings. + lora_scale (`float`, *optional*): + A lora scale that will be applied to all LoRA layers of the text encoder if LoRA layers are loaded. + """ + device = device or self._execution_device + + # set lora scale so that monkey patched LoRA + # function of text encoder can correctly access it + if lora_scale is not None and isinstance(self, SD3LoraLoaderMixin): + self._lora_scale = lora_scale + + # dynamically adjust the LoRA scale + if self.text_encoder is not None and USE_PEFT_BACKEND: + scale_lora_layers(self.text_encoder, lora_scale) + if self.text_encoder_2 is not None and USE_PEFT_BACKEND: + scale_lora_layers(self.text_encoder_2, lora_scale) + + prompt = [prompt] if isinstance(prompt, str) else prompt + if prompt is not None: + batch_size = len(prompt) + else: + batch_size = prompt_embeds.shape[0] + + if prompt_embeds is None: + prompt_2 = prompt_2 or prompt + prompt_2 = [prompt_2] if isinstance(prompt_2, str) else prompt_2 + + prompt_3 = prompt_3 or prompt + prompt_3 = [prompt_3] if isinstance(prompt_3, str) else prompt_3 + + prompt_embed, pooled_prompt_embed = self._get_clip_prompt_embeds( + prompt=prompt, + device=device, + num_images_per_prompt=num_images_per_prompt, + clip_skip=clip_skip, + clip_model_index=0, + ) + prompt_2_embed, pooled_prompt_2_embed = self._get_clip_prompt_embeds( + prompt=prompt_2, + device=device, + num_images_per_prompt=num_images_per_prompt, + clip_skip=clip_skip, + clip_model_index=1, + ) + clip_prompt_embeds = torch.cat([prompt_embed, prompt_2_embed], dim=-1) + + t5_prompt_embed = self._get_t5_prompt_embeds( + prompt=prompt_3, + num_images_per_prompt=num_images_per_prompt, + max_sequence_length=max_sequence_length, + device=device, + ) + + clip_prompt_embeds = torch.nn.functional.pad( + clip_prompt_embeds, (0, t5_prompt_embed.shape[-1] - clip_prompt_embeds.shape[-1]) + ) + + prompt_embeds = torch.cat([clip_prompt_embeds, t5_prompt_embed], dim=-2) + pooled_prompt_embeds = torch.cat([pooled_prompt_embed, pooled_prompt_2_embed], dim=-1) + + if do_classifier_free_guidance and negative_prompt_embeds is None: + negative_prompt = negative_prompt or "" + negative_prompt_2 = negative_prompt_2 or negative_prompt + negative_prompt_3 = negative_prompt_3 or negative_prompt + + # normalize str to list + negative_prompt = batch_size * [negative_prompt] if isinstance(negative_prompt, str) else negative_prompt + negative_prompt_2 = ( + batch_size * [negative_prompt_2] if isinstance(negative_prompt_2, str) else negative_prompt_2 + ) + negative_prompt_3 = ( + batch_size * [negative_prompt_3] if isinstance(negative_prompt_3, str) else negative_prompt_3 + ) + + if prompt is not None and type(prompt) is not type(negative_prompt): + raise TypeError( + f"`negative_prompt` should be the same type to `prompt`, but got {type(negative_prompt)} !=" + f" {type(prompt)}." + ) + elif batch_size != len(negative_prompt): + raise ValueError( + f"`negative_prompt`: {negative_prompt} has batch size {len(negative_prompt)}, but `prompt`:" + f" {prompt} has batch size {batch_size}. Please make sure that passed `negative_prompt` matches" + " the batch size of `prompt`." + ) + + negative_prompt_embed, negative_pooled_prompt_embed = self._get_clip_prompt_embeds( + negative_prompt, + device=device, + num_images_per_prompt=num_images_per_prompt, + clip_skip=None, + clip_model_index=0, + ) + negative_prompt_2_embed, negative_pooled_prompt_2_embed = self._get_clip_prompt_embeds( + negative_prompt_2, + device=device, + num_images_per_prompt=num_images_per_prompt, + clip_skip=None, + clip_model_index=1, + ) + negative_clip_prompt_embeds = torch.cat([negative_prompt_embed, negative_prompt_2_embed], dim=-1) + + t5_negative_prompt_embed = self._get_t5_prompt_embeds( + prompt=negative_prompt_3, + num_images_per_prompt=num_images_per_prompt, + max_sequence_length=max_sequence_length, + device=device, + ) + + negative_clip_prompt_embeds = torch.nn.functional.pad( + negative_clip_prompt_embeds, + (0, t5_negative_prompt_embed.shape[-1] - negative_clip_prompt_embeds.shape[-1]), + ) + + negative_prompt_embeds = torch.cat([negative_clip_prompt_embeds, t5_negative_prompt_embed], dim=-2) + negative_pooled_prompt_embeds = torch.cat( + [negative_pooled_prompt_embed, negative_pooled_prompt_2_embed], dim=-1 + ) + + if self.text_encoder is not None: + if isinstance(self, SD3LoraLoaderMixin) and USE_PEFT_BACKEND: + # Retrieve the original scale by scaling back the LoRA layers + unscale_lora_layers(self.text_encoder, lora_scale) + + if self.text_encoder_2 is not None: + if isinstance(self, SD3LoraLoaderMixin) and USE_PEFT_BACKEND: + # Retrieve the original scale by scaling back the LoRA layers + unscale_lora_layers(self.text_encoder_2, lora_scale) + + return prompt_embeds, negative_prompt_embeds, pooled_prompt_embeds, negative_pooled_prompt_embeds + + def check_inputs( + self, + prompt, + prompt_2, + prompt_3, + height, + width, + negative_prompt=None, + negative_prompt_2=None, + negative_prompt_3=None, + prompt_embeds=None, + negative_prompt_embeds=None, + pooled_prompt_embeds=None, + negative_pooled_prompt_embeds=None, + callback_on_step_end_tensor_inputs=None, + max_sequence_length=None, + ): + if height % 8 != 0 or width % 8 != 0: + raise ValueError(f"`height` and `width` have to be divisible by 8 but are {height} and {width}.") + + if callback_on_step_end_tensor_inputs is not None and not all( + k in self._callback_tensor_inputs for k in callback_on_step_end_tensor_inputs + ): + raise ValueError( + f"`callback_on_step_end_tensor_inputs` has to be in {self._callback_tensor_inputs}, but found {[k for k in callback_on_step_end_tensor_inputs if k not in self._callback_tensor_inputs]}" + ) + + if prompt is not None and prompt_embeds is not None: + raise ValueError( + f"Cannot forward both `prompt`: {prompt} and `prompt_embeds`: {prompt_embeds}. Please make sure to" + " only forward one of the two." + ) + elif prompt_2 is not None and prompt_embeds is not None: + raise ValueError( + f"Cannot forward both `prompt_2`: {prompt_2} and `prompt_embeds`: {prompt_embeds}. Please make sure to" + " only forward one of the two." + ) + elif prompt_3 is not None and prompt_embeds is not None: + raise ValueError( + f"Cannot forward both `prompt_3`: {prompt_2} and `prompt_embeds`: {prompt_embeds}. Please make sure to" + " only forward one of the two." + ) + elif prompt is None and prompt_embeds is None: + raise ValueError( + "Provide either `prompt` or `prompt_embeds`. Cannot leave both `prompt` and `prompt_embeds` undefined." + ) + elif prompt is not None and (not isinstance(prompt, str) and not isinstance(prompt, list)): + raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}") + elif prompt_2 is not None and (not isinstance(prompt_2, str) and not isinstance(prompt_2, list)): + raise ValueError(f"`prompt_2` has to be of type `str` or `list` but is {type(prompt_2)}") + elif prompt_3 is not None and (not isinstance(prompt_3, str) and not isinstance(prompt_3, list)): + raise ValueError(f"`prompt_3` has to be of type `str` or `list` but is {type(prompt_3)}") + + if negative_prompt is not None and negative_prompt_embeds is not None: + raise ValueError( + f"Cannot forward both `negative_prompt`: {negative_prompt} and `negative_prompt_embeds`:" + f" {negative_prompt_embeds}. Please make sure to only forward one of the two." + ) + elif negative_prompt_2 is not None and negative_prompt_embeds is not None: + raise ValueError( + f"Cannot forward both `negative_prompt_2`: {negative_prompt_2} and `negative_prompt_embeds`:" + f" {negative_prompt_embeds}. Please make sure to only forward one of the two." + ) + elif negative_prompt_3 is not None and negative_prompt_embeds is not None: + raise ValueError( + f"Cannot forward both `negative_prompt_3`: {negative_prompt_3} and `negative_prompt_embeds`:" + f" {negative_prompt_embeds}. Please make sure to only forward one of the two." + ) + + if prompt_embeds is not None and negative_prompt_embeds is not None: + if prompt_embeds.shape != negative_prompt_embeds.shape: + raise ValueError( + "`prompt_embeds` and `negative_prompt_embeds` must have the same shape when passed directly, but" + f" got: `prompt_embeds` {prompt_embeds.shape} != `negative_prompt_embeds`" + f" {negative_prompt_embeds.shape}." + ) + + if prompt_embeds is not None and pooled_prompt_embeds is None: + raise ValueError( + "If `prompt_embeds` are provided, `pooled_prompt_embeds` also have to be passed. Make sure to generate `pooled_prompt_embeds` from the same text encoder that was used to generate `prompt_embeds`." + ) + + if negative_prompt_embeds is not None and negative_pooled_prompt_embeds is None: + raise ValueError( + "If `negative_prompt_embeds` are provided, `negative_pooled_prompt_embeds` also have to be passed. Make sure to generate `negative_pooled_prompt_embeds` from the same text encoder that was used to generate `negative_prompt_embeds`." + ) + + if max_sequence_length is not None and max_sequence_length > 512: + raise ValueError(f"`max_sequence_length` cannot be greater than 512 but is {max_sequence_length}") + + def prepare_latents( + self, + batch_size, + num_channels_latents, + height, + width, + dtype, + device, + generator, + latents=None, + ): + if latents is not None: + return latents.to(device=device, dtype=dtype) + + shape = ( + batch_size, + num_channels_latents, + int(height) // self.vae_scale_factor, + int(width) // self.vae_scale_factor, + ) + + if isinstance(generator, list) and len(generator) != batch_size: + raise ValueError( + f"You have passed a list of generators of length {len(generator)}, but requested an effective batch" + f" size of {batch_size}. Make sure the batch size matches the length of the generators." + ) + + latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype) + + return latents + + @property + def guidance_scale(self): + return self._guidance_scale + + @property + def clip_skip(self): + return self._clip_skip + + # here `guidance_scale` is defined analog to the guidance weight `w` of equation (2) + # of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1` + # corresponds to doing no classifier free guidance. + @property + def do_classifier_free_guidance(self): + return self._guidance_scale > 1 + + @property + def joint_attention_kwargs(self): + return self._joint_attention_kwargs + + @property + def num_timesteps(self): + return self._num_timesteps + + @property + def interrupt(self): + return self._interrupt + + + @torch.inference_mode() + def init_ipadapter(self, ip_adapter_path, image_encoder_path, nb_token, output_dim=2432): + from transformers import SiglipVisionModel, SiglipImageProcessor + state_dict = torch.load(ip_adapter_path, map_location="cpu") + + device, dtype = self.transformer.device, self.transformer.dtype + image_encoder = SiglipVisionModel.from_pretrained(image_encoder_path) + image_processor = SiglipImageProcessor.from_pretrained(image_encoder_path) + image_encoder.eval() + image_encoder.to(device, dtype=dtype) + self.image_encoder = image_encoder + self.clip_image_processor = image_processor + + sample_class = TimeResampler + image_proj_model = sample_class( + dim=1280, + depth=4, + dim_head=64, + heads=20, + num_queries=nb_token, + embedding_dim=1152, + output_dim=output_dim, + ff_mult=4, + timestep_in_dim=320, + timestep_flip_sin_to_cos=True, + timestep_freq_shift=0, + ) + image_proj_model.eval() + image_proj_model.to(device, dtype=dtype) + key_name = image_proj_model.load_state_dict(state_dict["image_proj"], strict=False) + print(f"=> loading image_proj_model: {key_name}") + + self.image_proj_model = image_proj_model + + + attn_procs = {} + transformer = self.transformer + for idx_name, name in enumerate(transformer.attn_processors.keys()): + hidden_size = transformer.config.attention_head_dim * transformer.config.num_attention_heads + ip_hidden_states_dim = transformer.config.attention_head_dim * transformer.config.num_attention_heads + ip_encoder_hidden_states_dim = transformer.config.caption_projection_dim + + attn_procs[name] = JointIPAttnProcessor( + hidden_size=hidden_size, + cross_attention_dim=transformer.config.caption_projection_dim, + ip_hidden_states_dim=ip_hidden_states_dim, + ip_encoder_hidden_states_dim=ip_encoder_hidden_states_dim, + head_dim=transformer.config.attention_head_dim, + timesteps_emb_dim=1280, + ).to(device, dtype=dtype) + + self.transformer.set_attn_processor(attn_procs) + tmp_ip_layers = torch.nn.ModuleList(self.transformer.attn_processors.values()) + + key_name = tmp_ip_layers.load_state_dict(state_dict["ip_adapter"], strict=False) + print(f"=> loading ip_adapter: {key_name}") + + + @torch.inference_mode() + def encode_clip_image_emb(self, clip_image, device, dtype): + + # clip + clip_image_tensor = self.clip_image_processor(images=clip_image, return_tensors="pt").pixel_values + clip_image_tensor = clip_image_tensor.to(device, dtype=dtype) + clip_image_embeds = self.image_encoder(clip_image_tensor, output_hidden_states=True).hidden_states[-2] + clip_image_embeds = torch.cat([torch.zeros_like(clip_image_embeds), clip_image_embeds], dim=0) + + return clip_image_embeds + + + + @torch.no_grad() + @replace_example_docstring(EXAMPLE_DOC_STRING) + def __call__( + self, + prompt: Union[str, List[str]] = None, + prompt_2: Optional[Union[str, List[str]]] = None, + prompt_3: Optional[Union[str, List[str]]] = None, + height: Optional[int] = None, + width: Optional[int] = None, + num_inference_steps: int = 28, + timesteps: List[int] = None, + guidance_scale: float = 7.0, + negative_prompt: Optional[Union[str, List[str]]] = None, + negative_prompt_2: Optional[Union[str, List[str]]] = None, + negative_prompt_3: Optional[Union[str, List[str]]] = None, + num_images_per_prompt: Optional[int] = 1, + generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None, + latents: Optional[torch.FloatTensor] = None, + prompt_embeds: Optional[torch.FloatTensor] = None, + negative_prompt_embeds: Optional[torch.FloatTensor] = None, + pooled_prompt_embeds: Optional[torch.FloatTensor] = None, + negative_pooled_prompt_embeds: Optional[torch.FloatTensor] = None, + output_type: Optional[str] = "pil", + return_dict: bool = True, + joint_attention_kwargs: Optional[Dict[str, Any]] = None, + clip_skip: Optional[int] = None, + callback_on_step_end: Optional[Callable[[int, int, Dict], None]] = None, + callback_on_step_end_tensor_inputs: List[str] = ["latents"], + max_sequence_length: int = 256, + + # ipa + clip_image=None, + ipadapter_scale=1.0, + ): + r""" + Function invoked when calling the pipeline for generation. + + Args: + prompt (`str` or `List[str]`, *optional*): + The prompt or prompts to guide the image generation. If not defined, one has to pass `prompt_embeds`. + instead. + prompt_2 (`str` or `List[str]`, *optional*): + The prompt or prompts to be sent to `tokenizer_2` and `text_encoder_2`. If not defined, `prompt` is + will be used instead + prompt_3 (`str` or `List[str]`, *optional*): + The prompt or prompts to be sent to `tokenizer_3` and `text_encoder_3`. If not defined, `prompt` is + will be used instead + height (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor): + The height in pixels of the generated image. This is set to 1024 by default for the best results. + width (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor): + The width in pixels of the generated image. This is set to 1024 by default for the best results. + num_inference_steps (`int`, *optional*, defaults to 50): + The number of denoising steps. More denoising steps usually lead to a higher quality image at the + expense of slower inference. + timesteps (`List[int]`, *optional*): + Custom timesteps to use for the denoising process with schedulers which support a `timesteps` argument + in their `set_timesteps` method. If not defined, the default behavior when `num_inference_steps` is + passed will be used. Must be in descending order. + guidance_scale (`float`, *optional*, defaults to 7.0): + Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598). + `guidance_scale` is defined as `w` of equation 2. of [Imagen + Paper](https://arxiv.org/pdf/2205.11487.pdf). Guidance scale is enabled by setting `guidance_scale > + 1`. Higher guidance scale encourages to generate images that are closely linked to the text `prompt`, + usually at the expense of lower image quality. + negative_prompt (`str` or `List[str]`, *optional*): + The prompt or prompts not to guide the image generation. If not defined, one has to pass + `negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is + less than `1`). + negative_prompt_2 (`str` or `List[str]`, *optional*): + The prompt or prompts not to guide the image generation to be sent to `tokenizer_2` and + `text_encoder_2`. If not defined, `negative_prompt` is used instead + negative_prompt_3 (`str` or `List[str]`, *optional*): + The prompt or prompts not to guide the image generation to be sent to `tokenizer_3` and + `text_encoder_3`. If not defined, `negative_prompt` is used instead + num_images_per_prompt (`int`, *optional*, defaults to 1): + The number of images to generate per prompt. + generator (`torch.Generator` or `List[torch.Generator]`, *optional*): + One or a list of [torch generator(s)](https://pytorch.org/docs/stable/generated/torch.Generator.html) + to make generation deterministic. + latents (`torch.FloatTensor`, *optional*): + Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for image + generation. Can be used to tweak the same generation with different prompts. If not provided, a latents + tensor will ge generated by sampling using the supplied random `generator`. + prompt_embeds (`torch.FloatTensor`, *optional*): + Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not + provided, text embeddings will be generated from `prompt` input argument. + negative_prompt_embeds (`torch.FloatTensor`, *optional*): + Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt + weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input + argument. + pooled_prompt_embeds (`torch.FloatTensor`, *optional*): + Pre-generated pooled text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. + If not provided, pooled text embeddings will be generated from `prompt` input argument. + negative_pooled_prompt_embeds (`torch.FloatTensor`, *optional*): + Pre-generated negative pooled text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt + weighting. If not provided, pooled negative_prompt_embeds will be generated from `negative_prompt` + input argument. + output_type (`str`, *optional*, defaults to `"pil"`): + The output format of the generate image. Choose between + [PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`. + return_dict (`bool`, *optional*, defaults to `True`): + Whether or not to return a [`~pipelines.stable_diffusion_xl.StableDiffusionXLPipelineOutput`] instead + of a plain tuple. + joint_attention_kwargs (`dict`, *optional*): + A kwargs dictionary that if specified is passed along to the `AttentionProcessor` as defined under + `self.processor` in + [diffusers.models.attention_processor](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py). + callback_on_step_end (`Callable`, *optional*): + A function that calls at the end of each denoising steps during the inference. The function is called + with the following arguments: `callback_on_step_end(self: DiffusionPipeline, step: int, timestep: int, + callback_kwargs: Dict)`. `callback_kwargs` will include a list of all tensors as specified by + `callback_on_step_end_tensor_inputs`. + callback_on_step_end_tensor_inputs (`List`, *optional*): + The list of tensor inputs for the `callback_on_step_end` function. The tensors specified in the list + will be passed as `callback_kwargs` argument. You will only be able to include variables listed in the + `._callback_tensor_inputs` attribute of your pipeline class. + max_sequence_length (`int` defaults to 256): Maximum sequence length to use with the `prompt`. + + Examples: + + Returns: + [`~pipelines.stable_diffusion_3.StableDiffusion3PipelineOutput`] or `tuple`: + [`~pipelines.stable_diffusion_3.StableDiffusion3PipelineOutput`] if `return_dict` is True, otherwise a + `tuple`. When returning a tuple, the first element is a list with the generated images. + """ + + height = height or self.default_sample_size * self.vae_scale_factor + width = width or self.default_sample_size * self.vae_scale_factor + + # 1. Check inputs. Raise error if not correct + self.check_inputs( + prompt, + prompt_2, + prompt_3, + height, + width, + negative_prompt=negative_prompt, + negative_prompt_2=negative_prompt_2, + negative_prompt_3=negative_prompt_3, + prompt_embeds=prompt_embeds, + negative_prompt_embeds=negative_prompt_embeds, + pooled_prompt_embeds=pooled_prompt_embeds, + negative_pooled_prompt_embeds=negative_pooled_prompt_embeds, + callback_on_step_end_tensor_inputs=callback_on_step_end_tensor_inputs, + max_sequence_length=max_sequence_length, + ) + + self._guidance_scale = guidance_scale + self._clip_skip = clip_skip + self._joint_attention_kwargs = joint_attention_kwargs + self._interrupt = False + + # 2. Define call parameters + if prompt is not None and isinstance(prompt, str): + batch_size = 1 + elif prompt is not None and isinstance(prompt, list): + batch_size = len(prompt) + else: + batch_size = prompt_embeds.shape[0] + + device = self._execution_device + dtype = self.transformer.dtype + + lora_scale = ( + self.joint_attention_kwargs.get("scale", None) if self.joint_attention_kwargs is not None else None + ) + ( + prompt_embeds, + negative_prompt_embeds, + pooled_prompt_embeds, + negative_pooled_prompt_embeds, + ) = self.encode_prompt( + prompt=prompt, + prompt_2=prompt_2, + prompt_3=prompt_3, + negative_prompt=negative_prompt, + negative_prompt_2=negative_prompt_2, + negative_prompt_3=negative_prompt_3, + do_classifier_free_guidance=self.do_classifier_free_guidance, + prompt_embeds=prompt_embeds, + negative_prompt_embeds=negative_prompt_embeds, + pooled_prompt_embeds=pooled_prompt_embeds, + negative_pooled_prompt_embeds=negative_pooled_prompt_embeds, + device=device, + clip_skip=self.clip_skip, + num_images_per_prompt=num_images_per_prompt, + max_sequence_length=max_sequence_length, + lora_scale=lora_scale, + ) + + if self.do_classifier_free_guidance: + prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds], dim=0) + pooled_prompt_embeds = torch.cat([negative_pooled_prompt_embeds, pooled_prompt_embeds], dim=0) + + # 3. prepare clip emb + clip_image = clip_image.resize((max(clip_image.size), max(clip_image.size))) + clip_image_embeds = self.encode_clip_image_emb(clip_image, device, dtype) + + # 4. Prepare timesteps + timesteps, num_inference_steps = retrieve_timesteps(self.scheduler, num_inference_steps, device, timesteps) + num_warmup_steps = max(len(timesteps) - num_inference_steps * self.scheduler.order, 0) + self._num_timesteps = len(timesteps) + + # 5. Prepare latent variables + num_channels_latents = self.transformer.config.in_channels + latents = self.prepare_latents( + batch_size * num_images_per_prompt, + num_channels_latents, + height, + width, + prompt_embeds.dtype, + device, + generator, + latents, + ) + + # 6. Denoising loop + with self.progress_bar(total=num_inference_steps) as progress_bar: + for i, t in enumerate(timesteps): + if self.interrupt: + continue + + # expand the latents if we are doing classifier free guidance + latent_model_input = torch.cat([latents] * 2) if self.do_classifier_free_guidance else latents + # broadcast to batch dimension in a way that's compatible with ONNX/Core ML + timestep = t.expand(latent_model_input.shape[0]) + + image_prompt_embeds, timestep_emb = self.image_proj_model( + clip_image_embeds, + timestep.to(dtype=latents.dtype), + need_temb=True + ) + + joint_attention_kwargs = dict( + emb_dict=dict( + ip_hidden_states=image_prompt_embeds, + temb=timestep_emb, + scale=ipadapter_scale, + ) + ) + + noise_pred = self.transformer( + hidden_states=latent_model_input, + timestep=timestep, + encoder_hidden_states=prompt_embeds, + pooled_projections=pooled_prompt_embeds, + joint_attention_kwargs=joint_attention_kwargs, + return_dict=False, + )[0] + + # perform guidance + if self.do_classifier_free_guidance: + noise_pred_uncond, noise_pred_text = noise_pred.chunk(2) + noise_pred = noise_pred_uncond + self.guidance_scale * (noise_pred_text - noise_pred_uncond) + + # compute the previous noisy sample x_t -> x_t-1 + latents_dtype = latents.dtype + latents = self.scheduler.step(noise_pred, t, latents, return_dict=False)[0] + + if latents.dtype != latents_dtype: + if torch.backends.mps.is_available(): + # some platforms (eg. apple mps) misbehave due to a pytorch bug: https://github.com/pytorch/pytorch/pull/99272 + latents = latents.to(latents_dtype) + + if callback_on_step_end is not None: + callback_kwargs = {} + for k in callback_on_step_end_tensor_inputs: + callback_kwargs[k] = locals()[k] + callback_outputs = callback_on_step_end(self, i, t, callback_kwargs) + + latents = callback_outputs.pop("latents", latents) + prompt_embeds = callback_outputs.pop("prompt_embeds", prompt_embeds) + negative_prompt_embeds = callback_outputs.pop("negative_prompt_embeds", negative_prompt_embeds) + negative_pooled_prompt_embeds = callback_outputs.pop( + "negative_pooled_prompt_embeds", negative_pooled_prompt_embeds + ) + + # call the callback, if provided + if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0): + progress_bar.update() + + if XLA_AVAILABLE: + xm.mark_step() + + if output_type == "latent": + image = latents + + else: + latents = (latents / self.vae.config.scaling_factor) + self.vae.config.shift_factor + + image = self.vae.decode(latents, return_dict=False)[0] + image = self.image_processor.postprocess(image, output_type=output_type) + + # Offload all models + self.maybe_free_model_hooks() + + if not return_dict: + return (image,) + + return StableDiffusion3PipelineOutput(images=image) \ No newline at end of file diff --git a/src/diffusers/models/attention.py b/src/diffusers/models/attention.py index 02ed1f965abf..37af68a9b7f1 100644 --- a/src/diffusers/models/attention.py +++ b/src/diffusers/models/attention.py @@ -188,7 +188,8 @@ def set_chunk_feed_forward(self, chunk_size: Optional[int], dim: int = 0): self._chunk_dim = dim def forward( - self, hidden_states: torch.FloatTensor, encoder_hidden_states: torch.FloatTensor, temb: torch.FloatTensor + self, hidden_states: torch.FloatTensor, encoder_hidden_states: torch.FloatTensor, temb: torch.FloatTensor, + joint_attention_kwargs=None, ): if self.use_dual_attention: norm_hidden_states, gate_msa, shift_mlp, scale_mlp, gate_mlp, norm_hidden_states2, gate_msa2 = self.norm1( @@ -206,7 +207,8 @@ def forward( # Attention. attn_output, context_attn_output = self.attn( - hidden_states=norm_hidden_states, encoder_hidden_states=norm_encoder_hidden_states + hidden_states=norm_hidden_states, encoder_hidden_states=norm_encoder_hidden_states, + **({} if joint_attention_kwargs is None else joint_attention_kwargs), ) # Process attention outputs for the `hidden_states`. @@ -214,7 +216,7 @@ def forward( hidden_states = hidden_states + attn_output if self.use_dual_attention: - attn_output2 = self.attn2(hidden_states=norm_hidden_states2) + attn_output2 = self.attn2(hidden_states=norm_hidden_states2, **({} if joint_attention_kwargs is None else joint_attention_kwargs),) attn_output2 = gate_msa2.unsqueeze(1) * attn_output2 hidden_states = hidden_states + attn_output2 diff --git a/src/diffusers/models/transformers/transformer_sd3.py b/src/diffusers/models/transformers/transformer_sd3.py index a89a5e26ee97..f916c08fe557 100644 --- a/src/diffusers/models/transformers/transformer_sd3.py +++ b/src/diffusers/models/transformers/transformer_sd3.py @@ -268,7 +268,6 @@ def forward( block_controlnet_hidden_states: List = None, joint_attention_kwargs: Optional[Dict[str, Any]] = None, return_dict: bool = True, - skip_layers: Optional[List[int]] = None, ) -> Union[torch.FloatTensor, Transformer2DModelOutput]: """ The [`SD3Transformer2DModel`] forward method. @@ -291,8 +290,6 @@ def forward( return_dict (`bool`, *optional*, defaults to `True`): Whether or not to return a [`~models.transformer_2d.Transformer2DModelOutput`] instead of a plain tuple. - skip_layers (`list` of `int`, *optional*): - A list of layer indices to skip during the forward pass. Returns: If `return_dict` is True, an [`~models.transformer_2d.Transformer2DModelOutput`] is returned, otherwise a @@ -320,10 +317,7 @@ def forward( encoder_hidden_states = self.context_embedder(encoder_hidden_states) for index_block, block in enumerate(self.transformer_blocks): - # Skip specified layers - is_skip = True if skip_layers is not None and index_block in skip_layers else False - - if torch.is_grad_enabled() and self.gradient_checkpointing and not is_skip: + if self.training and self.gradient_checkpointing: def create_custom_forward(module, return_dict=None): def custom_forward(*inputs): @@ -340,11 +334,14 @@ def custom_forward(*inputs): hidden_states, encoder_hidden_states, temb, + joint_attention_kwargs, **ckpt_kwargs, ) - elif not is_skip: + + else: encoder_hidden_states, hidden_states = block( - hidden_states=hidden_states, encoder_hidden_states=encoder_hidden_states, temb=temb + hidden_states=hidden_states, encoder_hidden_states=encoder_hidden_states, temb=temb, + joint_attention_kwargs=joint_attention_kwargs, ) # controlnet residual From 55674387f29f8aa370d3e5f51ec9f91565d21540 Mon Sep 17 00:00:00 2001 From: Daniel Regado <danielregado@gmail.com> Date: Fri, 6 Dec 2024 15:09:06 +0000 Subject: [PATCH 02/44] Added support for single IPAdapter on SD3.5 pipeline --- src/diffusers/models/attention.py | 6 +- src/diffusers/models/attention_processor.py | 140 ++++++++ src/diffusers/models/embeddings.py | 148 +++++++++ .../pipeline_stable_diffusion_3.py | 302 +++++++++++++++++- 4 files changed, 583 insertions(+), 13 deletions(-) diff --git a/src/diffusers/models/attention.py b/src/diffusers/models/attention.py index 37af68a9b7f1..bb761d1a99c2 100644 --- a/src/diffusers/models/attention.py +++ b/src/diffusers/models/attention.py @@ -189,7 +189,7 @@ def set_chunk_feed_forward(self, chunk_size: Optional[int], dim: int = 0): def forward( self, hidden_states: torch.FloatTensor, encoder_hidden_states: torch.FloatTensor, temb: torch.FloatTensor, - joint_attention_kwargs=None, + joint_attention_kwargs: Dict[str, Any] = {} ): if self.use_dual_attention: norm_hidden_states, gate_msa, shift_mlp, scale_mlp, gate_mlp, norm_hidden_states2, gate_msa2 = self.norm1( @@ -208,7 +208,7 @@ def forward( # Attention. attn_output, context_attn_output = self.attn( hidden_states=norm_hidden_states, encoder_hidden_states=norm_encoder_hidden_states, - **({} if joint_attention_kwargs is None else joint_attention_kwargs), + **joint_attention_kwargs ) # Process attention outputs for the `hidden_states`. @@ -216,7 +216,7 @@ def forward( hidden_states = hidden_states + attn_output if self.use_dual_attention: - attn_output2 = self.attn2(hidden_states=norm_hidden_states2, **({} if joint_attention_kwargs is None else joint_attention_kwargs),) + attn_output2 = self.attn2(hidden_states=norm_hidden_states2, **joint_attention_kwargs) attn_output2 = gate_msa2.unsqueeze(1) * attn_output2 hidden_states = hidden_states + attn_output2 diff --git a/src/diffusers/models/attention_processor.py b/src/diffusers/models/attention_processor.py index ffbf4a0056c6..5bd84452397b 100644 --- a/src/diffusers/models/attention_processor.py +++ b/src/diffusers/models/attention_processor.py @@ -18,6 +18,7 @@ import torch import torch.nn.functional as F from torch import nn +from einops import rearrange from ..image_processor import IPAdapterMaskProcessor from ..utils import deprecate, logging @@ -4800,6 +4801,144 @@ def __call__( hidden_states = hidden_states / attn.rescale_output_factor return hidden_states + + +class IPAdapterJointAttnProcessor2_0(torch.nn.Module): + """Attention processor for IP-Adapter used typically in processing the SD3-like self-attention projections.""" + + def __init__( + self, + hidden_size: int, + ip_hidden_states_dim: int, + head_dim: int, + timesteps_emb_dim: int = 1280, + scale: float = 0.5 + ): + super().__init__() + + # To prevent circular import + from .normalization import RMSNorm, AdaLayerNorm + + self.norm_ip = AdaLayerNorm(timesteps_emb_dim, output_dim=ip_hidden_states_dim * 2, + norm_eps=1e-6, chunk_dim=1) + self.to_k_ip = nn.Linear(ip_hidden_states_dim, hidden_size, bias=False) + self.to_v_ip = nn.Linear(ip_hidden_states_dim, hidden_size, bias=False) + self.norm_q = RMSNorm(head_dim, 1e-6) + self.norm_k = RMSNorm(head_dim, 1e-6) + self.norm_ip_k = RMSNorm(head_dim, 1e-6) + self.scale = scale + + def __call__( + self, + attn: Attention, + hidden_states: torch.FloatTensor, + encoder_hidden_states: torch.FloatTensor = None, + attention_mask: Optional[torch.FloatTensor] = None, + ip_hidden_states: torch.FloatTensor = None, + temb: torch.FloatTensor = None + ) -> torch.FloatTensor: + residual = hidden_states + + batch_size = hidden_states.shape[0] + + # `sample` projections. + query = attn.to_q(hidden_states) + key = attn.to_k(hidden_states) + value = attn.to_v(hidden_states) + img_query = query + img_key = key + img_value = value + + inner_dim = key.shape[-1] + head_dim = inner_dim // attn.heads + + query = query.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2) + key = key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2) + value = value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2) + + if attn.norm_q is not None: + query = attn.norm_q(query) + if attn.norm_k is not None: + key = attn.norm_k(key) + + # `context` projections. + if encoder_hidden_states is not None: + encoder_hidden_states_query_proj = attn.add_q_proj(encoder_hidden_states) + encoder_hidden_states_key_proj = attn.add_k_proj(encoder_hidden_states) + encoder_hidden_states_value_proj = attn.add_v_proj(encoder_hidden_states) + + encoder_hidden_states_query_proj = encoder_hidden_states_query_proj.view( + batch_size, -1, attn.heads, head_dim + ).transpose(1, 2) + encoder_hidden_states_key_proj = encoder_hidden_states_key_proj.view( + batch_size, -1, attn.heads, head_dim + ).transpose(1, 2) + encoder_hidden_states_value_proj = encoder_hidden_states_value_proj.view( + batch_size, -1, attn.heads, head_dim + ).transpose(1, 2) + + if attn.norm_added_q is not None: + encoder_hidden_states_query_proj = attn.norm_added_q(encoder_hidden_states_query_proj) + if attn.norm_added_k is not None: + encoder_hidden_states_key_proj = attn.norm_added_k(encoder_hidden_states_key_proj) + + query = torch.cat([query, encoder_hidden_states_query_proj], dim=2) + key = torch.cat([key, encoder_hidden_states_key_proj], dim=2) + value = torch.cat([value, encoder_hidden_states_value_proj], dim=2) + + hidden_states = F.scaled_dot_product_attention(query, key, value, dropout_p=0.0, is_causal=False) + hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim) + hidden_states = hidden_states.to(query.dtype) + + if encoder_hidden_states is not None: + # Split the attention outputs. + hidden_states, encoder_hidden_states = ( + hidden_states[:, : residual.shape[1]], + hidden_states[:, residual.shape[1] :], + ) + if not attn.context_pre_only: + encoder_hidden_states = attn.to_add_out(encoder_hidden_states) + + # IP Adapter + if self.scale != 0 and ip_hidden_states is not None: + # Norm image features + norm_ip_hidden_states = self.norm_ip(ip_hidden_states, temb=temb) + + # To k and v + ip_key = self.to_k_ip(norm_ip_hidden_states) + ip_value = self.to_v_ip(norm_ip_hidden_states) + + # Reshape + img_query = rearrange(img_query, 'b l (h d) -> b h l d', h=attn.heads) + img_key = rearrange(img_key, 'b l (h d) -> b h l d', h=attn.heads) + img_value = rearrange(img_value, 'b l (h d) -> b h l d', h=attn.heads) + ip_key = rearrange(ip_key, 'b l (h d) -> b h l d', h=attn.heads) + ip_value = rearrange(ip_value, 'b l (h d) -> b h l d', h=attn.heads) + + # Norm + img_query = self.norm_q(img_query) + img_key = self.norm_k(img_key) + ip_key = self.norm_ip_k(ip_key) + + # cat img + img_key = torch.cat([img_key, ip_key], dim=2) + img_value = torch.cat([img_value, ip_value], dim=2) + + ip_hidden_states = F.scaled_dot_product_attention(img_query, img_key, img_value, dropout_p=0.0, is_causal=False) + ip_hidden_states = rearrange(ip_hidden_states, 'b h l d -> b l (h d)') + ip_hidden_states = ip_hidden_states.to(img_query.dtype) + + hidden_states = hidden_states + ip_hidden_states * self.scale + + # linear proj + hidden_states = attn.to_out[0](hidden_states) + # dropout + hidden_states = attn.to_out[1](hidden_states) + + if encoder_hidden_states is not None: + return hidden_states, encoder_hidden_states + else: + return hidden_states class PAGIdentitySelfAttnProcessor2_0: @@ -5089,6 +5228,7 @@ def __init__(self): IPAdapterAttnProcessor, IPAdapterAttnProcessor2_0, IPAdapterXFormersAttnProcessor, + IPAdapterJointAttnProcessor2_0, PAGIdentitySelfAttnProcessor2_0, PAGCFGIdentitySelfAttnProcessor2_0, LoRAAttnProcessor, diff --git a/src/diffusers/models/embeddings.py b/src/diffusers/models/embeddings.py index 80775d477c0d..10e8e49c7321 100644 --- a/src/diffusers/models/embeddings.py +++ b/src/diffusers/models/embeddings.py @@ -1999,6 +1999,154 @@ def forward(self, id_embeds: torch.Tensor) -> torch.Tensor: return out +# Modified from https://github.com/mlfoundations/open_flamingo/blob/main/open_flamingo/src/helpers.py +class TimePerceiverAttention(nn.Module): + def __init__( + self, + *, + dim: int, + dim_head: int = 64, + heads: int = 8, + ) -> None: + super().__init__() + + self.scale = dim_head ** -0.5 + self.dim_head = dim_head + self.heads = heads + inner_dim = dim_head * heads + + self.norm1 = nn.LayerNorm(dim) + self.norm2 = nn.LayerNorm(dim) + + self.to_q = nn.Linear(dim, inner_dim, bias=False) + self.to_kv = nn.Linear(dim, inner_dim * 2, bias=False) + self.to_out = nn.Linear(inner_dim, dim, bias=False) + + def forward(self, x, latents, shift=None, scale=None): + """ + Args: + x (torch.Tensor): image features + shape (b, n1, D) + latent (torch.Tensor): latent features + shape (b, n2, D) + """ + def reshape_tensor(x, heads): + bs, length, _ = x.shape + # (bs, length, width) --> (bs, length, n_heads, dim_per_head) + x = x.view(bs, length, heads, -1) + # (bs, length, n_heads, dim_per_head) --> (bs, n_heads, length, dim_per_head) + x = x.transpose(1, 2) + # (bs, n_heads, length, dim_per_head) --> (bs*n_heads, length, dim_per_head) + return x.reshape(bs, heads, length, -1) + + x = self.norm1(x) + latents = self.norm2(latents) + + if shift is not None and scale is not None: + latents = latents * (1 + scale.unsqueeze(1)) + shift.unsqueeze(1) + + b, l, _ = latents.shape + + q = self.to_q(latents) + kv_input = torch.cat((x, latents), dim=-2) + k, v = self.to_kv(kv_input).chunk(2, dim=-1) + + q = reshape_tensor(q, self.heads) + k = reshape_tensor(k, self.heads) + v = reshape_tensor(v, self.heads) + + # attention + scale = 1 / math.sqrt(math.sqrt(self.dim_head)) + weight = (q * scale) @ (k * scale).transpose(-2, -1) # More stable with f16 than dividing afterwards + weight = torch.softmax(weight.float(), dim=-1).type(weight.dtype) + out = weight @ v + + out = out.permute(0, 2, 1, 3).reshape(b, l, -1) + + return self.to_out(out) + + +# Modified from https://github.com/mlfoundations/open_flamingo/blob/main/open_flamingo/src/helpers.py +class TimePerceiverResampler(nn.Module): + def __init__( + self, + embed_dim: int = 1152, + output_dim: int = 2432, + hidden_dim: int = 1280, + depth: int = 4, + dim_head: int = 64, + heads: int = 20, + num_queries: int = 64, + ffn_ratio: int = 4, + timestep_in_dim: int = 320, + timestep_flip_sin_to_cos: bool = True, + timestep_freq_shift: int = 0, + ) -> None: + super().__init__() + + self.latents = nn.Parameter(torch.randn(1, num_queries, hidden_dim) / hidden_dim ** 0.5) + self.proj_in = nn.Linear(embed_dim, hidden_dim) + self.proj_out = nn.Linear(hidden_dim, output_dim) + self.norm_out = nn.LayerNorm(output_dim) + + ff_inner_dim = int(hidden_dim * ffn_ratio) + self.layers = nn.ModuleList([]) + for _ in range(depth): + self.layers.append( + nn.ModuleList( + [ + # msa + TimePerceiverAttention(dim=hidden_dim, dim_head=dim_head, heads=heads), + # ff + nn.Sequential( + nn.LayerNorm(hidden_dim), + nn.Linear(hidden_dim, ff_inner_dim, bias=False), + nn.GELU(), + nn.Linear(ff_inner_dim, hidden_dim, bias=False), + ), + # adaLN + nn.Sequential( + nn.SiLU(), + nn.Linear(hidden_dim, ff_inner_dim, bias=True) + ) + ] + ) + ) + + # Time + self.time_proj = Timesteps(timestep_in_dim, timestep_flip_sin_to_cos, timestep_freq_shift) + self.time_embedding = TimestepEmbedding(timestep_in_dim, hidden_dim, act_fn="silu") + + def forward(self, x, timestep, need_temb=False): + timestep_emb = self.time_proj(timestep).to(dtype=x.dtype) + timestep_emb = self.time_embedding(timestep_emb, None) + + latents = self.latents.repeat(x.size(0), 1, 1) + + x = self.proj_in(x) + x = x + timestep_emb[:, None] + + for attn, ff, adaLN_modulation in self.layers: + shift_msa, scale_msa, shift_mlp, scale_mlp = adaLN_modulation(timestep_emb).chunk(4, dim=1) + latents = attn(x, latents, shift_msa, scale_msa) + latents + + res = latents + for idx_ff in range(len(ff)): + layer_ff = ff[idx_ff] + latents = layer_ff(latents) + if idx_ff == 0 and isinstance(layer_ff, nn.LayerNorm): # adaLN + latents = latents * (1 + scale_mlp.unsqueeze(1)) + shift_mlp.unsqueeze(1) + latents = latents + res + + latents = self.proj_out(latents) + latents = self.norm_out(latents) + + if need_temb: + return latents, timestep_emb + else: + return latents + + class MultiIPAdapterImageProjection(nn.Module): def __init__(self, IPAdapterImageProjectionLayers: Union[List[nn.Module], Tuple[nn.Module]]): super().__init__() diff --git a/src/diffusers/pipelines/stable_diffusion_3/pipeline_stable_diffusion_3.py b/src/diffusers/pipelines/stable_diffusion_3/pipeline_stable_diffusion_3.py index a77231cdc02d..65fdbbe2a0d7 100644 --- a/src/diffusers/pipelines/stable_diffusion_3/pipeline_stable_diffusion_3.py +++ b/src/diffusers/pipelines/stable_diffusion_3/pipeline_stable_diffusion_3.py @@ -1,4 +1,4 @@ -# Copyright 2024 Stability AI and The HuggingFace Team. All rights reserved. +# Copyright 2024 Stability AI, The HuggingFace Team and The InstantX Team. All rights reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -16,18 +16,29 @@ from typing import Any, Callable, Dict, List, Optional, Union import torch +from safetensors import safe_open from transformers import ( CLIPTextModelWithProjection, CLIPTokenizer, T5EncoderModel, T5TokenizerFast, + PreTrainedModel, + BaseImageProcessor, ) -from ...image_processor import VaeImageProcessor +from ...image_processor import VaeImageProcessor, PipelineImageInput from ...loaders import FromSingleFileMixin, SD3LoraLoaderMixin from ...models.autoencoders import AutoencoderKL from ...models.transformers import SD3Transformer2DModel +from ...models.embeddings import TimePerceiverResampler from ...schedulers import FlowMatchEulerDiscreteScheduler +from ...models.attention_processor import IPAdapterJointAttnProcessor2_0 +from ...models.modeling_utils import ( + load_model_dict_into_meta, + _LOW_CPU_MEM_USAGE_DEFAULT, + load_state_dict, +) +from huggingface_hub.utils import validate_hf_hub_args from ...utils import ( USE_PEFT_BACKEND, is_torch_xla_available, @@ -35,6 +46,9 @@ replace_example_docstring, scale_lora_layers, unscale_lora_layers, + is_torch_version, + is_accelerate_available, + _get_model_file, ) from ...utils.torch_utils import randn_tensor from ..pipeline_utils import DiffusionPipeline @@ -160,10 +174,14 @@ class StableDiffusion3Pipeline(DiffusionPipeline, SD3LoraLoaderMixin, FromSingle tokenizer_3 (`T5TokenizerFast`): Tokenizer of class [T5Tokenizer](https://huggingface.co/docs/transformers/model_doc/t5#transformers.T5Tokenizer). + image_encoder (`PreTrainedModel`, *optional*): + Pre-trained Vision Model for IP Adapter. + feature_extractor (`BaseImageProcessor`, *optional*): + Image processor for IP Adapter. """ model_cpu_offload_seq = "text_encoder->text_encoder_2->text_encoder_3->transformer->vae" - _optional_components = [] + _optional_components = ["image_encoder", "feature_extractor"] _callback_tensor_inputs = ["latents", "prompt_embeds", "negative_prompt_embeds", "negative_pooled_prompt_embeds"] def __init__( @@ -177,6 +195,8 @@ def __init__( tokenizer_2: CLIPTokenizer, text_encoder_3: T5EncoderModel, tokenizer_3: T5TokenizerFast, + image_encoder: PreTrainedModel = None, + feature_extractor: BaseImageProcessor = None ): super().__init__() @@ -190,6 +210,8 @@ def __init__( tokenizer_3=tokenizer_3, transformer=transformer, scheduler=scheduler, + image_encoder=image_encoder, + feature_extractor=feature_extractor ) self.vae_scale_factor = ( 2 ** (len(self.vae.config.block_out_channels) - 1) if hasattr(self, "vae") and self.vae is not None else 8 @@ -668,7 +690,228 @@ def num_timesteps(self): @property def interrupt(self): return self._interrupt + + @validate_hf_hub_args + def load_ip_adapter( + self, + pretrained_model_name_or_path_or_dict: Union[str, Dict[str, torch.Tensor]], + weight_name: str, + subfolder: Optional[str] = None, + **kwargs, + ): + """ + Parameters: + pretrained_model_name_or_path_or_dict (`str` or `dict`): + Can be either: + - A string, the *model id* (for example `google/ddpm-celebahq-256`) of a pretrained model hosted on + the Hub. + - A path to a *directory* (for example `./my_model_directory`) containing the model weights saved + with [`ModelMixin.save_pretrained`]. + - A [torch state + dict](https://pytorch.org/tutorials/beginner/saving_loading_models.html#what-is-a-state-dict). + weight_name (`str`): + The name of the weight file to load. + subfolder (`str, *optional*): + The subfolder location of a model file within a larger model repository on the Hub or locally. + cache_dir (`Union[str, os.PathLike]`, *optional*): + Path to a directory where a downloaded pretrained model configuration is cached if the standard cache + is not used. + force_download (`bool`, *optional*, defaults to `False`): + Whether or not to force the (re-)download of the model weights and configuration files, overriding the + cached versions if they exist. + proxies (`Dict[str, str]`, *optional*): + A dictionary of proxy servers to use by protocol or endpoint, for example, `{'http': 'foo.bar:3128', + 'http://hostname': 'foo.bar:4012'}`. The proxies are used on each request. + local_files_only (`bool`, *optional*, defaults to `False`): + Whether to only load local model weights and configuration files or not. If set to `True`, the model + won't be downloaded from the Hub. + token (`str` or *bool*, *optional*): + The token to use as HTTP bearer authorization for remote files. If `True`, the token generated from + `diffusers-cli login` (stored in `~/.huggingface`) is used. + revision (`str`, *optional*, defaults to `"main"`): + The specific model version to use. It can be a branch name, a tag name, a commit id, or any identifier + allowed by Git. + low_cpu_mem_usage (`bool`, *optional*, defaults to `True` if torch version >= 1.9.0 else `False`): + Speed up model loading only loading the pretrained weights and not initializing the weights. This also + tries to not use more than 1x model size in CPU memory (including peak memory) while loading the model. + Only supported for PyTorch >= 1.9.0. If you are using an older version of PyTorch, setting this + argument to `True` will raise an error. + """ + + # Load the main state dict first + cache_dir = kwargs.pop("cache_dir", None) + force_download = kwargs.pop("force_download", False) + proxies = kwargs.pop("proxies", None) + local_files_only = kwargs.pop("local_files_only", None) + token = kwargs.pop("token", None) + revision = kwargs.pop("revision", None) + low_cpu_mem_usage = kwargs.pop("low_cpu_mem_usage", _LOW_CPU_MEM_USAGE_DEFAULT) + + if low_cpu_mem_usage: + if is_accelerate_available(): + from accelerate import init_empty_weights + + else: + low_cpu_mem_usage = False + logger.warning( + "Cannot initialize model with low cpu memory usage because `accelerate` was not found in the" + " environment. Defaulting to `low_cpu_mem_usage=False`. It is strongly recommended to install" + " `accelerate` for faster and less memory-intense model loading. You can do so with: \n```\npip" + " install accelerate\n```\n." + ) + + if low_cpu_mem_usage is True and not is_torch_version(">=", "1.9.0"): + raise NotImplementedError( + "Low memory initialization requires torch >= 1.9.0. Please either update your PyTorch version or set" + " `low_cpu_mem_usage=False`." + ) + + user_agent = { + "file_type": "attn_procs_weights", + "framework": "pytorch", + } + + if not isinstance(pretrained_model_name_or_path_or_dict, dict): + model_file = _get_model_file( + pretrained_model_name_or_path_or_dict, + weights_name=weight_name, + cache_dir=cache_dir, + force_download=force_download, + proxies=proxies, + local_files_only=local_files_only, + token=token, + revision=revision, + subfolder=subfolder, + user_agent=user_agent, + ) + if weight_name.endswith(".safetensors"): + state_dict = {"image_proj": {}, "ip_adapter": {}} + with safe_open(model_file, framework="pt", device="cpu") as f: + for key in f.keys(): + if key.startswith("image_proj."): + state_dict["image_proj"][key.replace("image_proj.", "")] = f.get_tensor(key) + elif key.startswith("ip_adapter."): + state_dict["ip_adapter"][key.replace("ip_adapter.", "")] = f.get_tensor(key) + else: + state_dict = load_state_dict(model_file) + else: + state_dict = pretrained_model_name_or_path_or_dict + + if list(state_dict.keys()) != ["image_proj", "ip_adapter"]: + raise ValueError("Required keys are (`image_proj` and `ip_adapter`) missing from the state dict.") + + # Load ip_adapter + hidden_size = self.transformer.config.attention_head_dim * self.transformer.config.num_attention_heads + ip_hidden_states_dim = self.transformer.config.attention_head_dim * self.transformer.config.num_attention_heads + timesteps_emb_dim = state_dict["ip_adapter"]["0.norm_ip.linear.weight"].shape[1] + + # State dict by layer + layer_state_dict = {idx: {} for idx in range(len(self.transformer.attn_processors))} + for key, weights in state_dict["ip_adapter"].items(): + idx, name = key.split(".", 1) + layer_state_dict[int(idx)][name] = weights + + attn_procs = {} + for idx, name in enumerate(self.transformer.attn_processors.keys()): + attn_procs[name] = IPAdapterJointAttnProcessor2_0( + hidden_size=hidden_size, + ip_hidden_states_dim=ip_hidden_states_dim, + head_dim=self.transformer.config.attention_head_dim, + timesteps_emb_dim=timesteps_emb_dim, + ).to(self.device, dtype=self.dtype) + + if not low_cpu_mem_usage: + attn_procs[name].load_state_dict(layer_state_dict[idx]) + else: + load_model_dict_into_meta(attn_procs[name], layer_state_dict, device=self.device, dtype=self.dtype) + + self.transformer.set_attn_processor(attn_procs) + + # Load image_proj + embed_dim = state_dict["image_proj"]["proj_in.weight"].shape[1] + output_dim = state_dict["image_proj"]["proj_out.weight"].shape[0] + hidden_dim = state_dict["image_proj"]["latents"].shape[2] + heads = state_dict["image_proj"]["layers.0.0.to_q.weight"].shape[0] // 64 + num_queries = state_dict["image_proj"]["latents"].shape[1] + timestep_in_dim = state_dict["image_proj"]["time_embedding.linear_1.weight"].shape[1] + + self.image_proj = TimePerceiverResampler( + embed_dim=embed_dim, + output_dim=output_dim, + hidden_dim=hidden_dim, + heads=heads, + num_queries=num_queries, + timestep_in_dim=timestep_in_dim + ).to(device=self.device, dtype=self.dtype) + + if not low_cpu_mem_usage: + self.image_proj.load_state_dict(state_dict["image_proj"], strict=True) + else: + load_model_dict_into_meta(self.image_proj, state_dict["image_proj"], device=self.device, dtype=self.dtype) + + def set_ip_adapter_scale(self, scale): + """ + Controls image/text prompt conditioning. A value of 1.0 means the model is only conditioned on the image prompt, and 0.0 + only conditioned by the text prompt. Lowering this value encourages the model to produce more diverse images, but they + may not be as aligned with the image prompt. + """ + for attn_processor in self.transformes.attn_processors.values(): + if isinstance(attn_processor, IPAdapterJointAttnProcessor2_0): + attn_processor.scale = scale + + # Adapted from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_xl.StableDiffusionXLPipeline.encode_image + def encode_image(self, image): + if not isinstance(image, torch.Tensor): + image = self.feature_extractor(image, return_tensors="pt").pixel_values + + image = image.to(device=self.device, dtype=self.dtype) + image_enc_hidden_states = self.image_encoder(image, output_hidden_states=True).hidden_states[-2] + uncond_image_enc_hidden_states = self.image_encoder(torch.zeros_like(image), output_hidden_states=True).hidden_states[-2] + + return image_enc_hidden_states, uncond_image_enc_hidden_states + + # Adapted from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_xl.StableDiffusionXLPipeline.prepare_ip_adapter_image_embeds + def prepare_ip_adapter_image_embeds( + self, ip_adapter_image, ip_adapter_image_embeds, device, num_images_per_prompt, do_classifier_free_guidance + ): + # image_embeds = [] + + # if do_classifier_free_guidance: + # negative_image_embeds = [] + + # if ip_adapter_image_embeds is None: + # single_image_embeds, single_negative_image_embeds = self.encode_image(ip_adapter_image) + # image_embeds.append(single_image_embeds[None, :]) + + # if do_classifier_free_guidance: + # negative_image_embeds.append(single_negative_image_embeds[None, :]) + # else: + # for single_image_embeds in ip_adapter_image_embeds: + # if do_classifier_free_guidance: + # single_negative_image_embeds, single_image_embeds = single_image_embeds.chunk(2) + # negative_image_embeds.append(single_negative_image_embeds) + # image_embeds.append(single_image_embeds) + + # ip_adapter_image_embeds = [] + # for i, single_image_embeds in enumerate(image_embeds): + # single_image_embeds = torch.cat([single_image_embeds] * num_images_per_prompt, dim=0) + + # if do_classifier_free_guidance: + # single_negative_image_embeds = torch.cat([negative_image_embeds[i]] * num_images_per_prompt, dim=0) + # single_image_embeds = torch.cat([single_negative_image_embeds, single_image_embeds], dim=0) + + # single_image_embeds = single_image_embeds.to(device=device) + # ip_adapter_image_embeds.append(single_image_embeds) + + + # Single image only :/ + clip_image_tensor = self.feature_extractor(images=ip_adapter_image, return_tensors="pt").pixel_values + clip_image_tensor = clip_image_tensor.to(device, dtype=self.dtype) + clip_image_embeds = self.image_encoder(clip_image_tensor, output_hidden_states=True).hidden_states[-2] + + return torch.cat([torch.zeros_like(clip_image_embeds), clip_image_embeds], dim=0) + @torch.no_grad() @replace_example_docstring(EXAMPLE_DOC_STRING) def __call__( @@ -691,17 +934,19 @@ def __call__( negative_prompt_embeds: Optional[torch.FloatTensor] = None, pooled_prompt_embeds: Optional[torch.FloatTensor] = None, negative_pooled_prompt_embeds: Optional[torch.FloatTensor] = None, + ip_adapter_image: Optional[PipelineImageInput] = None, + ip_adapter_image_embeds: Optional[List[torch.Tensor]] = None, output_type: Optional[str] = "pil", return_dict: bool = True, - joint_attention_kwargs: Optional[Dict[str, Any]] = None, + joint_attention_kwargs: Dict[str, Any] = {}, clip_skip: Optional[int] = None, callback_on_step_end: Optional[Callable[[int, int, Dict], None]] = None, callback_on_step_end_tensor_inputs: List[str] = ["latents"], max_sequence_length: int = 256, skip_guidance_layers: List[int] = None, - skip_layer_guidance_scale: int = 2.8, - skip_layer_guidance_stop: int = 0.2, - skip_layer_guidance_start: int = 0.01, + skip_layer_guidance_scale: float = 2.8, + skip_layer_guidance_stop: float = 0.2, + skip_layer_guidance_start: float = 0.01, ): r""" Function invoked when calling the pipeline for generation. @@ -766,6 +1011,12 @@ def __call__( Pre-generated negative pooled text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not provided, pooled negative_prompt_embeds will be generated from `negative_prompt` input argument. + ip_adapter_image (`PipelineImageInput`, *optional*): Optional image input to work with IP Adapters. + ip_adapter_image_embeds (`List[torch.Tensor]`, *optional*): + Pre-generated image embeddings for IP-Adapter. It should be a list of length same as number of + IP-adapters. Each element should be a tensor of shape `(batch_size, num_images, emb_dim)`. It should + contain the negative image embedding if `do_classifier_free_guidance` is set to `True`. If not + provided, embeddings are computed from the `ip_adapter_image` input argument. output_type (`str`, *optional*, defaults to `"pil"`): The output format of the generate image. Choose between [PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`. @@ -900,7 +1151,17 @@ def __call__( latents, ) - # 6. Denoising loop + # 6. Prepare image embeddings + if ip_adapter_image is not None or ip_adapter_image_embeds is not None: + image_embeds = self.prepare_ip_adapter_image_embeds( + ip_adapter_image, + ip_adapter_image_embeds, + device, + batch_size * num_images_per_prompt, + self.do_classifier_free_guidance, + ) + + # 7. Denoising loop with self.progress_bar(total=num_inference_steps) as progress_bar: for i, t in enumerate(timesteps): if self.interrupt: @@ -912,16 +1173,34 @@ def __call__( if self.do_classifier_free_guidance and skip_guidance_layers is None else latents ) + # broadcast to batch dimension in a way that's compatible with ONNX/Core ML timestep = t.expand(latent_model_input.shape[0]) + if ip_adapter_image is not None or ip_adapter_image_embeds is not None: + ip_hidden_states, temb = self.image_proj( + image_embeds, + timestep.to(dtype=latents.dtype), + need_temb=True, + ) + + image_prompt_embeds = dict( + ip_hidden_states = ip_hidden_states, + temb = temb + ) + else: + image_prompt_embeds = {} + noise_pred = self.transformer( hidden_states=latent_model_input, timestep=timestep, encoder_hidden_states=prompt_embeds, pooled_projections=pooled_prompt_embeds, - joint_attention_kwargs=self.joint_attention_kwargs, return_dict=False, + joint_attention_kwargs={ + **image_prompt_embeds, + **self.joint_attention_kwargs, + } )[0] # perform guidance @@ -940,7 +1219,10 @@ def __call__( timestep=timestep, encoder_hidden_states=original_prompt_embeds, pooled_projections=original_pooled_prompt_embeds, - joint_attention_kwargs=self.joint_attention_kwargs, + joint_attention_kwargs={ + **image_prompt_embeds, + **self.joint_attention_kwargs, + }, return_dict=False, skip_layers=skip_guidance_layers, )[0] From 0ef36dd247eac3f5a4ae959d99b3ff234cba72e3 Mon Sep 17 00:00:00 2001 From: Daniel Regado <danielregado@gmail.com> Date: Sat, 7 Dec 2024 14:31:09 +0000 Subject: [PATCH 03/44] Fixed typo and reverted removal of skip_layers in SD3Transformer2DModel --- src/diffusers/models/transformers/transformer_sd3.py | 11 ++++++++--- .../stable_diffusion_3/pipeline_stable_diffusion_3.py | 2 +- 2 files changed, 9 insertions(+), 4 deletions(-) diff --git a/src/diffusers/models/transformers/transformer_sd3.py b/src/diffusers/models/transformers/transformer_sd3.py index 37d86220011e..4865a52bb5fa 100644 --- a/src/diffusers/models/transformers/transformer_sd3.py +++ b/src/diffusers/models/transformers/transformer_sd3.py @@ -341,6 +341,7 @@ def forward( block_controlnet_hidden_states: List = None, joint_attention_kwargs: Optional[Dict[str, Any]] = None, return_dict: bool = True, + skip_layers: Optional[List[int]] = None, ) -> Union[torch.FloatTensor, Transformer2DModelOutput]: """ The [`SD3Transformer2DModel`] forward method. @@ -363,6 +364,8 @@ def forward( return_dict (`bool`, *optional*, defaults to `True`): Whether or not to return a [`~models.transformer_2d.Transformer2DModelOutput`] instead of a plain tuple. + skip_layers (`list` of `int`, *optional*): + A list of layer indices to skip during the forward pass. Returns: If `return_dict` is True, an [`~models.transformer_2d.Transformer2DModelOutput`] is returned, otherwise a @@ -390,7 +393,10 @@ def forward( encoder_hidden_states = self.context_embedder(encoder_hidden_states) for index_block, block in enumerate(self.transformer_blocks): - if self.training and self.gradient_checkpointing: + # Skip specified layers + is_skip = True if skip_layers is not None and index_block in skip_layers else False + + if torch.is_grad_enabled() and self.gradient_checkpointing and not is_skip: def create_custom_forward(module, return_dict=None): def custom_forward(*inputs): @@ -410,8 +416,7 @@ def custom_forward(*inputs): joint_attention_kwargs, **ckpt_kwargs, ) - - else: + elif not is_skip: encoder_hidden_states, hidden_states = block( hidden_states=hidden_states, encoder_hidden_states=encoder_hidden_states, temb=temb, joint_attention_kwargs=joint_attention_kwargs, diff --git a/src/diffusers/pipelines/stable_diffusion_3/pipeline_stable_diffusion_3.py b/src/diffusers/pipelines/stable_diffusion_3/pipeline_stable_diffusion_3.py index 5f643b7fddb7..443ce3787182 100644 --- a/src/diffusers/pipelines/stable_diffusion_3/pipeline_stable_diffusion_3.py +++ b/src/diffusers/pipelines/stable_diffusion_3/pipeline_stable_diffusion_3.py @@ -855,7 +855,7 @@ def set_ip_adapter_scale(self, scale): only conditioned by the text prompt. Lowering this value encourages the model to produce more diverse images, but they may not be as aligned with the image prompt. """ - for attn_processor in self.transformes.attn_processors.values(): + for attn_processor in self.transformer.attn_processors.values(): if isinstance(attn_processor, IPAdapterJointAttnProcessor2_0): attn_processor.scale = scale From de8909acfdec063a19e1f92facbc5b6f883f34b2 Mon Sep 17 00:00:00 2001 From: Daniel Regado <danielregado@gmail.com> Date: Mon, 9 Dec 2024 00:03:03 +0000 Subject: [PATCH 04/44] Added new SD3IPAdapterMixin loader --- src/diffusers/loaders/__init__.py | 10 +- src/diffusers/loaders/ip_adapter.py | 232 +++++++++++++++++- .../models/transformers/transformer_sd3.py | 58 ++++- .../pipeline_stable_diffusion_3.py | 186 +------------- 4 files changed, 291 insertions(+), 195 deletions(-) diff --git a/src/diffusers/loaders/__init__.py b/src/diffusers/loaders/__init__.py index 007d3c95597a..e4ffdd2324af 100644 --- a/src/diffusers/loaders/__init__.py +++ b/src/diffusers/loaders/__init__.py @@ -71,7 +71,10 @@ def text_encoder_attn_modules(text_encoder): "Mochi1LoraLoaderMixin", ] _import_structure["textual_inversion"] = ["TextualInversionLoaderMixin"] - _import_structure["ip_adapter"] = ["IPAdapterMixin"] + _import_structure["ip_adapter"] = [ + "IPAdapterMixin", + "SD3IPAdapterMixin", + ] _import_structure["peft"] = ["PeftAdapterMixin"] @@ -83,7 +86,10 @@ def text_encoder_attn_modules(text_encoder): from .utils import AttnProcsLayers if is_transformers_available(): - from .ip_adapter import IPAdapterMixin + from .ip_adapter import ( + IPAdapterMixin, + SD3IPAdapterMixin, + ) from .lora_pipeline import ( AmusedLoraLoaderMixin, CogVideoXLoraLoaderMixin, diff --git a/src/diffusers/loaders/ip_adapter.py b/src/diffusers/loaders/ip_adapter.py index ca460f948e6f..0960d90f018f 100644 --- a/src/diffusers/loaders/ip_adapter.py +++ b/src/diffusers/loaders/ip_adapter.py @@ -33,16 +33,23 @@ if is_transformers_available(): - from transformers import CLIPImageProcessor, CLIPVisionModelWithProjection - - from ..models.attention_processor import ( - AttnProcessor, - AttnProcessor2_0, - IPAdapterAttnProcessor, - IPAdapterAttnProcessor2_0, - IPAdapterXFormersAttnProcessor, + from transformers import ( + CLIPImageProcessor, + CLIPVisionModelWithProjection, + SiglipImageProcessor, + SiglipVisionModel ) +from ..models.attention_processor import ( + AttnProcessor, + AttnProcessor2_0, + JointAttnProcessor2_0, + IPAdapterAttnProcessor, + IPAdapterAttnProcessor2_0, + IPAdapterXFormersAttnProcessor, + IPAdapterJointAttnProcessor2_0, +) + logger = logging.get_logger(__name__) @@ -348,3 +355,212 @@ def unload_ip_adapter(self): else value.__class__() ) self.unet.set_attn_processor(attn_procs) + + +class SD3IPAdapterMixin: + """Mixin for handling StableDiffusion 3 IP Adapters.""" + + @validate_hf_hub_args + def load_ip_adapter( + self, + pretrained_model_name_or_path_or_dict: Union[str, Dict[str, torch.Tensor]], + subfolder: str, + weight_name: str, + image_encoder_folder: Optional[str] = "image_encoder", + **kwargs, + ): + """ + Parameters: + pretrained_model_name_or_path_or_dict (`str` or `os.PathLike` or `dict`): + Can be either: + - A string, the *model id* (for example `google/ddpm-celebahq-256`) of a pretrained model hosted on + the Hub. + - A path to a *directory* (for example `./my_model_directory`) containing the model weights saved + with [`ModelMixin.save_pretrained`]. + - A [torch state + dict](https://pytorch.org/tutorials/beginner/saving_loading_models.html#what-is-a-state-dict). + subfolder (`str`): + The subfolder location of a model file within a larger model repository on the Hub or locally. If a + list is passed, it should have the same length as `weight_name`. + weight_name (`str`): + The name of the weight file to load. If a list is passed, it should have the same length as + `subfolder`. + image_encoder_folder (`str`, *optional*, defaults to `image_encoder`): + The subfolder location of the image encoder within a larger model repository on the Hub or locally. + Pass `None` to not load the image encoder. If the image encoder is located in a folder inside + `subfolder`, you only need to pass the name of the folder that contains image encoder weights, e.g. + `image_encoder_folder="image_encoder"`. If the image encoder is located in a folder other than + `subfolder`, you should pass the path to the folder that contains image encoder weights, for example, + `image_encoder_folder="different_subfolder/image_encoder"`. + cache_dir (`Union[str, os.PathLike]`, *optional*): + Path to a directory where a downloaded pretrained model configuration is cached if the standard cache + is not used. + force_download (`bool`, *optional*, defaults to `False`): + Whether or not to force the (re-)download of the model weights and configuration files, overriding the + cached versions if they exist. + proxies (`Dict[str, str]`, *optional*): + A dictionary of proxy servers to use by protocol or endpoint, for example, `{'http': 'foo.bar:3128', + 'http://hostname': 'foo.bar:4012'}`. The proxies are used on each request. + local_files_only (`bool`, *optional*, defaults to `False`): + Whether to only load local model weights and configuration files or not. If set to `True`, the model + won't be downloaded from the Hub. + token (`str` or *bool*, *optional*): + The token to use as HTTP bearer authorization for remote files. If `True`, the token generated from + `diffusers-cli login` (stored in `~/.huggingface`) is used. + revision (`str`, *optional*, defaults to `"main"`): + The specific model version to use. It can be a branch name, a tag name, a commit id, or any identifier + allowed by Git. + low_cpu_mem_usage (`bool`, *optional*, defaults to `True` if torch version >= 1.9.0 else `False`): + Speed up model loading only loading the pretrained weights and not initializing the weights. This also + tries to not use more than 1x model size in CPU memory (including peak memory) while loading the model. + Only supported for PyTorch >= 1.9.0. If you are using an older version of PyTorch, setting this + argument to `True` will raise an error. + """ + # Load the main state dict first + cache_dir = kwargs.pop("cache_dir", None) + force_download = kwargs.pop("force_download", False) + proxies = kwargs.pop("proxies", None) + local_files_only = kwargs.pop("local_files_only", None) + token = kwargs.pop("token", None) + revision = kwargs.pop("revision", None) + low_cpu_mem_usage = kwargs.pop("low_cpu_mem_usage", _LOW_CPU_MEM_USAGE_DEFAULT) + + if low_cpu_mem_usage and not is_accelerate_available(): + low_cpu_mem_usage = False + logger.warning( + "Cannot initialize model with low cpu memory usage because `accelerate` was not found in the" + " environment. Defaulting to `low_cpu_mem_usage=False`. It is strongly recommended to install" + " `accelerate` for faster and less memory-intense model loading. You can do so with: \n```\npip" + " install accelerate\n```\n." + ) + + if low_cpu_mem_usage is True and not is_torch_version(">=", "1.9.0"): + raise NotImplementedError( + "Low memory initialization requires torch >= 1.9.0. Please either update your PyTorch version or set" + " `low_cpu_mem_usage=False`." + ) + + user_agent = { + "file_type": "attn_procs_weights", + "framework": "pytorch", + } + + if not isinstance(pretrained_model_name_or_path_or_dict, dict): + model_file = _get_model_file( + pretrained_model_name_or_path_or_dict, + weights_name=weight_name, + cache_dir=cache_dir, + force_download=force_download, + proxies=proxies, + local_files_only=local_files_only, + token=token, + revision=revision, + subfolder=subfolder, + user_agent=user_agent, + ) + if weight_name.endswith(".safetensors"): + state_dict = {"image_proj": {}, "ip_adapter": {}} + with safe_open(model_file, framework="pt", device="cpu") as f: + for key in f.keys(): + if key.startswith("image_proj."): + state_dict["image_proj"][key.replace("image_proj.", "")] = f.get_tensor(key) + elif key.startswith("ip_adapter."): + state_dict["ip_adapter"][key.replace("ip_adapter.", "")] = f.get_tensor(key) + else: + state_dict = load_state_dict(model_file) + else: + state_dict = pretrained_model_name_or_path_or_dict + + keys = list(state_dict.keys()) + if "image_proj" not in keys and "ip_adapter" not in keys: + raise ValueError("Required keys are (`image_proj` and `ip_adapter`) missing from the state dict.") + + # Load image_encoder and feature_extractor here if they haven't been registered to the pipeline yet + if hasattr(self, "image_encoder") and getattr(self, "image_encoder", None) is None: + if image_encoder_folder is not None: + if not isinstance(pretrained_model_name_or_path_or_dict, dict): + logger.info(f"loading image_encoder from {pretrained_model_name_or_path_or_dict}") + if image_encoder_folder.count("/") == 0: + image_encoder_subfolder = Path(subfolder, image_encoder_folder).as_posix() + else: + image_encoder_subfolder = Path(image_encoder_folder).as_posix() + + # Commons args for loading image encoder and image processor + args = dict( + pretrained_model_name_or_path_or_dict, + subfolder=image_encoder_subfolder, + low_cpu_mem_usage=low_cpu_mem_usage, + cache_dir=cache_dir, + local_files_only=local_files_only, + ) + + self.register_modules( + feature_extractor = SiglipImageProcessor.from_pretrained(**args).to(self.device, dtype=self.dtype), + image_encoder = SiglipVisionModel.from_pretrained(**args).to(self.device, dtype=self.dtype), + ) + else: + raise ValueError( + "`image_encoder` cannot be loaded because `pretrained_model_name_or_path_or_dict` is a state dict." + ) + else: + logger.warning( + "image_encoder is not loaded since `image_encoder_folder=None` passed. You will not be able to use `ip_adapter_image` when calling the pipeline with IP-Adapter." + "Use `ip_adapter_image_embeds` to pass pre-generated image embedding instead." + ) + + # Load IP-Adapter into transformer + self.transformer._load_ip_adapter_weights(state_dict, low_cpu_mem_usage=low_cpu_mem_usage) + + def set_ip_adapter_scale(self, scale: float): + """ + Controls image/text prompt conditioning. A value of 1.0 means the model is only conditioned on the image prompt, and 0.0 + only conditioned by the text prompt. Lowering this value encourages the model to produce more diverse images, but they + may not be as aligned with the image prompt. + + Example: + + ```python + >>> # Assuming `pipeline` is already loaded with the IP Adapter weights. + >>> pipeline.set_ip_adapter_scale(0.6) + >>> ... + ``` + """ + for attn_processor in self.transformer.attn_processors.values(): + if isinstance(attn_processor, IPAdapterJointAttnProcessor2_0): + attn_processor.scale = scale + + def unload_ip_adapter(self): + """ + Unloads the IP Adapter weights. + + Example: + + ```python + >>> # Assuming `pipeline` is already loaded with the IP Adapter weights. + >>> pipeline.unload_ip_adapter() + >>> ... + ``` + """ + # Remove image encoder + if hasattr(self, "image_encoder") and getattr(self, "image_encoder", None) is not None: + self.image_encoder = None + self.register_to_config(image_encoder=None) + + # Remove feature extractor + if hasattr(self, "feature_extractor") and getattr(self, "feature_extractor", None) is not None: + self.feature_extractor = None + self.register_to_config(feature_extractor=None) + + # Remove image projection + self.transformer.image_proj = None + + # Restore original attention processors layers + attn_procs = { + name: ( + JointAttnProcessor2_0() + if isinstance(value, IPAdapterJointAttnProcessor2_0) + else value.__class__() + ) + for name, value in self.transformer.attn_processors.items() + } + self.transformer.set_attn_processor(attn_procs) \ No newline at end of file diff --git a/src/diffusers/models/transformers/transformer_sd3.py b/src/diffusers/models/transformers/transformer_sd3.py index 4865a52bb5fa..1b8fac537496 100644 --- a/src/diffusers/models/transformers/transformer_sd3.py +++ b/src/diffusers/models/transformers/transformer_sd3.py @@ -27,12 +27,13 @@ AttentionProcessor, FusedJointAttnProcessor2_0, JointAttnProcessor2_0, + IPAdapterJointAttnProcessor2_0, ) -from ...models.modeling_utils import ModelMixin +from ...models.modeling_utils import ModelMixin, load_model_dict_into_meta from ...models.normalization import AdaLayerNormContinuous, AdaLayerNormZero from ...utils import USE_PEFT_BACKEND, is_torch_version, logging, scale_lora_layers, unscale_lora_layers from ...utils.torch_utils import maybe_allow_in_graph -from ..embeddings import CombinedTimestepTextProjEmbeddings, PatchEmbed +from ..embeddings import CombinedTimestepTextProjEmbeddings, PatchEmbed, TimePerceiverResampler from ..modeling_outputs import Transformer2DModelOutput @@ -332,6 +333,59 @@ def _set_gradient_checkpointing(self, module, value=False): if hasattr(module, "gradient_checkpointing"): module.gradient_checkpointing = value + def _load_ip_adapter_weights(self, state_dict: Dict, low_cpu_mem_usage: bool): + # IP-Adapter cross attention parameters + hidden_size = self.config.attention_head_dim * self.config.num_attention_heads + ip_hidden_states_dim = self.config.attention_head_dim * self.config.num_attention_heads + timesteps_emb_dim = state_dict["ip_adapter"]["0.norm_ip.linear.weight"].shape[1] + + # Dict where key is transformer layer index, value is attention processor's state dict + # ip_adapter state dict keys example: "0.norm_ip.linear.weight" + layer_state_dict = {idx: {} for idx in range(len(self.attn_processors))} + for key, weights in state_dict["ip_adapter"].items(): + idx, name = key.split(".", maxsplit=1) + layer_state_dict[int(idx)][name] = weights + + # Create IP-Adapter attention processor + attn_procs = {} + for idx, name in enumerate(self.attn_processors.keys()): + attn_procs[name] = IPAdapterJointAttnProcessor2_0( + hidden_size=hidden_size, + ip_hidden_states_dim=ip_hidden_states_dim, + head_dim=self.config.attention_head_dim, + timesteps_emb_dim=timesteps_emb_dim, + ).to(self.device, dtype=self.dtype) + + if not low_cpu_mem_usage: + attn_procs[name].load_state_dict(layer_state_dict[idx], strict=True) + else: + load_model_dict_into_meta(attn_procs[name], layer_state_dict, device=self.device, dtype=self.dtype) + + self.set_attn_processor(attn_procs) + + # Image projetion parameters + embed_dim = state_dict["image_proj"]["proj_in.weight"].shape[1] + output_dim = state_dict["image_proj"]["proj_out.weight"].shape[0] + hidden_dim = state_dict["image_proj"]["latents"].shape[2] + heads = state_dict["image_proj"]["layers.0.0.to_q.weight"].shape[0] // 64 + num_queries = state_dict["image_proj"]["latents"].shape[1] + timestep_in_dim = state_dict["image_proj"]["time_embedding.linear_1.weight"].shape[1] + + # Image projection + self.image_proj = TimePerceiverResampler( + embed_dim=embed_dim, + output_dim=output_dim, + hidden_dim=hidden_dim, + heads=heads, + num_queries=num_queries, + timestep_in_dim=timestep_in_dim + ).to(device=self.device, dtype=self.dtype) + + if not low_cpu_mem_usage: + self.image_proj.load_state_dict(state_dict["image_proj"], strict=True) + else: + load_model_dict_into_meta(self.image_proj, state_dict["image_proj"], device=self.device, dtype=self.dtype) + def forward( self, hidden_states: torch.FloatTensor, diff --git a/src/diffusers/pipelines/stable_diffusion_3/pipeline_stable_diffusion_3.py b/src/diffusers/pipelines/stable_diffusion_3/pipeline_stable_diffusion_3.py index 443ce3787182..11d285e15051 100644 --- a/src/diffusers/pipelines/stable_diffusion_3/pipeline_stable_diffusion_3.py +++ b/src/diffusers/pipelines/stable_diffusion_3/pipeline_stable_diffusion_3.py @@ -16,7 +16,6 @@ from typing import Any, Callable, Dict, List, Optional, Union import torch -from safetensors import safe_open from transformers import ( CLIPTextModelWithProjection, CLIPTokenizer, @@ -27,18 +26,10 @@ ) from ...image_processor import VaeImageProcessor, PipelineImageInput -from ...loaders import FromSingleFileMixin, SD3LoraLoaderMixin +from ...loaders import FromSingleFileMixin, SD3LoraLoaderMixin, SD3IPAdapterMixin from ...models.autoencoders import AutoencoderKL from ...models.transformers import SD3Transformer2DModel -from ...models.embeddings import TimePerceiverResampler from ...schedulers import FlowMatchEulerDiscreteScheduler -from ...models.attention_processor import IPAdapterJointAttnProcessor2_0 -from ...models.modeling_utils import ( - load_model_dict_into_meta, - _LOW_CPU_MEM_USAGE_DEFAULT, - load_state_dict, -) -from huggingface_hub.utils import validate_hf_hub_args from ...utils import ( USE_PEFT_BACKEND, is_torch_xla_available, @@ -46,9 +37,6 @@ replace_example_docstring, scale_lora_layers, unscale_lora_layers, - is_torch_version, - is_accelerate_available, - _get_model_file, ) from ...utils.torch_utils import randn_tensor from ..pipeline_utils import DiffusionPipeline @@ -142,7 +130,7 @@ def retrieve_timesteps( return timesteps, num_inference_steps -class StableDiffusion3Pipeline(DiffusionPipeline, SD3LoraLoaderMixin, FromSingleFileMixin): +class StableDiffusion3Pipeline(DiffusionPipeline, SD3LoraLoaderMixin, FromSingleFileMixin, SD3IPAdapterMixin): r""" Args: transformer ([`SD3Transformer2DModel`]): @@ -691,174 +679,6 @@ def num_timesteps(self): def interrupt(self): return self._interrupt - @validate_hf_hub_args - def load_ip_adapter( - self, - pretrained_model_name_or_path_or_dict: Union[str, Dict[str, torch.Tensor]], - weight_name: str, - subfolder: Optional[str] = None, - **kwargs, - ): - """ - Parameters: - pretrained_model_name_or_path_or_dict (`str` or `dict`): - Can be either: - - A string, the *model id* (for example `google/ddpm-celebahq-256`) of a pretrained model hosted on - the Hub. - - A path to a *directory* (for example `./my_model_directory`) containing the model weights saved - with [`ModelMixin.save_pretrained`]. - - A [torch state - dict](https://pytorch.org/tutorials/beginner/saving_loading_models.html#what-is-a-state-dict). - weight_name (`str`): - The name of the weight file to load. - subfolder (`str, *optional*): - The subfolder location of a model file within a larger model repository on the Hub or locally. - cache_dir (`Union[str, os.PathLike]`, *optional*): - Path to a directory where a downloaded pretrained model configuration is cached if the standard cache - is not used. - force_download (`bool`, *optional*, defaults to `False`): - Whether or not to force the (re-)download of the model weights and configuration files, overriding the - cached versions if they exist. - proxies (`Dict[str, str]`, *optional*): - A dictionary of proxy servers to use by protocol or endpoint, for example, `{'http': 'foo.bar:3128', - 'http://hostname': 'foo.bar:4012'}`. The proxies are used on each request. - local_files_only (`bool`, *optional*, defaults to `False`): - Whether to only load local model weights and configuration files or not. If set to `True`, the model - won't be downloaded from the Hub. - token (`str` or *bool*, *optional*): - The token to use as HTTP bearer authorization for remote files. If `True`, the token generated from - `diffusers-cli login` (stored in `~/.huggingface`) is used. - revision (`str`, *optional*, defaults to `"main"`): - The specific model version to use. It can be a branch name, a tag name, a commit id, or any identifier - allowed by Git. - low_cpu_mem_usage (`bool`, *optional*, defaults to `True` if torch version >= 1.9.0 else `False`): - Speed up model loading only loading the pretrained weights and not initializing the weights. This also - tries to not use more than 1x model size in CPU memory (including peak memory) while loading the model. - Only supported for PyTorch >= 1.9.0. If you are using an older version of PyTorch, setting this - argument to `True` will raise an error. - """ - - # Load the main state dict first - cache_dir = kwargs.pop("cache_dir", None) - force_download = kwargs.pop("force_download", False) - proxies = kwargs.pop("proxies", None) - local_files_only = kwargs.pop("local_files_only", None) - token = kwargs.pop("token", None) - revision = kwargs.pop("revision", None) - low_cpu_mem_usage = kwargs.pop("low_cpu_mem_usage", _LOW_CPU_MEM_USAGE_DEFAULT) - - if low_cpu_mem_usage: - if is_accelerate_available(): - from accelerate import init_empty_weights - - else: - low_cpu_mem_usage = False - logger.warning( - "Cannot initialize model with low cpu memory usage because `accelerate` was not found in the" - " environment. Defaulting to `low_cpu_mem_usage=False`. It is strongly recommended to install" - " `accelerate` for faster and less memory-intense model loading. You can do so with: \n```\npip" - " install accelerate\n```\n." - ) - - if low_cpu_mem_usage is True and not is_torch_version(">=", "1.9.0"): - raise NotImplementedError( - "Low memory initialization requires torch >= 1.9.0. Please either update your PyTorch version or set" - " `low_cpu_mem_usage=False`." - ) - - user_agent = { - "file_type": "attn_procs_weights", - "framework": "pytorch", - } - - if not isinstance(pretrained_model_name_or_path_or_dict, dict): - model_file = _get_model_file( - pretrained_model_name_or_path_or_dict, - weights_name=weight_name, - cache_dir=cache_dir, - force_download=force_download, - proxies=proxies, - local_files_only=local_files_only, - token=token, - revision=revision, - subfolder=subfolder, - user_agent=user_agent, - ) - if weight_name.endswith(".safetensors"): - state_dict = {"image_proj": {}, "ip_adapter": {}} - with safe_open(model_file, framework="pt", device="cpu") as f: - for key in f.keys(): - if key.startswith("image_proj."): - state_dict["image_proj"][key.replace("image_proj.", "")] = f.get_tensor(key) - elif key.startswith("ip_adapter."): - state_dict["ip_adapter"][key.replace("ip_adapter.", "")] = f.get_tensor(key) - else: - state_dict = load_state_dict(model_file) - else: - state_dict = pretrained_model_name_or_path_or_dict - - if list(state_dict.keys()) != ["image_proj", "ip_adapter"]: - raise ValueError("Required keys are (`image_proj` and `ip_adapter`) missing from the state dict.") - - # Load ip_adapter - hidden_size = self.transformer.config.attention_head_dim * self.transformer.config.num_attention_heads - ip_hidden_states_dim = self.transformer.config.attention_head_dim * self.transformer.config.num_attention_heads - timesteps_emb_dim = state_dict["ip_adapter"]["0.norm_ip.linear.weight"].shape[1] - - # State dict by layer - layer_state_dict = {idx: {} for idx in range(len(self.transformer.attn_processors))} - for key, weights in state_dict["ip_adapter"].items(): - idx, name = key.split(".", 1) - layer_state_dict[int(idx)][name] = weights - - attn_procs = {} - for idx, name in enumerate(self.transformer.attn_processors.keys()): - attn_procs[name] = IPAdapterJointAttnProcessor2_0( - hidden_size=hidden_size, - ip_hidden_states_dim=ip_hidden_states_dim, - head_dim=self.transformer.config.attention_head_dim, - timesteps_emb_dim=timesteps_emb_dim, - ).to(self.device, dtype=self.dtype) - - if not low_cpu_mem_usage: - attn_procs[name].load_state_dict(layer_state_dict[idx]) - else: - load_model_dict_into_meta(attn_procs[name], layer_state_dict, device=self.device, dtype=self.dtype) - - self.transformer.set_attn_processor(attn_procs) - - # Load image_proj - embed_dim = state_dict["image_proj"]["proj_in.weight"].shape[1] - output_dim = state_dict["image_proj"]["proj_out.weight"].shape[0] - hidden_dim = state_dict["image_proj"]["latents"].shape[2] - heads = state_dict["image_proj"]["layers.0.0.to_q.weight"].shape[0] // 64 - num_queries = state_dict["image_proj"]["latents"].shape[1] - timestep_in_dim = state_dict["image_proj"]["time_embedding.linear_1.weight"].shape[1] - - self.image_proj = TimePerceiverResampler( - embed_dim=embed_dim, - output_dim=output_dim, - hidden_dim=hidden_dim, - heads=heads, - num_queries=num_queries, - timestep_in_dim=timestep_in_dim - ).to(device=self.device, dtype=self.dtype) - - if not low_cpu_mem_usage: - self.image_proj.load_state_dict(state_dict["image_proj"], strict=True) - else: - load_model_dict_into_meta(self.image_proj, state_dict["image_proj"], device=self.device, dtype=self.dtype) - - def set_ip_adapter_scale(self, scale): - """ - Controls image/text prompt conditioning. A value of 1.0 means the model is only conditioned on the image prompt, and 0.0 - only conditioned by the text prompt. Lowering this value encourages the model to produce more diverse images, but they - may not be as aligned with the image prompt. - """ - for attn_processor in self.transformer.attn_processors.values(): - if isinstance(attn_processor, IPAdapterJointAttnProcessor2_0): - attn_processor.scale = scale - # Adapted from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_xl.StableDiffusionXLPipeline.encode_image def encode_image(self, image): if not isinstance(image, torch.Tensor): @@ -1173,7 +993,7 @@ def __call__( timestep = t.expand(latent_model_input.shape[0]) if ip_adapter_image is not None or ip_adapter_image_embeds is not None: - ip_hidden_states, temb = self.image_proj( + ip_hidden_states, temb = self.transformer.image_proj( image_embeds, timestep.to(dtype=latents.dtype), need_temb=True, From ab0d90421917ef708c7a898d0c764b6ac68a0c38 Mon Sep 17 00:00:00 2001 From: Daniel Regado <danielregado@gmail.com> Date: Mon, 9 Dec 2024 10:54:46 +0000 Subject: [PATCH 05/44] ip_adapter image embeds now considers num_images_per_prompt --- .../pipeline_stable_diffusion_3.py | 50 ++++++------------- 1 file changed, 15 insertions(+), 35 deletions(-) diff --git a/src/diffusers/pipelines/stable_diffusion_3/pipeline_stable_diffusion_3.py b/src/diffusers/pipelines/stable_diffusion_3/pipeline_stable_diffusion_3.py index 11d285e15051..69416df3469a 100644 --- a/src/diffusers/pipelines/stable_diffusion_3/pipeline_stable_diffusion_3.py +++ b/src/diffusers/pipelines/stable_diffusion_3/pipeline_stable_diffusion_3.py @@ -695,42 +695,22 @@ def encode_image(self, image): def prepare_ip_adapter_image_embeds( self, ip_adapter_image, ip_adapter_image_embeds, device, num_images_per_prompt, do_classifier_free_guidance ): - # image_embeds = [] - - # if do_classifier_free_guidance: - # negative_image_embeds = [] - - # if ip_adapter_image_embeds is None: - # single_image_embeds, single_negative_image_embeds = self.encode_image(ip_adapter_image) - # image_embeds.append(single_image_embeds[None, :]) - - # if do_classifier_free_guidance: - # negative_image_embeds.append(single_negative_image_embeds[None, :]) - # else: - # for single_image_embeds in ip_adapter_image_embeds: - # if do_classifier_free_guidance: - # single_negative_image_embeds, single_image_embeds = single_image_embeds.chunk(2) - # negative_image_embeds.append(single_negative_image_embeds) - # image_embeds.append(single_image_embeds) - - # ip_adapter_image_embeds = [] - # for i, single_image_embeds in enumerate(image_embeds): - # single_image_embeds = torch.cat([single_image_embeds] * num_images_per_prompt, dim=0) - - # if do_classifier_free_guidance: - # single_negative_image_embeds = torch.cat([negative_image_embeds[i]] * num_images_per_prompt, dim=0) - # single_image_embeds = torch.cat([single_negative_image_embeds, single_image_embeds], dim=0) - - # single_image_embeds = single_image_embeds.to(device=device) - # ip_adapter_image_embeds.append(single_image_embeds) - - - # Single image only :/ - clip_image_tensor = self.feature_extractor(images=ip_adapter_image, return_tensors="pt").pixel_values - clip_image_tensor = clip_image_tensor.to(device, dtype=self.dtype) - clip_image_embeds = self.image_encoder(clip_image_tensor, output_hidden_states=True).hidden_states[-2] + if ip_adapter_image_embeds is None: + single_image_embeds, single_negative_image_embeds = self.encode_image(ip_adapter_image) + else: + for single_image_embeds in ip_adapter_image_embeds: + if do_classifier_free_guidance: + single_negative_image_embeds, single_image_embeds = single_image_embeds.chunk(2) + else: + single_image_embeds = ip_adapter_image_embeds + + single_image_embeds = torch.cat([single_image_embeds] * num_images_per_prompt, dim=0) - return torch.cat([torch.zeros_like(clip_image_embeds), clip_image_embeds], dim=0) + if do_classifier_free_guidance: + single_negative_image_embeds = torch.cat([single_negative_image_embeds] * num_images_per_prompt, dim=0) + single_image_embeds = torch.cat([single_negative_image_embeds, single_image_embeds], dim=0) + + return single_image_embeds.to(device=device) @torch.no_grad() @replace_example_docstring(EXAMPLE_DOC_STRING) From 5aed1d3d0a7b8aac81976b0bd286394c4211133b Mon Sep 17 00:00:00 2001 From: Daniel Regado <danielregado@gmail.com> Date: Mon, 9 Dec 2024 14:02:02 +0000 Subject: [PATCH 06/44] Removed usage of einops --- src/diffusers/models/attention_processor.py | 13 ++++++------- 1 file changed, 6 insertions(+), 7 deletions(-) diff --git a/src/diffusers/models/attention_processor.py b/src/diffusers/models/attention_processor.py index b77682fb0bf0..122909698b29 100644 --- a/src/diffusers/models/attention_processor.py +++ b/src/diffusers/models/attention_processor.py @@ -18,7 +18,6 @@ import torch import torch.nn.functional as F from torch import nn -from einops import rearrange from ..image_processor import IPAdapterMaskProcessor from ..utils import deprecate, is_torch_xla_available, logging @@ -5156,11 +5155,11 @@ def __call__( ip_value = self.to_v_ip(norm_ip_hidden_states) # Reshape - img_query = rearrange(img_query, 'b l (h d) -> b h l d', h=attn.heads) - img_key = rearrange(img_key, 'b l (h d) -> b h l d', h=attn.heads) - img_value = rearrange(img_value, 'b l (h d) -> b h l d', h=attn.heads) - ip_key = rearrange(ip_key, 'b l (h d) -> b h l d', h=attn.heads) - ip_value = rearrange(ip_value, 'b l (h d) -> b h l d', h=attn.heads) + img_query = img_query.view(batch_size, head_dim, attn.heads, -1).transpose(1,2) + img_key = img_key.view(batch_size, head_dim, attn.heads, -1).transpose(1,2) + img_value = img_value.view(batch_size, head_dim, attn.heads, -1).transpose(1,2) + ip_key = ip_key.view(batch_size, head_dim, attn.heads, -1).transpose(1,2) + ip_value = ip_value.view(batch_size, head_dim, attn.heads, -1).transpose(1,2) # Norm img_query = self.norm_q(img_query) @@ -5172,7 +5171,7 @@ def __call__( img_value = torch.cat([img_value, ip_value], dim=2) ip_hidden_states = F.scaled_dot_product_attention(img_query, img_key, img_value, dropout_p=0.0, is_causal=False) - ip_hidden_states = rearrange(ip_hidden_states, 'b h l d -> b l (h d)') + ip_hidden_states = ip_hidden_states.transpose(1,2).view(batch_size, head_dim, -1) ip_hidden_states = ip_hidden_states.to(img_query.dtype) hidden_states = hidden_states + ip_hidden_states * self.scale From 4383175d04b88750ec266c581b5cb1153df7a05e Mon Sep 17 00:00:00 2001 From: Daniel Regado <danielregado@gmail.com> Date: Mon, 9 Dec 2024 14:26:20 +0000 Subject: [PATCH 07/44] Reverted joint_attention_kwargs default for consistency --- .../pipeline_stable_diffusion_3.py | 23 ++++++++----------- 1 file changed, 10 insertions(+), 13 deletions(-) diff --git a/src/diffusers/pipelines/stable_diffusion_3/pipeline_stable_diffusion_3.py b/src/diffusers/pipelines/stable_diffusion_3/pipeline_stable_diffusion_3.py index 69416df3469a..fb457fc50429 100644 --- a/src/diffusers/pipelines/stable_diffusion_3/pipeline_stable_diffusion_3.py +++ b/src/diffusers/pipelines/stable_diffusion_3/pipeline_stable_diffusion_3.py @@ -738,7 +738,7 @@ def __call__( ip_adapter_image_embeds: Optional[List[torch.Tensor]] = None, output_type: Optional[str] = "pil", return_dict: bool = True, - joint_attention_kwargs: Dict[str, Any] = {}, + joint_attention_kwargs: Optional[Dict[str, Any]] = None, clip_skip: Optional[int] = None, callback_on_step_end: Optional[Callable[[int, int, Dict], None]] = None, callback_on_step_end_tensor_inputs: List[str] = ["latents"], @@ -980,11 +980,14 @@ def __call__( ) image_prompt_embeds = dict( - ip_hidden_states = ip_hidden_states, - temb = temb + ip_hidden_states=ip_hidden_states, + temb=temb ) - else: - image_prompt_embeds = {} + + if self.joint_attention_kwargs is None: + self._joint_attention_kwargs = image_prompt_embeds + else: + self._joint_attention_kwargs.update(**image_prompt_embeds) noise_pred = self.transformer( hidden_states=latent_model_input, @@ -992,10 +995,7 @@ def __call__( encoder_hidden_states=prompt_embeds, pooled_projections=pooled_prompt_embeds, return_dict=False, - joint_attention_kwargs={ - **image_prompt_embeds, - **self.joint_attention_kwargs, - } + joint_attention_kwargs=self.joint_attention_kwargs, )[0] # perform guidance @@ -1016,10 +1016,7 @@ def __call__( timestep=timestep, encoder_hidden_states=original_prompt_embeds, pooled_projections=original_pooled_prompt_embeds, - joint_attention_kwargs={ - **image_prompt_embeds, - **self.joint_attention_kwargs, - }, + joint_attention_kwargs=self.joint_attention_kwargs, return_dict=False, skip_layers=skip_guidance_layers, )[0] From 461ab73375e868e9dff3100bea5e58b63b55d4c2 Mon Sep 17 00:00:00 2001 From: Daniel Regado <danielregado@gmail.com> Date: Mon, 9 Dec 2024 15:12:51 +0000 Subject: [PATCH 08/44] Corrected einops removal --- src/diffusers/models/attention_processor.py | 12 ++++++------ .../pipeline_stable_diffusion_3.py | 2 +- 2 files changed, 7 insertions(+), 7 deletions(-) diff --git a/src/diffusers/models/attention_processor.py b/src/diffusers/models/attention_processor.py index 122909698b29..22088c9f9101 100644 --- a/src/diffusers/models/attention_processor.py +++ b/src/diffusers/models/attention_processor.py @@ -5155,11 +5155,11 @@ def __call__( ip_value = self.to_v_ip(norm_ip_hidden_states) # Reshape - img_query = img_query.view(batch_size, head_dim, attn.heads, -1).transpose(1,2) - img_key = img_key.view(batch_size, head_dim, attn.heads, -1).transpose(1,2) - img_value = img_value.view(batch_size, head_dim, attn.heads, -1).transpose(1,2) - ip_key = ip_key.view(batch_size, head_dim, attn.heads, -1).transpose(1,2) - ip_value = ip_value.view(batch_size, head_dim, attn.heads, -1).transpose(1,2) + img_query = img_query.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2) + img_key = img_key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2) + img_value = img_value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2) + ip_key = ip_key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2) + ip_value = ip_value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2) # Norm img_query = self.norm_q(img_query) @@ -5171,7 +5171,7 @@ def __call__( img_value = torch.cat([img_value, ip_value], dim=2) ip_hidden_states = F.scaled_dot_product_attention(img_query, img_key, img_value, dropout_p=0.0, is_causal=False) - ip_hidden_states = ip_hidden_states.transpose(1,2).view(batch_size, head_dim, -1) + ip_hidden_states = ip_hidden_states.transpose(1, 2).view(batch_size, -1, attn.heads * head_dim) ip_hidden_states = ip_hidden_states.to(img_query.dtype) hidden_states = hidden_states + ip_hidden_states * self.scale diff --git a/src/diffusers/pipelines/stable_diffusion_3/pipeline_stable_diffusion_3.py b/src/diffusers/pipelines/stable_diffusion_3/pipeline_stable_diffusion_3.py index fb457fc50429..2bee2fe87914 100644 --- a/src/diffusers/pipelines/stable_diffusion_3/pipeline_stable_diffusion_3.py +++ b/src/diffusers/pipelines/stable_diffusion_3/pipeline_stable_diffusion_3.py @@ -994,8 +994,8 @@ def __call__( timestep=timestep, encoder_hidden_states=prompt_embeds, pooled_projections=pooled_prompt_embeds, - return_dict=False, joint_attention_kwargs=self.joint_attention_kwargs, + return_dict=False, )[0] # perform guidance From 832324085160bea53bf71250c7cefcbda704774a Mon Sep 17 00:00:00 2001 From: Daniel Regado <danielregado@gmail.com> Date: Mon, 9 Dec 2024 15:37:42 +0000 Subject: [PATCH 09/44] Quality and style checks --- src/diffusers/loaders/ip_adapter.py | 32 ++++++++---------- src/diffusers/models/attention.py | 12 ++++--- src/diffusers/models/attention_processor.py | 15 +++++---- src/diffusers/models/embeddings.py | 20 +++++------ .../models/transformers/transformer_sd3.py | 8 +++-- .../pipeline_stable_diffusion_3.py | 33 +++++++++---------- 6 files changed, 59 insertions(+), 61 deletions(-) diff --git a/src/diffusers/loaders/ip_adapter.py b/src/diffusers/loaders/ip_adapter.py index 0960d90f018f..34f3a151262f 100644 --- a/src/diffusers/loaders/ip_adapter.py +++ b/src/diffusers/loaders/ip_adapter.py @@ -33,23 +33,19 @@ if is_transformers_available(): - from transformers import ( - CLIPImageProcessor, - CLIPVisionModelWithProjection, - SiglipImageProcessor, - SiglipVisionModel - ) + from transformers import CLIPImageProcessor, CLIPVisionModelWithProjection, SiglipImageProcessor, SiglipVisionModel from ..models.attention_processor import ( AttnProcessor, AttnProcessor2_0, - JointAttnProcessor2_0, IPAdapterAttnProcessor, IPAdapterAttnProcessor2_0, - IPAdapterXFormersAttnProcessor, IPAdapterJointAttnProcessor2_0, + IPAdapterXFormersAttnProcessor, + JointAttnProcessor2_0, ) + logger = logging.get_logger(__name__) @@ -495,8 +491,10 @@ def load_ip_adapter( ) self.register_modules( - feature_extractor = SiglipImageProcessor.from_pretrained(**args).to(self.device, dtype=self.dtype), - image_encoder = SiglipVisionModel.from_pretrained(**args).to(self.device, dtype=self.dtype), + feature_extractor=SiglipImageProcessor.from_pretrained(**args).to( + self.device, dtype=self.dtype + ), + image_encoder=SiglipVisionModel.from_pretrained(**args).to(self.device, dtype=self.dtype), ) else: raise ValueError( @@ -513,9 +511,9 @@ def load_ip_adapter( def set_ip_adapter_scale(self, scale: float): """ - Controls image/text prompt conditioning. A value of 1.0 means the model is only conditioned on the image prompt, and 0.0 - only conditioned by the text prompt. Lowering this value encourages the model to produce more diverse images, but they - may not be as aligned with the image prompt. + Controls image/text prompt conditioning. A value of 1.0 means the model is only conditioned on the image + prompt, and 0.0 only conditioned by the text prompt. Lowering this value encourages the model to produce more + diverse images, but they may not be as aligned with the image prompt. Example: @@ -556,11 +554,7 @@ def unload_ip_adapter(self): # Restore original attention processors layers attn_procs = { - name: ( - JointAttnProcessor2_0() - if isinstance(value, IPAdapterJointAttnProcessor2_0) - else value.__class__() - ) + name: (JointAttnProcessor2_0() if isinstance(value, IPAdapterJointAttnProcessor2_0) else value.__class__()) for name, value in self.transformer.attn_processors.items() } - self.transformer.set_attn_processor(attn_procs) \ No newline at end of file + self.transformer.set_attn_processor(attn_procs) diff --git a/src/diffusers/models/attention.py b/src/diffusers/models/attention.py index bb761d1a99c2..e8a47fa8226a 100644 --- a/src/diffusers/models/attention.py +++ b/src/diffusers/models/attention.py @@ -188,8 +188,11 @@ def set_chunk_feed_forward(self, chunk_size: Optional[int], dim: int = 0): self._chunk_dim = dim def forward( - self, hidden_states: torch.FloatTensor, encoder_hidden_states: torch.FloatTensor, temb: torch.FloatTensor, - joint_attention_kwargs: Dict[str, Any] = {} + self, + hidden_states: torch.FloatTensor, + encoder_hidden_states: torch.FloatTensor, + temb: torch.FloatTensor, + joint_attention_kwargs: Dict[str, Any] = {}, ): if self.use_dual_attention: norm_hidden_states, gate_msa, shift_mlp, scale_mlp, gate_mlp, norm_hidden_states2, gate_msa2 = self.norm1( @@ -207,8 +210,9 @@ def forward( # Attention. attn_output, context_attn_output = self.attn( - hidden_states=norm_hidden_states, encoder_hidden_states=norm_encoder_hidden_states, - **joint_attention_kwargs + hidden_states=norm_hidden_states, + encoder_hidden_states=norm_encoder_hidden_states, + **joint_attention_kwargs, ) # Process attention outputs for the `hidden_states`. diff --git a/src/diffusers/models/attention_processor.py b/src/diffusers/models/attention_processor.py index 22088c9f9101..a8fb8bf28f89 100644 --- a/src/diffusers/models/attention_processor.py +++ b/src/diffusers/models/attention_processor.py @@ -5047,7 +5047,7 @@ def __call__( hidden_states = hidden_states / attn.rescale_output_factor return hidden_states - + class IPAdapterJointAttnProcessor2_0(torch.nn.Module): """Attention processor for IP-Adapter used typically in processing the SD3-like self-attention projections.""" @@ -5058,15 +5058,14 @@ def __init__( ip_hidden_states_dim: int, head_dim: int, timesteps_emb_dim: int = 1280, - scale: float = 0.5 + scale: float = 0.5, ): super().__init__() # To prevent circular import - from .normalization import RMSNorm, AdaLayerNorm + from .normalization import AdaLayerNorm, RMSNorm - self.norm_ip = AdaLayerNorm(timesteps_emb_dim, output_dim=ip_hidden_states_dim * 2, - norm_eps=1e-6, chunk_dim=1) + self.norm_ip = AdaLayerNorm(timesteps_emb_dim, output_dim=ip_hidden_states_dim * 2, norm_eps=1e-6, chunk_dim=1) self.to_k_ip = nn.Linear(ip_hidden_states_dim, hidden_size, bias=False) self.to_v_ip = nn.Linear(ip_hidden_states_dim, hidden_size, bias=False) self.norm_q = RMSNorm(head_dim, 1e-6) @@ -5081,7 +5080,7 @@ def __call__( encoder_hidden_states: torch.FloatTensor = None, attention_mask: Optional[torch.FloatTensor] = None, ip_hidden_states: torch.FloatTensor = None, - temb: torch.FloatTensor = None + temb: torch.FloatTensor = None, ) -> torch.FloatTensor: residual = hidden_states @@ -5170,7 +5169,9 @@ def __call__( img_key = torch.cat([img_key, ip_key], dim=2) img_value = torch.cat([img_value, ip_value], dim=2) - ip_hidden_states = F.scaled_dot_product_attention(img_query, img_key, img_value, dropout_p=0.0, is_causal=False) + ip_hidden_states = F.scaled_dot_product_attention( + img_query, img_key, img_value, dropout_p=0.0, is_causal=False + ) ip_hidden_states = ip_hidden_states.transpose(1, 2).view(batch_size, -1, attn.heads * head_dim) ip_hidden_states = ip_hidden_states.to(img_query.dtype) diff --git a/src/diffusers/models/embeddings.py b/src/diffusers/models/embeddings.py index 0d380e240ea6..8d69eb8ab72f 100644 --- a/src/diffusers/models/embeddings.py +++ b/src/diffusers/models/embeddings.py @@ -2115,7 +2115,7 @@ def __init__( ) -> None: super().__init__() - self.scale = dim_head ** -0.5 + self.scale = dim_head**-0.5 self.dim_head = dim_head self.heads = heads inner_dim = dim_head * heads @@ -2135,6 +2135,7 @@ def forward(self, x, latents, shift=None, scale=None): latent (torch.Tensor): latent features shape (b, n2, D) """ + def reshape_tensor(x, heads): bs, length, _ = x.shape # (bs, length, width) --> (bs, length, n_heads, dim_per_head) @@ -2169,7 +2170,7 @@ def reshape_tensor(x, heads): out = out.permute(0, 2, 1, 3).reshape(b, l, -1) return self.to_out(out) - + # Modified from https://github.com/mlfoundations/open_flamingo/blob/main/open_flamingo/src/helpers.py class TimePerceiverResampler(nn.Module): @@ -2188,12 +2189,12 @@ def __init__( timestep_freq_shift: int = 0, ) -> None: super().__init__() - - self.latents = nn.Parameter(torch.randn(1, num_queries, hidden_dim) / hidden_dim ** 0.5) + + self.latents = nn.Parameter(torch.randn(1, num_queries, hidden_dim) / hidden_dim**0.5) self.proj_in = nn.Linear(embed_dim, hidden_dim) self.proj_out = nn.Linear(hidden_dim, output_dim) self.norm_out = nn.LayerNorm(output_dim) - + ff_inner_dim = int(hidden_dim * ffn_ratio) self.layers = nn.ModuleList([]) for _ in range(depth): @@ -2210,10 +2211,7 @@ def __init__( nn.Linear(ff_inner_dim, hidden_dim, bias=False), ), # adaLN - nn.Sequential( - nn.SiLU(), - nn.Linear(hidden_dim, ff_inner_dim, bias=True) - ) + nn.Sequential(nn.SiLU(), nn.Linear(hidden_dim, ff_inner_dim, bias=True)), ] ) ) @@ -2227,7 +2225,7 @@ def forward(self, x, timestep, need_temb=False): timestep_emb = self.time_embedding(timestep_emb, None) latents = self.latents.repeat(x.size(0), 1, 1) - + x = self.proj_in(x) x = x + timestep_emb[:, None] @@ -2242,7 +2240,7 @@ def forward(self, x, timestep, need_temb=False): if idx_ff == 0 and isinstance(layer_ff, nn.LayerNorm): # adaLN latents = latents * (1 + scale_mlp.unsqueeze(1)) + shift_mlp.unsqueeze(1) latents = latents + res - + latents = self.proj_out(latents) latents = self.norm_out(latents) diff --git a/src/diffusers/models/transformers/transformer_sd3.py b/src/diffusers/models/transformers/transformer_sd3.py index 0e46a86abf22..fdc318fe9cc7 100644 --- a/src/diffusers/models/transformers/transformer_sd3.py +++ b/src/diffusers/models/transformers/transformer_sd3.py @@ -24,8 +24,8 @@ Attention, AttentionProcessor, FusedJointAttnProcessor2_0, - JointAttnProcessor2_0, IPAdapterJointAttnProcessor2_0, + JointAttnProcessor2_0, ) from ...models.modeling_utils import ModelMixin, load_model_dict_into_meta from ...models.normalization import AdaLayerNormContinuous, AdaLayerNormZero @@ -376,7 +376,7 @@ def _load_ip_adapter_weights(self, state_dict: Dict, low_cpu_mem_usage: bool): hidden_dim=hidden_dim, heads=heads, num_queries=num_queries, - timestep_in_dim=timestep_in_dim + timestep_in_dim=timestep_in_dim, ).to(device=self.device, dtype=self.dtype) if not low_cpu_mem_usage: @@ -470,7 +470,9 @@ def custom_forward(*inputs): ) elif not is_skip: encoder_hidden_states, hidden_states = block( - hidden_states=hidden_states, encoder_hidden_states=encoder_hidden_states, temb=temb, + hidden_states=hidden_states, + encoder_hidden_states=encoder_hidden_states, + temb=temb, joint_attention_kwargs=joint_attention_kwargs, ) diff --git a/src/diffusers/pipelines/stable_diffusion_3/pipeline_stable_diffusion_3.py b/src/diffusers/pipelines/stable_diffusion_3/pipeline_stable_diffusion_3.py index 2bee2fe87914..0f5c457ca7e6 100644 --- a/src/diffusers/pipelines/stable_diffusion_3/pipeline_stable_diffusion_3.py +++ b/src/diffusers/pipelines/stable_diffusion_3/pipeline_stable_diffusion_3.py @@ -17,16 +17,16 @@ import torch from transformers import ( + BaseImageProcessor, CLIPTextModelWithProjection, CLIPTokenizer, + PreTrainedModel, T5EncoderModel, T5TokenizerFast, - PreTrainedModel, - BaseImageProcessor, ) -from ...image_processor import VaeImageProcessor, PipelineImageInput -from ...loaders import FromSingleFileMixin, SD3LoraLoaderMixin, SD3IPAdapterMixin +from ...image_processor import PipelineImageInput, VaeImageProcessor +from ...loaders import FromSingleFileMixin, SD3IPAdapterMixin, SD3LoraLoaderMixin from ...models.autoencoders import AutoencoderKL from ...models.transformers import SD3Transformer2DModel from ...schedulers import FlowMatchEulerDiscreteScheduler @@ -184,7 +184,7 @@ def __init__( text_encoder_3: T5EncoderModel, tokenizer_3: T5TokenizerFast, image_encoder: PreTrainedModel = None, - feature_extractor: BaseImageProcessor = None + feature_extractor: BaseImageProcessor = None, ): super().__init__() @@ -199,7 +199,7 @@ def __init__( transformer=transformer, scheduler=scheduler, image_encoder=image_encoder, - feature_extractor=feature_extractor + feature_extractor=feature_extractor, ) self.vae_scale_factor = ( 2 ** (len(self.vae.config.block_out_channels) - 1) if hasattr(self, "vae") and self.vae is not None else 8 @@ -678,7 +678,7 @@ def num_timesteps(self): @property def interrupt(self): return self._interrupt - + # Adapted from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_xl.StableDiffusionXLPipeline.encode_image def encode_image(self, image): if not isinstance(image, torch.Tensor): @@ -687,8 +687,10 @@ def encode_image(self, image): image = image.to(device=self.device, dtype=self.dtype) image_enc_hidden_states = self.image_encoder(image, output_hidden_states=True).hidden_states[-2] - uncond_image_enc_hidden_states = self.image_encoder(torch.zeros_like(image), output_hidden_states=True).hidden_states[-2] - + uncond_image_enc_hidden_states = self.image_encoder( + torch.zeros_like(image), output_hidden_states=True + ).hidden_states[-2] + return image_enc_hidden_states, uncond_image_enc_hidden_states # Adapted from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_xl.StableDiffusionXLPipeline.prepare_ip_adapter_image_embeds @@ -696,7 +698,7 @@ def prepare_ip_adapter_image_embeds( self, ip_adapter_image, ip_adapter_image_embeds, device, num_images_per_prompt, do_classifier_free_guidance ): if ip_adapter_image_embeds is None: - single_image_embeds, single_negative_image_embeds = self.encode_image(ip_adapter_image) + single_image_embeds, single_negative_image_embeds = self.encode_image(ip_adapter_image) else: for single_image_embeds in ip_adapter_image_embeds: if do_classifier_free_guidance: @@ -705,13 +707,13 @@ def prepare_ip_adapter_image_embeds( single_image_embeds = ip_adapter_image_embeds single_image_embeds = torch.cat([single_image_embeds] * num_images_per_prompt, dim=0) - + if do_classifier_free_guidance: single_negative_image_embeds = torch.cat([single_negative_image_embeds] * num_images_per_prompt, dim=0) single_image_embeds = torch.cat([single_negative_image_embeds, single_image_embeds], dim=0) return single_image_embeds.to(device=device) - + @torch.no_grad() @replace_example_docstring(EXAMPLE_DOC_STRING) def __call__( @@ -979,15 +981,12 @@ def __call__( need_temb=True, ) - image_prompt_embeds = dict( - ip_hidden_states=ip_hidden_states, - temb=temb - ) + image_prompt_embeds = {"ip_hidden_states": ip_hidden_states, "temb": temb} if self.joint_attention_kwargs is None: self._joint_attention_kwargs = image_prompt_embeds else: - self._joint_attention_kwargs.update(**image_prompt_embeds) + self._joint_attention_kwargs.update(**image_prompt_embeds) noise_pred = self.transformer( hidden_states=latent_model_input, From 89c4e6343c431b15b546e748914638256ac74704 Mon Sep 17 00:00:00 2001 From: Daniel Regado <danielregado@gmail.com> Date: Mon, 9 Dec 2024 15:42:56 +0000 Subject: [PATCH 10/44] Quality and style checks --- .../pipeline_stable_diffusion_3_ipa.py | 126 ++++++++---------- 1 file changed, 53 insertions(+), 73 deletions(-) diff --git a/examples/community/pipeline_stable_diffusion_3_ipa.py b/examples/community/pipeline_stable_diffusion_3_ipa.py index 490f65b7a719..d7830178e67f 100644 --- a/examples/community/pipeline_stable_diffusion_3_ipa.py +++ b/examples/community/pipeline_stable_diffusion_3_ipa.py @@ -13,11 +13,13 @@ # limitations under the License. import inspect +import math from typing import Any, Callable, Dict, List, Optional, Union import torch import torch.nn as nn import torch.nn.functional as F +from einops import rearrange from transformers import ( CLIPTextModelWithProjection, CLIPTokenizer, @@ -28,6 +30,11 @@ from diffusers.image_processor import VaeImageProcessor from diffusers.loaders import FromSingleFileMixin, SD3LoraLoaderMixin from diffusers.models.autoencoders import AutoencoderKL +from diffusers.models.embeddings import TimestepEmbedding, Timesteps +from diffusers.models.normalization import RMSNorm +from diffusers.models.transformers import SD3Transformer2DModel +from diffusers.pipelines.pipeline_utils import DiffusionPipeline +from diffusers.pipelines.stable_diffusion_3.pipeline_output import StableDiffusion3PipelineOutput from diffusers.schedulers import FlowMatchEulerDiscreteScheduler from diffusers.utils import ( USE_PEFT_BACKEND, @@ -38,15 +45,6 @@ unscale_lora_layers, ) from diffusers.utils.torch_utils import randn_tensor -from diffusers.pipelines.pipeline_utils import DiffusionPipeline -from diffusers.pipelines.stable_diffusion_3.pipeline_output import StableDiffusion3PipelineOutput - -from diffusers.models.transformers import SD3Transformer2DModel -from diffusers.models.normalization import RMSNorm -from einops import rearrange -import math - -from diffusers.models.embeddings import Timesteps, TimestepEmbedding if is_torch_xla_available(): @@ -86,10 +84,10 @@ def FeedForward(dim, mult=4): nn.Linear(inner_dim, dim, bias=False), ) - + def reshape_tensor(x, heads): bs, length, width = x.shape - #(bs, length, width) --> (bs, length, n_heads, dim_per_head) + # (bs, length, width) --> (bs, length, n_heads, dim_per_head) x = x.view(bs, length, heads, -1) # (bs, length, n_heads, dim_per_head) --> (bs, n_heads, length, dim_per_head) x = x.transpose(1, 2) @@ -113,7 +111,6 @@ def __init__(self, *, dim, dim_head=64, heads=8): self.to_kv = nn.Linear(dim, inner_dim * 2, bias=False) self.to_out = nn.Linear(inner_dim, dim, bias=False) - def forward(self, x, latents, shift=None, scale=None): """ Args: @@ -127,23 +124,23 @@ def forward(self, x, latents, shift=None, scale=None): if shift is not None and scale is not None: latents = latents * (1 + scale.unsqueeze(1)) + shift.unsqueeze(1) - + b, l, _ = latents.shape q = self.to_q(latents) kv_input = torch.cat((x, latents), dim=-2) k, v = self.to_kv(kv_input).chunk(2, dim=-1) - + q = reshape_tensor(q, self.heads) k = reshape_tensor(k, self.heads) v = reshape_tensor(v, self.heads) # attention scale = 1 / math.sqrt(math.sqrt(self.dim_head)) - weight = (q * scale) @ (k * scale).transpose(-2, -1) # More stable with f16 than dividing afterwards + weight = (q * scale) @ (k * scale).transpose(-2, -1) # More stable with f16 than dividing afterwards weight = torch.softmax(weight.float(), dim=-1).type(weight.dtype) out = weight @ v - + out = out.permute(0, 2, 1, 3).reshape(b, l, -1) return self.to_out(out) @@ -166,14 +163,14 @@ def __init__( timestep_freq_shift=0, ): super().__init__() - + self.latents = nn.Parameter(torch.randn(1, num_queries, dim) / dim**0.5) - + self.proj_in = nn.Linear(embedding_dim, dim) self.proj_out = nn.Linear(dim, output_dim) self.norm_out = nn.LayerNorm(output_dim) - + self.layers = nn.ModuleList([]) for _ in range(depth): self.layers.append( @@ -184,7 +181,7 @@ def __init__( # ff FeedForward(dim=dim, mult=ff_mult), # adaLN - nn.Sequential(nn.SiLU(), nn.Linear(dim, 4 * dim, bias=True)) + nn.Sequential(nn.SiLU(), nn.Linear(dim, 4 * dim, bias=True)), ] ) ) @@ -199,12 +196,11 @@ def __init__( # nn.Linear(timestep_out_dim, 6 * timestep_out_dim, bias=True) # ) - def forward(self, x, timestep, need_temb=False): timestep_emb = self.embedding_time(x, timestep) # bs, dim latents = self.latents.repeat(x.size(0), 1, 1) - + x = self.proj_in(x) x = x + timestep_emb[:, None] @@ -221,7 +217,7 @@ def forward(self, x, timestep, need_temb=False): latents = latents + res # latents = ff(latents) + latents - + latents = self.proj_out(latents) latents = self.norm_out(latents) @@ -230,10 +226,7 @@ def forward(self, x, timestep, need_temb=False): else: return latents - - def embedding_time(self, sample, timestep): - # 1. time timesteps = timestep if not torch.is_tensor(timesteps): @@ -271,15 +264,12 @@ class AdaLayerNorm(nn.Module): num_embeddings (`int`): The size of the embeddings dictionary. """ - def __init__(self, embedding_dim: int, time_embedding_dim=None, mode='normal'): + def __init__(self, embedding_dim: int, time_embedding_dim=None, mode="normal"): super().__init__() self.silu = nn.SiLU() - num_params_dict = dict( - zero=6, - normal=2, - ) - num_params = num_params_dict[mode] + + num_params = 2 if mode == "normal" else 6 self.linear = nn.Linear(time_embedding_dim or embedding_dim, num_params * embedding_dim, bias=True) self.norm = nn.LayerNorm(embedding_dim, elementwise_affine=False, eps=1e-6) self.mode = mode @@ -287,16 +277,16 @@ def __init__(self, embedding_dim: int, time_embedding_dim=None, mode='normal'): def forward( self, x, - hidden_dtype = None, - emb = None, + hidden_dtype=None, + emb=None, ): emb = self.linear(self.silu(emb)) - if self.mode == 'normal': + if self.mode == "normal": shift_msa, scale_msa = emb.chunk(2, dim=1) x = self.norm(x) * (1 + scale_msa[:, None]) + shift_msa[:, None] return x - elif self.mode == 'zero': + elif self.mode == "zero": shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp = emb.chunk(6, dim=1) x = self.norm(x) * (1 + scale_msa[:, None]) + shift_msa[:, None] return x, gate_msa, shift_mlp, scale_mlp, gate_mlp @@ -323,7 +313,6 @@ def __init__( self.norm_k = RMSNorm(head_dim, 1e-6) self.norm_ip_k = RMSNorm(head_dim, 1e-6) - def __call__( self, attn, @@ -396,9 +385,8 @@ def __call__( if not attn.context_pre_only: encoder_hidden_states = attn.to_add_out(encoder_hidden_states) - # IPadapter - ip_hidden_states = emb_dict.get('ip_hidden_states', None) + ip_hidden_states = emb_dict.get("ip_hidden_states", None) ip_hidden_states = self.get_ip_hidden_states( attn, img_query, @@ -407,11 +395,10 @@ def __call__( img_value, None, None, - emb_dict['temb'], + emb_dict["temb"], ) if ip_hidden_states is not None: - hidden_states = hidden_states + ip_hidden_states * emb_dict.get('scale', 1.0) - + hidden_states = hidden_states + ip_hidden_states * emb_dict.get("scale", 1.0) # linear proj hidden_states = attn.to_out[0](hidden_states) @@ -423,12 +410,13 @@ def __call__( else: return hidden_states - - def get_ip_hidden_states(self, attn, query, ip_hidden_states, img_key=None, img_value=None, text_key=None, text_value=None, temb=None): + def get_ip_hidden_states( + self, attn, query, ip_hidden_states, img_key=None, img_value=None, text_key=None, text_value=None, temb=None + ): if ip_hidden_states is None: return None - - if not hasattr(self, 'to_k_ip') or not hasattr(self, 'to_v_ip'): + + if not hasattr(self, "to_k_ip") or not hasattr(self, "to_v_ip"): return None # norm ip input @@ -439,11 +427,11 @@ def get_ip_hidden_states(self, attn, query, ip_hidden_states, img_key=None, img_ ip_value = self.to_v_ip(norm_ip_hidden_states) # reshape - query = rearrange(query, 'b l (h d) -> b h l d', h=attn.heads) - img_key = rearrange(img_key, 'b l (h d) -> b h l d', h=attn.heads) - img_value = rearrange(img_value, 'b l (h d) -> b h l d', h=attn.heads) - ip_key = rearrange(ip_key, 'b l (h d) -> b h l d', h=attn.heads) - ip_value = rearrange(ip_value, 'b l (h d) -> b h l d', h=attn.heads) + query = rearrange(query, "b l (h d) -> b h l d", h=attn.heads) + img_key = rearrange(img_key, "b l (h d) -> b h l d", h=attn.heads) + img_value = rearrange(img_value, "b l (h d) -> b h l d", h=attn.heads) + ip_key = rearrange(ip_key, "b l (h d) -> b h l d", h=attn.heads) + ip_value = rearrange(ip_value, "b l (h d) -> b h l d", h=attn.heads) # norm query = self.norm_q(query) @@ -454,9 +442,9 @@ def get_ip_hidden_states(self, attn, query, ip_hidden_states, img_key=None, img_ key = torch.cat([img_key, ip_key], dim=2) value = torch.cat([img_value, ip_value], dim=2) - # + # ip_hidden_states = F.scaled_dot_product_attention(query, key, value, dropout_p=0.0, is_causal=False) - ip_hidden_states = rearrange(ip_hidden_states, 'b h l d -> b l (h d)') + ip_hidden_states = rearrange(ip_hidden_states, "b h l d -> b l (h d)") ip_hidden_states = ip_hidden_states.to(query.dtype) return ip_hidden_states @@ -1049,10 +1037,10 @@ def num_timesteps(self): def interrupt(self): return self._interrupt - @torch.inference_mode() def init_ipadapter(self, ip_adapter_path, image_encoder_path, nb_token, output_dim=2432): - from transformers import SiglipVisionModel, SiglipImageProcessor + from transformers import SiglipImageProcessor, SiglipVisionModel + state_dict = torch.load(ip_adapter_path, map_location="cpu") device, dtype = self.transformer.device, self.transformer.dtype @@ -1084,14 +1072,13 @@ def init_ipadapter(self, ip_adapter_path, image_encoder_path, nb_token, output_d self.image_proj_model = image_proj_model - attn_procs = {} transformer = self.transformer for idx_name, name in enumerate(transformer.attn_processors.keys()): hidden_size = transformer.config.attention_head_dim * transformer.config.num_attention_heads ip_hidden_states_dim = transformer.config.attention_head_dim * transformer.config.num_attention_heads ip_encoder_hidden_states_dim = transformer.config.caption_projection_dim - + attn_procs[name] = JointIPAttnProcessor( hidden_size=hidden_size, cross_attention_dim=transformer.config.caption_projection_dim, @@ -1107,10 +1094,8 @@ def init_ipadapter(self, ip_adapter_path, image_encoder_path, nb_token, output_d key_name = tmp_ip_layers.load_state_dict(state_dict["ip_adapter"], strict=False) print(f"=> loading ip_adapter: {key_name}") - @torch.inference_mode() def encode_clip_image_emb(self, clip_image, device, dtype): - # clip clip_image_tensor = self.clip_image_processor(images=clip_image, return_tensors="pt").pixel_values clip_image_tensor = clip_image_tensor.to(device, dtype=dtype) @@ -1119,8 +1104,6 @@ def encode_clip_image_emb(self, clip_image, device, dtype): return clip_image_embeds - - @torch.no_grad() @replace_example_docstring(EXAMPLE_DOC_STRING) def __call__( @@ -1150,7 +1133,6 @@ def __call__( callback_on_step_end: Optional[Callable[[int, int, Dict], None]] = None, callback_on_step_end_tensor_inputs: List[str] = ["latents"], max_sequence_length: int = 256, - # ipa clip_image=None, ipadapter_scale=1.0, @@ -1349,18 +1331,16 @@ def __call__( timestep = t.expand(latent_model_input.shape[0]) image_prompt_embeds, timestep_emb = self.image_proj_model( - clip_image_embeds, - timestep.to(dtype=latents.dtype), - need_temb=True + clip_image_embeds, timestep.to(dtype=latents.dtype), need_temb=True ) - joint_attention_kwargs = dict( - emb_dict=dict( - ip_hidden_states=image_prompt_embeds, - temb=timestep_emb, - scale=ipadapter_scale, - ) - ) + joint_attention_kwargs = { + "emb_dict": { + "ip_hidden_states": image_prompt_embeds, + "temb": timestep_emb, + "scale": ipadapter_scale, + } + } noise_pred = self.transformer( hidden_states=latent_model_input, @@ -1420,4 +1400,4 @@ def __call__( if not return_dict: return (image,) - return StableDiffusion3PipelineOutput(images=image) \ No newline at end of file + return StableDiffusion3PipelineOutput(images=image) From 27d574fe8538f7dc3935943885fa4a5824da7308 Mon Sep 17 00:00:00 2001 From: Daniel Regado <danielregado@gmail.com> Date: Mon, 9 Dec 2024 16:33:12 +0000 Subject: [PATCH 11/44] Handle None joint_attention_kwargs in JointTransformerBlock --- src/diffusers/models/attention.py | 6 +++++- 1 file changed, 5 insertions(+), 1 deletion(-) diff --git a/src/diffusers/models/attention.py b/src/diffusers/models/attention.py index e8a47fa8226a..d478917412f3 100644 --- a/src/diffusers/models/attention.py +++ b/src/diffusers/models/attention.py @@ -192,7 +192,7 @@ def forward( hidden_states: torch.FloatTensor, encoder_hidden_states: torch.FloatTensor, temb: torch.FloatTensor, - joint_attention_kwargs: Dict[str, Any] = {}, + joint_attention_kwargs: Optional[Dict[str, Any]] = None, ): if self.use_dual_attention: norm_hidden_states, gate_msa, shift_mlp, scale_mlp, gate_mlp, norm_hidden_states2, gate_msa2 = self.norm1( @@ -208,6 +208,10 @@ def forward( encoder_hidden_states, emb=temb ) + # Empty dict if None is passed + if joint_attention_kwargs is None: + joint_attention_kwargs = {} + # Attention. attn_output, context_attn_output = self.attn( hidden_states=norm_hidden_states, From 0a4864838ce4492f1d87a6439f12274ee1d48ef6 Mon Sep 17 00:00:00 2001 From: hlky <hlky@hlky.ac> Date: Mon, 9 Dec 2024 21:20:23 +0000 Subject: [PATCH 12/44] Fix test_components_function --- .../stable_diffusion_3/test_pipeline_stable_diffusion_3.py | 2 ++ .../test_pipeline_stable_diffusion_3_img2img.py | 2 ++ .../test_pipeline_stable_diffusion_3_inpaint.py | 2 ++ 3 files changed, 6 insertions(+) diff --git a/tests/pipelines/stable_diffusion_3/test_pipeline_stable_diffusion_3.py b/tests/pipelines/stable_diffusion_3/test_pipeline_stable_diffusion_3.py index 07ce5487f256..a6f718ae4fbb 100644 --- a/tests/pipelines/stable_diffusion_3/test_pipeline_stable_diffusion_3.py +++ b/tests/pipelines/stable_diffusion_3/test_pipeline_stable_diffusion_3.py @@ -103,6 +103,8 @@ def get_dummy_components(self): "tokenizer_3": tokenizer_3, "transformer": transformer, "vae": vae, + "image_encoder": None, + "feature_extractor": None, } def get_dummy_inputs(self, device, seed=0): diff --git a/tests/pipelines/stable_diffusion_3/test_pipeline_stable_diffusion_3_img2img.py b/tests/pipelines/stable_diffusion_3/test_pipeline_stable_diffusion_3_img2img.py index 695954163c8f..358c8d9aee12 100644 --- a/tests/pipelines/stable_diffusion_3/test_pipeline_stable_diffusion_3_img2img.py +++ b/tests/pipelines/stable_diffusion_3/test_pipeline_stable_diffusion_3_img2img.py @@ -105,6 +105,8 @@ def get_dummy_components(self): "tokenizer_3": tokenizer_3, "transformer": transformer, "vae": vae, + "image_encoder": None, + "feature_extractor": None, } def get_dummy_inputs(self, device, seed=0): diff --git a/tests/pipelines/stable_diffusion_3/test_pipeline_stable_diffusion_3_inpaint.py b/tests/pipelines/stable_diffusion_3/test_pipeline_stable_diffusion_3_inpaint.py index 464ef6d017df..a37ea3fc39c5 100644 --- a/tests/pipelines/stable_diffusion_3/test_pipeline_stable_diffusion_3_inpaint.py +++ b/tests/pipelines/stable_diffusion_3/test_pipeline_stable_diffusion_3_inpaint.py @@ -106,6 +106,8 @@ def get_dummy_components(self): "tokenizer_3": tokenizer_3, "transformer": transformer, "vae": vae, + "image_encoder": None, + "feature_extractor": None, } def get_dummy_inputs(self, device, seed=0): From 10d0a0623c25bee16fca7bd7b58163e76bc2ad82 Mon Sep 17 00:00:00 2001 From: hlky <hlky@hlky.ac> Date: Mon, 9 Dec 2024 21:35:39 +0000 Subject: [PATCH 13/44] Remove from img2img/inpaint for now --- .../test_pipeline_stable_diffusion_3_img2img.py | 2 -- .../test_pipeline_stable_diffusion_3_inpaint.py | 2 -- 2 files changed, 4 deletions(-) diff --git a/tests/pipelines/stable_diffusion_3/test_pipeline_stable_diffusion_3_img2img.py b/tests/pipelines/stable_diffusion_3/test_pipeline_stable_diffusion_3_img2img.py index 358c8d9aee12..695954163c8f 100644 --- a/tests/pipelines/stable_diffusion_3/test_pipeline_stable_diffusion_3_img2img.py +++ b/tests/pipelines/stable_diffusion_3/test_pipeline_stable_diffusion_3_img2img.py @@ -105,8 +105,6 @@ def get_dummy_components(self): "tokenizer_3": tokenizer_3, "transformer": transformer, "vae": vae, - "image_encoder": None, - "feature_extractor": None, } def get_dummy_inputs(self, device, seed=0): diff --git a/tests/pipelines/stable_diffusion_3/test_pipeline_stable_diffusion_3_inpaint.py b/tests/pipelines/stable_diffusion_3/test_pipeline_stable_diffusion_3_inpaint.py index a37ea3fc39c5..464ef6d017df 100644 --- a/tests/pipelines/stable_diffusion_3/test_pipeline_stable_diffusion_3_inpaint.py +++ b/tests/pipelines/stable_diffusion_3/test_pipeline_stable_diffusion_3_inpaint.py @@ -106,8 +106,6 @@ def get_dummy_components(self): "tokenizer_3": tokenizer_3, "transformer": transformer, "vae": vae, - "image_encoder": None, - "feature_extractor": None, } def get_dummy_inputs(self, device, seed=0): From c78c4fd12a474ecdd2295f20f9ac0e8d8fda0f73 Mon Sep 17 00:00:00 2001 From: Daniel Regado <danielregado@gmail.com> Date: Tue, 10 Dec 2024 02:23:56 +0000 Subject: [PATCH 14/44] Fixed loading ip_adapter state dict --- src/diffusers/models/transformers/transformer_sd3.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/diffusers/models/transformers/transformer_sd3.py b/src/diffusers/models/transformers/transformer_sd3.py index fdc318fe9cc7..32fbb2a6c103 100644 --- a/src/diffusers/models/transformers/transformer_sd3.py +++ b/src/diffusers/models/transformers/transformer_sd3.py @@ -357,7 +357,7 @@ def _load_ip_adapter_weights(self, state_dict: Dict, low_cpu_mem_usage: bool): if not low_cpu_mem_usage: attn_procs[name].load_state_dict(layer_state_dict[idx], strict=True) else: - load_model_dict_into_meta(attn_procs[name], layer_state_dict, device=self.device, dtype=self.dtype) + load_model_dict_into_meta(attn_procs[name], layer_state_dict[idx], device=self.device, dtype=self.dtype) self.set_attn_processor(attn_procs) From 0f6c6079df30e9469cd8b606c6b6c833a2eba986 Mon Sep 17 00:00:00 2001 From: Daniel Regado <danielregado@gmail.com> Date: Tue, 10 Dec 2024 11:58:24 +0000 Subject: [PATCH 15/44] Simpler image encoding --- .../pipeline_stable_diffusion_3.py | 28 ++++++++----------- 1 file changed, 12 insertions(+), 16 deletions(-) diff --git a/src/diffusers/pipelines/stable_diffusion_3/pipeline_stable_diffusion_3.py b/src/diffusers/pipelines/stable_diffusion_3/pipeline_stable_diffusion_3.py index 0f5c457ca7e6..2b0c82ee700b 100644 --- a/src/diffusers/pipelines/stable_diffusion_3/pipeline_stable_diffusion_3.py +++ b/src/diffusers/pipelines/stable_diffusion_3/pipeline_stable_diffusion_3.py @@ -686,33 +686,29 @@ def encode_image(self, image): image = image.to(device=self.device, dtype=self.dtype) - image_enc_hidden_states = self.image_encoder(image, output_hidden_states=True).hidden_states[-2] - uncond_image_enc_hidden_states = self.image_encoder( - torch.zeros_like(image), output_hidden_states=True - ).hidden_states[-2] - - return image_enc_hidden_states, uncond_image_enc_hidden_states + return self.image_encoder(image, output_hidden_states=True).hidden_states[-2] # Adapted from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_xl.StableDiffusionXLPipeline.prepare_ip_adapter_image_embeds def prepare_ip_adapter_image_embeds( self, ip_adapter_image, ip_adapter_image_embeds, device, num_images_per_prompt, do_classifier_free_guidance ): if ip_adapter_image_embeds is None: - single_image_embeds, single_negative_image_embeds = self.encode_image(ip_adapter_image) + single_image_embeds = self.encode_image(ip_adapter_image) + if do_classifier_free_guidance: + single_negative_image_embeds = torch.zeros_like(single_image_embeds) else: - for single_image_embeds in ip_adapter_image_embeds: - if do_classifier_free_guidance: - single_negative_image_embeds, single_image_embeds = single_image_embeds.chunk(2) - else: - single_image_embeds = ip_adapter_image_embeds + if do_classifier_free_guidance: + single_negative_image_embeds, single_image_embeds = single_image_embeds.chunk(2) + else: + single_image_embeds = ip_adapter_image_embeds - single_image_embeds = torch.cat([single_image_embeds] * num_images_per_prompt, dim=0) + image_embeds = torch.cat([single_image_embeds] * num_images_per_prompt, dim=0) if do_classifier_free_guidance: - single_negative_image_embeds = torch.cat([single_negative_image_embeds] * num_images_per_prompt, dim=0) - single_image_embeds = torch.cat([single_negative_image_embeds, single_image_embeds], dim=0) + negative_image_embeds = torch.cat([single_negative_image_embeds] * num_images_per_prompt, dim=0) + image_embeds = torch.cat([negative_image_embeds, image_embeds], dim=0) - return single_image_embeds.to(device=device) + return image_embeds.to(device=device) @torch.no_grad() @replace_example_docstring(EXAMPLE_DOC_STRING) From 53fd40db8021e0899be21fb5b2f3a49a9cf90a7f Mon Sep 17 00:00:00 2001 From: Daniel Regado <danielregado@gmail.com> Date: Tue, 10 Dec 2024 11:58:49 +0000 Subject: [PATCH 16/44] Style check --- src/diffusers/models/transformers/transformer_sd3.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/src/diffusers/models/transformers/transformer_sd3.py b/src/diffusers/models/transformers/transformer_sd3.py index 32fbb2a6c103..f8ee6e77b2c3 100644 --- a/src/diffusers/models/transformers/transformer_sd3.py +++ b/src/diffusers/models/transformers/transformer_sd3.py @@ -357,7 +357,9 @@ def _load_ip_adapter_weights(self, state_dict: Dict, low_cpu_mem_usage: bool): if not low_cpu_mem_usage: attn_procs[name].load_state_dict(layer_state_dict[idx], strict=True) else: - load_model_dict_into_meta(attn_procs[name], layer_state_dict[idx], device=self.device, dtype=self.dtype) + load_model_dict_into_meta( + attn_procs[name], layer_state_dict[idx], device=self.device, dtype=self.dtype + ) self.set_attn_processor(attn_procs) From 8039599fa7b06cb03744c8b1ec0ac45f9ac31f6a Mon Sep 17 00:00:00 2001 From: Daniel Regado <danielregado@gmail.com> Date: Tue, 10 Dec 2024 12:30:48 +0000 Subject: [PATCH 17/44] Better checks for image prompt considering ip_adapter scale --- src/diffusers/loaders/ip_adapter.py | 18 ++++++++++++++++++ .../pipeline_stable_diffusion_3.py | 6 ++++-- 2 files changed, 22 insertions(+), 2 deletions(-) diff --git a/src/diffusers/loaders/ip_adapter.py b/src/diffusers/loaders/ip_adapter.py index 34f3a151262f..d2aca385c49c 100644 --- a/src/diffusers/loaders/ip_adapter.py +++ b/src/diffusers/loaders/ip_adapter.py @@ -356,6 +356,24 @@ def unload_ip_adapter(self): class SD3IPAdapterMixin: """Mixin for handling StableDiffusion 3 IP Adapters.""" + @property + def is_ip_adapter_active(self) -> bool: + r"""Checks if any ip_adapter attention processor have scale > 0. + + IP-Adapter scale controls the influence of the image prompt versus text prompt. When this value is set to 0, + image is irrelevant. + + Returns: + `bool`: True when ip_adapter is loaded and any ip_adapter layer scale > 0. + """ + scales = [ + attn_proc.scale + for attn_proc in self.transformer.attn_processors.values() + if isinstance(attn_proc, IPAdapterJointAttnProcessor2_0) + ] + + return len(scales) > 0 and any(scale > 0 for scale in scales) + @validate_hf_hub_args def load_ip_adapter( self, diff --git a/src/diffusers/pipelines/stable_diffusion_3/pipeline_stable_diffusion_3.py b/src/diffusers/pipelines/stable_diffusion_3/pipeline_stable_diffusion_3.py index 2b0c82ee700b..aa9eb15c2bdd 100644 --- a/src/diffusers/pipelines/stable_diffusion_3/pipeline_stable_diffusion_3.py +++ b/src/diffusers/pipelines/stable_diffusion_3/pipeline_stable_diffusion_3.py @@ -950,7 +950,9 @@ def __call__( ) # 6. Prepare image embeddings - if ip_adapter_image is not None or ip_adapter_image_embeds is not None: + # Either image is passed and ip_adapter is active + # Or image_embeds are passed directly + if (ip_adapter_image is not None and self.is_ip_adapter_active()) or ip_adapter_image_embeds is not None: image_embeds = self.prepare_ip_adapter_image_embeds( ip_adapter_image, ip_adapter_image_embeds, @@ -970,7 +972,7 @@ def __call__( # broadcast to batch dimension in a way that's compatible with ONNX/Core ML timestep = t.expand(latent_model_input.shape[0]) - if ip_adapter_image is not None or ip_adapter_image_embeds is not None: + if image_embeds is not None: ip_hidden_states, temb = self.transformer.image_proj( image_embeds, timestep.to(dtype=latents.dtype), From 7333bfc352e1c2123577d82aebc93025db50021d Mon Sep 17 00:00:00 2001 From: Daniel Regado <danielregado@gmail.com> Date: Tue, 10 Dec 2024 12:36:33 +0000 Subject: [PATCH 18/44] Minor change correcting checking for ip_adapter embeds --- .../stable_diffusion_3/pipeline_stable_diffusion_3.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/src/diffusers/pipelines/stable_diffusion_3/pipeline_stable_diffusion_3.py b/src/diffusers/pipelines/stable_diffusion_3/pipeline_stable_diffusion_3.py index aa9eb15c2bdd..c0738292bdaa 100644 --- a/src/diffusers/pipelines/stable_diffusion_3/pipeline_stable_diffusion_3.py +++ b/src/diffusers/pipelines/stable_diffusion_3/pipeline_stable_diffusion_3.py @@ -953,7 +953,7 @@ def __call__( # Either image is passed and ip_adapter is active # Or image_embeds are passed directly if (ip_adapter_image is not None and self.is_ip_adapter_active()) or ip_adapter_image_embeds is not None: - image_embeds = self.prepare_ip_adapter_image_embeds( + ip_adapter_image_embeds = self.prepare_ip_adapter_image_embeds( ip_adapter_image, ip_adapter_image_embeds, device, @@ -972,9 +972,9 @@ def __call__( # broadcast to batch dimension in a way that's compatible with ONNX/Core ML timestep = t.expand(latent_model_input.shape[0]) - if image_embeds is not None: + if ip_adapter_image_embeds is not None: ip_hidden_states, temb = self.transformer.image_proj( - image_embeds, + ip_adapter_image_embeds, timestep.to(dtype=latents.dtype), need_temb=True, ) From a87895e94ba97545b09eae35107e9a3103d73fa1 Mon Sep 17 00:00:00 2001 From: Daniel Regado <danielregado@gmail.com> Date: Tue, 10 Dec 2024 12:44:46 +0000 Subject: [PATCH 19/44] Removing old check of ip_adapter scale --- src/diffusers/models/attention_processor.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/diffusers/models/attention_processor.py b/src/diffusers/models/attention_processor.py index a8fb8bf28f89..6372dfb79271 100644 --- a/src/diffusers/models/attention_processor.py +++ b/src/diffusers/models/attention_processor.py @@ -5145,7 +5145,7 @@ def __call__( encoder_hidden_states = attn.to_add_out(encoder_hidden_states) # IP Adapter - if self.scale != 0 and ip_hidden_states is not None: + if ip_hidden_states is not None: # Norm image features norm_ip_hidden_states = self.norm_ip(ip_hidden_states, temb=temb) From 4ba374a688218716ec5b6bddfd98fbecb1e45496 Mon Sep 17 00:00:00 2001 From: Daniel Regado <danielregado@gmail.com> Date: Tue, 10 Dec 2024 17:59:05 +0000 Subject: [PATCH 20/44] Refactor of image_proj (testing) --- src/diffusers/models/attention_processor.py | 3 +- src/diffusers/models/embeddings.py | 142 ++++++------------ .../models/transformers/transformer_sd3.py | 35 +++-- 3 files changed, 70 insertions(+), 110 deletions(-) diff --git a/src/diffusers/models/attention_processor.py b/src/diffusers/models/attention_processor.py index 6372dfb79271..4eed41e2a5ba 100644 --- a/src/diffusers/models/attention_processor.py +++ b/src/diffusers/models/attention_processor.py @@ -3927,9 +3927,8 @@ def __call__( key = attn.norm_k(key) # the output of sdp = (batch, num_heads, seq_len, head_dim) - # TODO: add support for attn.scale when we move to Torch 2.1 hidden_states = F.scaled_dot_product_attention( - query, key, value, attn_mask=attention_mask, dropout_p=0.0, is_causal=False + query, key, value, attn_mask=attention_mask, dropout_p=0.0, is_causal=False, scale=attn.scale ) hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim) diff --git a/src/diffusers/models/embeddings.py b/src/diffusers/models/embeddings.py index 8d69eb8ab72f..deb76e6040aa 100644 --- a/src/diffusers/models/embeddings.py +++ b/src/diffusers/models/embeddings.py @@ -21,7 +21,7 @@ from ..utils import deprecate from .activations import FP32SiLU, get_activation -from .attention_processor import Attention +from .attention_processor import Attention, FusedAttnProcessor2_0 def get_timestep_embedding( @@ -2104,76 +2104,55 @@ def forward(self, id_embeds: torch.Tensor) -> torch.Tensor: return out -# Modified from https://github.com/mlfoundations/open_flamingo/blob/main/open_flamingo/src/helpers.py -class TimePerceiverAttention(nn.Module): +class IPAdapterTimeImageProjectionBlock(nn.Module): def __init__( self, - *, - dim: int, + hidden_dim: int = 768, dim_head: int = 64, - heads: int = 8, + heads: int = 16, + ffn_ratio: float = 4, ) -> None: super().__init__() + from .attention import FeedForward - self.scale = dim_head**-0.5 - self.dim_head = dim_head - self.heads = heads - inner_dim = dim_head * heads - - self.norm1 = nn.LayerNorm(dim) - self.norm2 = nn.LayerNorm(dim) - - self.to_q = nn.Linear(dim, inner_dim, bias=False) - self.to_kv = nn.Linear(dim, inner_dim * 2, bias=False) - self.to_out = nn.Linear(inner_dim, dim, bias=False) - - def forward(self, x, latents, shift=None, scale=None): - """ - Args: - x (torch.Tensor): image features - shape (b, n1, D) - latent (torch.Tensor): latent features - shape (b, n2, D) - """ - - def reshape_tensor(x, heads): - bs, length, _ = x.shape - # (bs, length, width) --> (bs, length, n_heads, dim_per_head) - x = x.view(bs, length, heads, -1) - # (bs, length, n_heads, dim_per_head) --> (bs, n_heads, length, dim_per_head) - x = x.transpose(1, 2) - # (bs, n_heads, length, dim_per_head) --> (bs*n_heads, length, dim_per_head) - return x.reshape(bs, heads, length, -1) - - x = self.norm1(x) - latents = self.norm2(latents) - - if shift is not None and scale is not None: - latents = latents * (1 + scale.unsqueeze(1)) + shift.unsqueeze(1) - - b, l, _ = latents.shape + self.ln0 = nn.LayerNorm(hidden_dim) + self.ln1 = nn.LayerNorm(hidden_dim) + self.attn = Attention( + query_dim=hidden_dim, + cross_attention_dim=hidden_dim, + dim_head=dim_head, + heads=heads, + bias=False, + out_bias=False, + processor=FusedAttnProcessor2_0(), + ) + self.ff = FeedForward(hidden_dim, hidden_dim, activation_fn="gelu", mult=ffn_ratio, bias=False) - q = self.to_q(latents) - kv_input = torch.cat((x, latents), dim=-2) - k, v = self.to_kv(kv_input).chunk(2, dim=-1) + # AdaLayerNorm + self.adaln_silu = nn.SiLU() + self.adaln_proj = nn.Linear(hidden_dim, 4 * hidden_dim) + self.adaln_norm = nn.LayerNorm(hidden_dim) - q = reshape_tensor(q, self.heads) - k = reshape_tensor(k, self.heads) - v = reshape_tensor(v, self.heads) + # Custom scale cannot be passed in constructor + self.attn.scale = 1 / math.sqrt(math.sqrt(dim_head)) + self.attn.fuse_projections() + self.attn.to_k = None + self.attn.to_v = None - # attention - scale = 1 / math.sqrt(math.sqrt(self.dim_head)) - weight = (q * scale) @ (k * scale).transpose(-2, -1) # More stable with f16 than dividing afterwards - weight = torch.softmax(weight.float(), dim=-1).type(weight.dtype) - out = weight @ v + def forward(self, x, latents, timestep_emb): + shift_msa, scale_msa, shift_mlp, scale_mlp = self.adaln_proj(self.adaln_silu(timestep_emb)) - out = out.permute(0, 2, 1, 3).reshape(b, l, -1) + x = self.ln0(x) + latents = self.ln1(latents) * (1 + scale_msa[:, None]) + shift_msa[:, None] + latents = self.attn(x, latents) + latents - return self.to_out(out) + residual = latents + latents = self.adaln_norm(latents) * (1 + scale_mlp[:, None]) + shift_mlp[:, None] + return self.ff(latents) + residual # Modified from https://github.com/mlfoundations/open_flamingo/blob/main/open_flamingo/src/helpers.py -class TimePerceiverResampler(nn.Module): +class IPAdapterTimeImageProjection(nn.Module): def __init__( self, embed_dim: int = 1152, @@ -2189,65 +2168,32 @@ def __init__( timestep_freq_shift: int = 0, ) -> None: super().__init__() - self.latents = nn.Parameter(torch.randn(1, num_queries, hidden_dim) / hidden_dim**0.5) self.proj_in = nn.Linear(embed_dim, hidden_dim) self.proj_out = nn.Linear(hidden_dim, output_dim) self.norm_out = nn.LayerNorm(output_dim) - - ff_inner_dim = int(hidden_dim * ffn_ratio) - self.layers = nn.ModuleList([]) - for _ in range(depth): - self.layers.append( - nn.ModuleList( - [ - # msa - TimePerceiverAttention(dim=hidden_dim, dim_head=dim_head, heads=heads), - # ff - nn.Sequential( - nn.LayerNorm(hidden_dim), - nn.Linear(hidden_dim, ff_inner_dim, bias=False), - nn.GELU(), - nn.Linear(ff_inner_dim, hidden_dim, bias=False), - ), - # adaLN - nn.Sequential(nn.SiLU(), nn.Linear(hidden_dim, ff_inner_dim, bias=True)), - ] - ) - ) - - # Time + self.layers = nn.ModuleList( + [IPAdapterTimeImageProjectionBlock(hidden_dim, dim_head, heads, ffn_ratio) for _ in range(depth)] + ) self.time_proj = Timesteps(timestep_in_dim, timestep_flip_sin_to_cos, timestep_freq_shift) self.time_embedding = TimestepEmbedding(timestep_in_dim, hidden_dim, act_fn="silu") - def forward(self, x, timestep, need_temb=False): + def forward(self, x, timestep): timestep_emb = self.time_proj(timestep).to(dtype=x.dtype) - timestep_emb = self.time_embedding(timestep_emb, None) + timestep_emb = self.time_embedding(timestep_emb) latents = self.latents.repeat(x.size(0), 1, 1) x = self.proj_in(x) x = x + timestep_emb[:, None] - for attn, ff, adaLN_modulation in self.layers: - shift_msa, scale_msa, shift_mlp, scale_mlp = adaLN_modulation(timestep_emb).chunk(4, dim=1) - latents = attn(x, latents, shift_msa, scale_msa) + latents - - res = latents - for idx_ff in range(len(ff)): - layer_ff = ff[idx_ff] - latents = layer_ff(latents) - if idx_ff == 0 and isinstance(layer_ff, nn.LayerNorm): # adaLN - latents = latents * (1 + scale_mlp.unsqueeze(1)) + shift_mlp.unsqueeze(1) - latents = latents + res + for block in self.layers: + latents = block(x, latents, timestep_emb) latents = self.proj_out(latents) latents = self.norm_out(latents) - if need_temb: - return latents, timestep_emb - else: - return latents + return latents, timestep_emb class MultiIPAdapterImageProjection(nn.Module): diff --git a/src/diffusers/models/transformers/transformer_sd3.py b/src/diffusers/models/transformers/transformer_sd3.py index f8ee6e77b2c3..b7b5397bcc8b 100644 --- a/src/diffusers/models/transformers/transformer_sd3.py +++ b/src/diffusers/models/transformers/transformer_sd3.py @@ -31,7 +31,7 @@ from ...models.normalization import AdaLayerNormContinuous, AdaLayerNormZero from ...utils import USE_PEFT_BACKEND, is_torch_version, logging, scale_lora_layers, unscale_lora_layers from ...utils.torch_utils import maybe_allow_in_graph -from ..embeddings import CombinedTimestepTextProjEmbeddings, PatchEmbed, TimePerceiverResampler +from ..embeddings import CombinedTimestepTextProjEmbeddings, IPAdapterTimeImageProjection, PatchEmbed from ..modeling_outputs import Transformer2DModelOutput @@ -363,16 +363,31 @@ def _load_ip_adapter_weights(self, state_dict: Dict, low_cpu_mem_usage: bool): self.set_attn_processor(attn_procs) + # Convert image_proj state dict to diffusers + image_proj_state_dict = {} + for key, value in state_dict["image_proj"].items(): + for idx in range(4): + key = key.replace(f"layers.{idx}.0.norm1", f"layers.{idx}.ln0") + key = key.replace(f"layers.{idx}.0.norm2", f"layers.{idx}.ln1") + key = key.replace(f"layers.{idx}.0.to_q", f"layers.{idx}.attn.to_q") + key = key.replace(f"layers.{idx}.0.to_kv", f"layers.{idx}.attn.to_kv") + key = key.replace(f"layers.{idx}.0.to_out", f"layers.{idx}.attn.to_out.0") + key = key.replace(f"layers.{idx}.1.0", f"layers.{idx}.adaln_norm") + key = key.replace(f"layers.{idx}.1.1", f"layers.{idx}.ff.net.0.proj") + key = key.replace(f"layers.{idx}.1.3", f"layers.{idx}.ff.net.2") + key = key.replace(f"layers.{idx}.2.1", f"layers.{idx}.adaln_proj") + image_proj_state_dict[key] = value + # Image projetion parameters - embed_dim = state_dict["image_proj"]["proj_in.weight"].shape[1] - output_dim = state_dict["image_proj"]["proj_out.weight"].shape[0] - hidden_dim = state_dict["image_proj"]["latents"].shape[2] - heads = state_dict["image_proj"]["layers.0.0.to_q.weight"].shape[0] // 64 - num_queries = state_dict["image_proj"]["latents"].shape[1] - timestep_in_dim = state_dict["image_proj"]["time_embedding.linear_1.weight"].shape[1] + embed_dim = image_proj_state_dict["proj_in.weight"].shape[1] + output_dim = image_proj_state_dict["proj_out.weight"].shape[0] + hidden_dim = image_proj_state_dict["proj_in.weight"].shape[0] + heads = image_proj_state_dict["layers.0.attn.to_q.weight"].shape[0] // 64 + num_queries = image_proj_state_dict["latents"].shape[1] + timestep_in_dim = image_proj_state_dict["time_embedding.linear_1.weight"].shape[1] # Image projection - self.image_proj = TimePerceiverResampler( + self.image_proj = IPAdapterTimeImageProjection( embed_dim=embed_dim, output_dim=output_dim, hidden_dim=hidden_dim, @@ -382,9 +397,9 @@ def _load_ip_adapter_weights(self, state_dict: Dict, low_cpu_mem_usage: bool): ).to(device=self.device, dtype=self.dtype) if not low_cpu_mem_usage: - self.image_proj.load_state_dict(state_dict["image_proj"], strict=True) + self.image_proj.load_state_dict(image_proj_state_dict, strict=True) else: - load_model_dict_into_meta(self.image_proj, state_dict["image_proj"], device=self.device, dtype=self.dtype) + load_model_dict_into_meta(self.image_proj, image_proj_state_dict, device=self.device, dtype=self.dtype) def forward( self, From 819dd3e028383d51f56669b01b6ca2f5088548c6 Mon Sep 17 00:00:00 2001 From: Daniel Regado <danielregado@gmail.com> Date: Tue, 10 Dec 2024 18:03:09 +0000 Subject: [PATCH 21/44] Revert "Removing old check of ip_adapter scale" This reverts commit a87895e94ba97545b09eae35107e9a3103d73fa1. --- src/diffusers/models/attention_processor.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/diffusers/models/attention_processor.py b/src/diffusers/models/attention_processor.py index 4eed41e2a5ba..c76909d8255f 100644 --- a/src/diffusers/models/attention_processor.py +++ b/src/diffusers/models/attention_processor.py @@ -5144,7 +5144,7 @@ def __call__( encoder_hidden_states = attn.to_add_out(encoder_hidden_states) # IP Adapter - if ip_hidden_states is not None: + if self.scale != 0 and ip_hidden_states is not None: # Norm image features norm_ip_hidden_states = self.norm_ip(ip_hidden_states, temb=temb) From ea32e13227545d8e7ce2ff283324b0403953169c Mon Sep 17 00:00:00 2001 From: Daniel Regado <danielregado@gmail.com> Date: Tue, 10 Dec 2024 18:18:50 +0000 Subject: [PATCH 22/44] Corrected property check --- .../pipelines/stable_diffusion_3/pipeline_stable_diffusion_3.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/diffusers/pipelines/stable_diffusion_3/pipeline_stable_diffusion_3.py b/src/diffusers/pipelines/stable_diffusion_3/pipeline_stable_diffusion_3.py index c0738292bdaa..26063d58d116 100644 --- a/src/diffusers/pipelines/stable_diffusion_3/pipeline_stable_diffusion_3.py +++ b/src/diffusers/pipelines/stable_diffusion_3/pipeline_stable_diffusion_3.py @@ -952,7 +952,7 @@ def __call__( # 6. Prepare image embeddings # Either image is passed and ip_adapter is active # Or image_embeds are passed directly - if (ip_adapter_image is not None and self.is_ip_adapter_active()) or ip_adapter_image_embeds is not None: + if (ip_adapter_image is not None and self.is_ip_adapter_active) or ip_adapter_image_embeds is not None: ip_adapter_image_embeds = self.prepare_ip_adapter_image_embeds( ip_adapter_image, ip_adapter_image_embeds, From f60751f7831bd938fb5cae168eee89bfe28e8371 Mon Sep 17 00:00:00 2001 From: Daniel Regado <danielregado@gmail.com> Date: Wed, 11 Dec 2024 10:41:51 +0000 Subject: [PATCH 23/44] Corrected forward() of IPAdapterTimeImageProjectionBlock --- src/diffusers/models/embeddings.py | 8 +++++--- .../stable_diffusion_3/pipeline_stable_diffusion_3.py | 4 +--- 2 files changed, 6 insertions(+), 6 deletions(-) diff --git a/src/diffusers/models/embeddings.py b/src/diffusers/models/embeddings.py index deb76e6040aa..f67a4d2bd4cb 100644 --- a/src/diffusers/models/embeddings.py +++ b/src/diffusers/models/embeddings.py @@ -2133,18 +2133,20 @@ def __init__( self.adaln_proj = nn.Linear(hidden_dim, 4 * hidden_dim) self.adaln_norm = nn.LayerNorm(hidden_dim) - # Custom scale cannot be passed in constructor + # Set scale and fuse KV self.attn.scale = 1 / math.sqrt(math.sqrt(dim_head)) self.attn.fuse_projections() self.attn.to_k = None self.attn.to_v = None def forward(self, x, latents, timestep_emb): - shift_msa, scale_msa, shift_mlp, scale_mlp = self.adaln_proj(self.adaln_silu(timestep_emb)) + emb = self.adaln_proj(self.adaln_silu(timestep_emb)) + shift_msa, scale_msa, shift_mlp, scale_mlp = emb.chunk(4, dim=1) + residual = latents x = self.ln0(x) latents = self.ln1(latents) * (1 + scale_msa[:, None]) + shift_msa[:, None] - latents = self.attn(x, latents) + latents + latents = self.attn(latents, torch.cat((x, latents), dim=-2)) + residual residual = latents latents = self.adaln_norm(latents) * (1 + scale_mlp[:, None]) + shift_mlp[:, None] diff --git a/src/diffusers/pipelines/stable_diffusion_3/pipeline_stable_diffusion_3.py b/src/diffusers/pipelines/stable_diffusion_3/pipeline_stable_diffusion_3.py index 26063d58d116..9efc433d75a5 100644 --- a/src/diffusers/pipelines/stable_diffusion_3/pipeline_stable_diffusion_3.py +++ b/src/diffusers/pipelines/stable_diffusion_3/pipeline_stable_diffusion_3.py @@ -974,9 +974,7 @@ def __call__( if ip_adapter_image_embeds is not None: ip_hidden_states, temb = self.transformer.image_proj( - ip_adapter_image_embeds, - timestep.to(dtype=latents.dtype), - need_temb=True, + ip_adapter_image_embeds, timestep.to(dtype=latents.dtype) ) image_prompt_embeds = {"ip_hidden_states": ip_hidden_states, "temb": temb} From b0aa5cb3fd9c48a646a08f61813abe6f3a90d907 Mon Sep 17 00:00:00 2001 From: Daniel Regado <danielregado@gmail.com> Date: Thu, 12 Dec 2024 12:05:44 +0000 Subject: [PATCH 24/44] IPAdapterTimeImageProjectionBlock now uses original attention implementation --- src/diffusers/models/embeddings.py | 37 ++++++++++++++++--- .../models/transformers/transformer_sd3.py | 3 +- 2 files changed, 33 insertions(+), 7 deletions(-) diff --git a/src/diffusers/models/embeddings.py b/src/diffusers/models/embeddings.py index f67a4d2bd4cb..2e0c773d8e04 100644 --- a/src/diffusers/models/embeddings.py +++ b/src/diffusers/models/embeddings.py @@ -2107,10 +2107,10 @@ def forward(self, id_embeds: torch.Tensor) -> torch.Tensor: class IPAdapterTimeImageProjectionBlock(nn.Module): def __init__( self, - hidden_dim: int = 768, + hidden_dim: int = 1280, dim_head: int = 64, - heads: int = 16, - ffn_ratio: float = 4, + heads: int = 20, + ffn_ratio: int = 4, ) -> None: super().__init__() from .attention import FeedForward @@ -2124,7 +2124,6 @@ def __init__( heads=heads, bias=False, out_bias=False, - processor=FusedAttnProcessor2_0(), ) self.ff = FeedForward(hidden_dim, hidden_dim, activation_fn="gelu", mult=ffn_ratio, bias=False) @@ -2133,21 +2132,47 @@ def __init__( self.adaln_proj = nn.Linear(hidden_dim, 4 * hidden_dim) self.adaln_norm = nn.LayerNorm(hidden_dim) - # Set scale and fuse KV + # Set attention scale and fuse KV self.attn.scale = 1 / math.sqrt(math.sqrt(dim_head)) self.attn.fuse_projections() self.attn.to_k = None self.attn.to_v = None def forward(self, x, latents, timestep_emb): + # Shift and scale for AdaLayerNorm emb = self.adaln_proj(self.adaln_silu(timestep_emb)) shift_msa, scale_msa, shift_mlp, scale_mlp = emb.chunk(4, dim=1) + # Fused Attention residual = latents x = self.ln0(x) latents = self.ln1(latents) * (1 + scale_msa[:, None]) + shift_msa[:, None] - latents = self.attn(latents, torch.cat((x, latents), dim=-2)) + residual + batch_size = latents.shape[0] + + query = self.attn.to_q(latents) + kv_input = torch.cat((x, latents), dim=-2) + kv = self.attn.to_kv(kv_input) + split_size = kv.shape[-1] // 2 + key, value = torch.split(kv, split_size, dim=-1) + + inner_dim = key.shape[-1] + head_dim = inner_dim // self.attn.heads + + query = query.view(batch_size, -1, self.attn.heads, head_dim).transpose(1, 2) + key = key.view(batch_size, -1, self.attn.heads, head_dim).transpose(1, 2) + value = value.view(batch_size, -1, self.attn.heads, head_dim).transpose(1, 2) + + weight = (query * self.attn.scale) @ (key * self.attn.scale).transpose(-2, -1) + weight = torch.softmax(weight.float(), dim=-1).type(weight.dtype) + latents = weight @ value + + latents = latents.transpose(1, 2).reshape(batch_size, -1, self.attn.heads * head_dim) + latents = self.attn.to_out[0](latents) + latents = self.attn.to_out[1](latents) + latents = latents + residual + + ## FeedForward residual = latents latents = self.adaln_norm(latents) * (1 + scale_mlp[:, None]) + shift_mlp[:, None] return self.ff(latents) + residual diff --git a/src/diffusers/models/transformers/transformer_sd3.py b/src/diffusers/models/transformers/transformer_sd3.py index b7b5397bcc8b..fb46f2c03f6f 100644 --- a/src/diffusers/models/transformers/transformer_sd3.py +++ b/src/diffusers/models/transformers/transformer_sd3.py @@ -366,7 +366,8 @@ def _load_ip_adapter_weights(self, state_dict: Dict, low_cpu_mem_usage: bool): # Convert image_proj state dict to diffusers image_proj_state_dict = {} for key, value in state_dict["image_proj"].items(): - for idx in range(4): + if key.startswith("layers."): + idx = key.split(".")[1] key = key.replace(f"layers.{idx}.0.norm1", f"layers.{idx}.ln0") key = key.replace(f"layers.{idx}.0.norm2", f"layers.{idx}.ln1") key = key.replace(f"layers.{idx}.0.to_q", f"layers.{idx}.attn.to_q") From b3dc69aef7d0769f61c50b6eafe12057b8e80287 Mon Sep 17 00:00:00 2001 From: Daniel Regado <danielregado@gmail.com> Date: Thu, 12 Dec 2024 15:34:31 +0000 Subject: [PATCH 25/44] Clean-up and make style --- .../pipeline_stable_diffusion_3_ipa.py | 1403 ----------------- src/diffusers/models/attention_processor.py | 3 +- src/diffusers/models/embeddings.py | 6 +- 3 files changed, 5 insertions(+), 1407 deletions(-) delete mode 100644 examples/community/pipeline_stable_diffusion_3_ipa.py diff --git a/examples/community/pipeline_stable_diffusion_3_ipa.py b/examples/community/pipeline_stable_diffusion_3_ipa.py deleted file mode 100644 index d7830178e67f..000000000000 --- a/examples/community/pipeline_stable_diffusion_3_ipa.py +++ /dev/null @@ -1,1403 +0,0 @@ -# Copyright 2024 Stability AI, The HuggingFace Team and The InstantX Team. All rights reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -import inspect -import math -from typing import Any, Callable, Dict, List, Optional, Union - -import torch -import torch.nn as nn -import torch.nn.functional as F -from einops import rearrange -from transformers import ( - CLIPTextModelWithProjection, - CLIPTokenizer, - T5EncoderModel, - T5TokenizerFast, -) - -from diffusers.image_processor import VaeImageProcessor -from diffusers.loaders import FromSingleFileMixin, SD3LoraLoaderMixin -from diffusers.models.autoencoders import AutoencoderKL -from diffusers.models.embeddings import TimestepEmbedding, Timesteps -from diffusers.models.normalization import RMSNorm -from diffusers.models.transformers import SD3Transformer2DModel -from diffusers.pipelines.pipeline_utils import DiffusionPipeline -from diffusers.pipelines.stable_diffusion_3.pipeline_output import StableDiffusion3PipelineOutput -from diffusers.schedulers import FlowMatchEulerDiscreteScheduler -from diffusers.utils import ( - USE_PEFT_BACKEND, - is_torch_xla_available, - logging, - replace_example_docstring, - scale_lora_layers, - unscale_lora_layers, -) -from diffusers.utils.torch_utils import randn_tensor - - -if is_torch_xla_available(): - import torch_xla.core.xla_model as xm - - XLA_AVAILABLE = True -else: - XLA_AVAILABLE = False - - -logger = logging.get_logger(__name__) # pylint: disable=invalid-name - -EXAMPLE_DOC_STRING = """ - Examples: - ```py - >>> import torch - >>> from diffusers import StableDiffusion3Pipeline - - >>> pipe = StableDiffusion3Pipeline.from_pretrained( - ... "stabilityai/stable-diffusion-3-medium-diffusers", torch_dtype=torch.float16 - ... ) - >>> pipe.to("cuda") - >>> prompt = "A cat holding a sign that says hello world" - >>> image = pipe(prompt).images[0] - >>> image.save("sd3.png") - ``` -""" - - -# FFN -def FeedForward(dim, mult=4): - inner_dim = int(dim * mult) - return nn.Sequential( - nn.LayerNorm(dim), - nn.Linear(dim, inner_dim, bias=False), - nn.GELU(), - nn.Linear(inner_dim, dim, bias=False), - ) - - -def reshape_tensor(x, heads): - bs, length, width = x.shape - # (bs, length, width) --> (bs, length, n_heads, dim_per_head) - x = x.view(bs, length, heads, -1) - # (bs, length, n_heads, dim_per_head) --> (bs, n_heads, length, dim_per_head) - x = x.transpose(1, 2) - # (bs, n_heads, length, dim_per_head) --> (bs*n_heads, length, dim_per_head) - x = x.reshape(bs, heads, length, -1) - return x - - -class PerceiverAttention(nn.Module): - def __init__(self, *, dim, dim_head=64, heads=8): - super().__init__() - self.scale = dim_head**-0.5 - self.dim_head = dim_head - self.heads = heads - inner_dim = dim_head * heads - - self.norm1 = nn.LayerNorm(dim) - self.norm2 = nn.LayerNorm(dim) - - self.to_q = nn.Linear(dim, inner_dim, bias=False) - self.to_kv = nn.Linear(dim, inner_dim * 2, bias=False) - self.to_out = nn.Linear(inner_dim, dim, bias=False) - - def forward(self, x, latents, shift=None, scale=None): - """ - Args: - x (torch.Tensor): image features - shape (b, n1, D) - latent (torch.Tensor): latent features - shape (b, n2, D) - """ - x = self.norm1(x) - latents = self.norm2(latents) - - if shift is not None and scale is not None: - latents = latents * (1 + scale.unsqueeze(1)) + shift.unsqueeze(1) - - b, l, _ = latents.shape - - q = self.to_q(latents) - kv_input = torch.cat((x, latents), dim=-2) - k, v = self.to_kv(kv_input).chunk(2, dim=-1) - - q = reshape_tensor(q, self.heads) - k = reshape_tensor(k, self.heads) - v = reshape_tensor(v, self.heads) - - # attention - scale = 1 / math.sqrt(math.sqrt(self.dim_head)) - weight = (q * scale) @ (k * scale).transpose(-2, -1) # More stable with f16 than dividing afterwards - weight = torch.softmax(weight.float(), dim=-1).type(weight.dtype) - out = weight @ v - - out = out.permute(0, 2, 1, 3).reshape(b, l, -1) - - return self.to_out(out) - - -# modified from https://github.com/mlfoundations/open_flamingo/blob/main/open_flamingo/src/helpers.py -class TimeResampler(nn.Module): - def __init__( - self, - dim=1024, - depth=8, - dim_head=64, - heads=16, - num_queries=8, - embedding_dim=768, - output_dim=1024, - ff_mult=4, - timestep_in_dim=320, - timestep_flip_sin_to_cos=True, - timestep_freq_shift=0, - ): - super().__init__() - - self.latents = nn.Parameter(torch.randn(1, num_queries, dim) / dim**0.5) - - self.proj_in = nn.Linear(embedding_dim, dim) - - self.proj_out = nn.Linear(dim, output_dim) - self.norm_out = nn.LayerNorm(output_dim) - - self.layers = nn.ModuleList([]) - for _ in range(depth): - self.layers.append( - nn.ModuleList( - [ - # msa - PerceiverAttention(dim=dim, dim_head=dim_head, heads=heads), - # ff - FeedForward(dim=dim, mult=ff_mult), - # adaLN - nn.Sequential(nn.SiLU(), nn.Linear(dim, 4 * dim, bias=True)), - ] - ) - ) - - # time - self.time_proj = Timesteps(timestep_in_dim, timestep_flip_sin_to_cos, timestep_freq_shift) - self.time_embedding = TimestepEmbedding(timestep_in_dim, dim, act_fn="silu") - - # adaLN - # self.adaLN_modulation = nn.Sequential( - # nn.SiLU(), - # nn.Linear(timestep_out_dim, 6 * timestep_out_dim, bias=True) - # ) - - def forward(self, x, timestep, need_temb=False): - timestep_emb = self.embedding_time(x, timestep) # bs, dim - - latents = self.latents.repeat(x.size(0), 1, 1) - - x = self.proj_in(x) - x = x + timestep_emb[:, None] - - for attn, ff, adaLN_modulation in self.layers: - shift_msa, scale_msa, shift_mlp, scale_mlp = adaLN_modulation(timestep_emb).chunk(4, dim=1) - latents = attn(x, latents, shift_msa, scale_msa) + latents - - res = latents - for idx_ff in range(len(ff)): - layer_ff = ff[idx_ff] - latents = layer_ff(latents) - if idx_ff == 0 and isinstance(layer_ff, nn.LayerNorm): # adaLN - latents = latents * (1 + scale_mlp.unsqueeze(1)) + shift_mlp.unsqueeze(1) - latents = latents + res - - # latents = ff(latents) + latents - - latents = self.proj_out(latents) - latents = self.norm_out(latents) - - if need_temb: - return latents, timestep_emb - else: - return latents - - def embedding_time(self, sample, timestep): - # 1. time - timesteps = timestep - if not torch.is_tensor(timesteps): - # TODO: this requires sync between CPU and GPU. So try to pass timesteps as tensors if you can - # This would be a good case for the `match` statement (Python 3.10+) - is_mps = sample.device.type == "mps" - if isinstance(timestep, float): - dtype = torch.float32 if is_mps else torch.float64 - else: - dtype = torch.int32 if is_mps else torch.int64 - timesteps = torch.tensor([timesteps], dtype=dtype, device=sample.device) - elif len(timesteps.shape) == 0: - timesteps = timesteps[None].to(sample.device) - - # broadcast to batch dimension in a way that's compatible with ONNX/Core ML - timesteps = timesteps.expand(sample.shape[0]) - - t_emb = self.time_proj(timesteps) - - # timesteps does not contain any weights and will always return f32 tensors - # but time_embedding might actually be running in fp16. so we need to cast here. - # there might be better ways to encapsulate this. - t_emb = t_emb.to(dtype=sample.dtype) - - emb = self.time_embedding(t_emb, None) - return emb - - -class AdaLayerNorm(nn.Module): - """ - Norm layer adaptive layer norm zero (adaLN-Zero). - - Parameters: - embedding_dim (`int`): The size of each embedding vector. - num_embeddings (`int`): The size of the embeddings dictionary. - """ - - def __init__(self, embedding_dim: int, time_embedding_dim=None, mode="normal"): - super().__init__() - - self.silu = nn.SiLU() - - num_params = 2 if mode == "normal" else 6 - self.linear = nn.Linear(time_embedding_dim or embedding_dim, num_params * embedding_dim, bias=True) - self.norm = nn.LayerNorm(embedding_dim, elementwise_affine=False, eps=1e-6) - self.mode = mode - - def forward( - self, - x, - hidden_dtype=None, - emb=None, - ): - emb = self.linear(self.silu(emb)) - if self.mode == "normal": - shift_msa, scale_msa = emb.chunk(2, dim=1) - x = self.norm(x) * (1 + scale_msa[:, None]) + shift_msa[:, None] - return x - - elif self.mode == "zero": - shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp = emb.chunk(6, dim=1) - x = self.norm(x) * (1 + scale_msa[:, None]) + shift_msa[:, None] - return x, gate_msa, shift_mlp, scale_mlp, gate_mlp - - -class JointIPAttnProcessor(torch.nn.Module): - """Attention processor used typically in processing the SD3-like self-attention projections.""" - - def __init__( - self, - hidden_size=None, - cross_attention_dim=None, - ip_hidden_states_dim=None, - ip_encoder_hidden_states_dim=None, - head_dim=None, - timesteps_emb_dim=1280, - ): - super().__init__() - - self.norm_ip = AdaLayerNorm(ip_hidden_states_dim, time_embedding_dim=timesteps_emb_dim) - self.to_k_ip = nn.Linear(ip_hidden_states_dim, hidden_size, bias=False) - self.to_v_ip = nn.Linear(ip_hidden_states_dim, hidden_size, bias=False) - self.norm_q = RMSNorm(head_dim, 1e-6) - self.norm_k = RMSNorm(head_dim, 1e-6) - self.norm_ip_k = RMSNorm(head_dim, 1e-6) - - def __call__( - self, - attn, - hidden_states: torch.FloatTensor, - encoder_hidden_states: torch.FloatTensor = None, - attention_mask: Optional[torch.FloatTensor] = None, - emb_dict=None, - *args, - **kwargs, - ) -> torch.FloatTensor: - residual = hidden_states - - batch_size = hidden_states.shape[0] - - # `sample` projections. - query = attn.to_q(hidden_states) - key = attn.to_k(hidden_states) - value = attn.to_v(hidden_states) - img_query = query - img_key = key - img_value = value - - inner_dim = key.shape[-1] - head_dim = inner_dim // attn.heads - - query = query.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2) - key = key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2) - value = value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2) - - if attn.norm_q is not None: - query = attn.norm_q(query) - if attn.norm_k is not None: - key = attn.norm_k(key) - - # `context` projections. - if encoder_hidden_states is not None: - encoder_hidden_states_query_proj = attn.add_q_proj(encoder_hidden_states) - encoder_hidden_states_key_proj = attn.add_k_proj(encoder_hidden_states) - encoder_hidden_states_value_proj = attn.add_v_proj(encoder_hidden_states) - - encoder_hidden_states_query_proj = encoder_hidden_states_query_proj.view( - batch_size, -1, attn.heads, head_dim - ).transpose(1, 2) - encoder_hidden_states_key_proj = encoder_hidden_states_key_proj.view( - batch_size, -1, attn.heads, head_dim - ).transpose(1, 2) - encoder_hidden_states_value_proj = encoder_hidden_states_value_proj.view( - batch_size, -1, attn.heads, head_dim - ).transpose(1, 2) - - if attn.norm_added_q is not None: - encoder_hidden_states_query_proj = attn.norm_added_q(encoder_hidden_states_query_proj) - if attn.norm_added_k is not None: - encoder_hidden_states_key_proj = attn.norm_added_k(encoder_hidden_states_key_proj) - - query = torch.cat([query, encoder_hidden_states_query_proj], dim=2) - key = torch.cat([key, encoder_hidden_states_key_proj], dim=2) - value = torch.cat([value, encoder_hidden_states_value_proj], dim=2) - - hidden_states = F.scaled_dot_product_attention(query, key, value, dropout_p=0.0, is_causal=False) - hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim) - hidden_states = hidden_states.to(query.dtype) - - if encoder_hidden_states is not None: - # Split the attention outputs. - hidden_states, encoder_hidden_states = ( - hidden_states[:, : residual.shape[1]], - hidden_states[:, residual.shape[1] :], - ) - if not attn.context_pre_only: - encoder_hidden_states = attn.to_add_out(encoder_hidden_states) - - # IPadapter - ip_hidden_states = emb_dict.get("ip_hidden_states", None) - ip_hidden_states = self.get_ip_hidden_states( - attn, - img_query, - ip_hidden_states, - img_key, - img_value, - None, - None, - emb_dict["temb"], - ) - if ip_hidden_states is not None: - hidden_states = hidden_states + ip_hidden_states * emb_dict.get("scale", 1.0) - - # linear proj - hidden_states = attn.to_out[0](hidden_states) - # dropout - hidden_states = attn.to_out[1](hidden_states) - - if encoder_hidden_states is not None: - return hidden_states, encoder_hidden_states - else: - return hidden_states - - def get_ip_hidden_states( - self, attn, query, ip_hidden_states, img_key=None, img_value=None, text_key=None, text_value=None, temb=None - ): - if ip_hidden_states is None: - return None - - if not hasattr(self, "to_k_ip") or not hasattr(self, "to_v_ip"): - return None - - # norm ip input - norm_ip_hidden_states = self.norm_ip(ip_hidden_states, emb=temb) - - # to k and v - ip_key = self.to_k_ip(norm_ip_hidden_states) - ip_value = self.to_v_ip(norm_ip_hidden_states) - - # reshape - query = rearrange(query, "b l (h d) -> b h l d", h=attn.heads) - img_key = rearrange(img_key, "b l (h d) -> b h l d", h=attn.heads) - img_value = rearrange(img_value, "b l (h d) -> b h l d", h=attn.heads) - ip_key = rearrange(ip_key, "b l (h d) -> b h l d", h=attn.heads) - ip_value = rearrange(ip_value, "b l (h d) -> b h l d", h=attn.heads) - - # norm - query = self.norm_q(query) - img_key = self.norm_k(img_key) - ip_key = self.norm_ip_k(ip_key) - - # cat img - key = torch.cat([img_key, ip_key], dim=2) - value = torch.cat([img_value, ip_value], dim=2) - - # - ip_hidden_states = F.scaled_dot_product_attention(query, key, value, dropout_p=0.0, is_causal=False) - ip_hidden_states = rearrange(ip_hidden_states, "b h l d -> b l (h d)") - ip_hidden_states = ip_hidden_states.to(query.dtype) - return ip_hidden_states - - -# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.retrieve_timesteps -def retrieve_timesteps( - scheduler, - num_inference_steps: Optional[int] = None, - device: Optional[Union[str, torch.device]] = None, - timesteps: Optional[List[int]] = None, - sigmas: Optional[List[float]] = None, - **kwargs, -): - """ - Calls the scheduler's `set_timesteps` method and retrieves timesteps from the scheduler after the call. Handles - custom timesteps. Any kwargs will be supplied to `scheduler.set_timesteps`. - - Args: - scheduler (`SchedulerMixin`): - The scheduler to get timesteps from. - num_inference_steps (`int`): - The number of diffusion steps used when generating samples with a pre-trained model. If used, `timesteps` - must be `None`. - device (`str` or `torch.device`, *optional*): - The device to which the timesteps should be moved to. If `None`, the timesteps are not moved. - timesteps (`List[int]`, *optional*): - Custom timesteps used to override the timestep spacing strategy of the scheduler. If `timesteps` is passed, - `num_inference_steps` and `sigmas` must be `None`. - sigmas (`List[float]`, *optional*): - Custom sigmas used to override the timestep spacing strategy of the scheduler. If `sigmas` is passed, - `num_inference_steps` and `timesteps` must be `None`. - - Returns: - `Tuple[torch.Tensor, int]`: A tuple where the first element is the timestep schedule from the scheduler and the - second element is the number of inference steps. - """ - if timesteps is not None and sigmas is not None: - raise ValueError("Only one of `timesteps` or `sigmas` can be passed. Please choose one to set custom values") - if timesteps is not None: - accepts_timesteps = "timesteps" in set(inspect.signature(scheduler.set_timesteps).parameters.keys()) - if not accepts_timesteps: - raise ValueError( - f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom" - f" timestep schedules. Please check whether you are using the correct scheduler." - ) - scheduler.set_timesteps(timesteps=timesteps, device=device, **kwargs) - timesteps = scheduler.timesteps - num_inference_steps = len(timesteps) - elif sigmas is not None: - accept_sigmas = "sigmas" in set(inspect.signature(scheduler.set_timesteps).parameters.keys()) - if not accept_sigmas: - raise ValueError( - f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom" - f" sigmas schedules. Please check whether you are using the correct scheduler." - ) - scheduler.set_timesteps(sigmas=sigmas, device=device, **kwargs) - timesteps = scheduler.timesteps - num_inference_steps = len(timesteps) - else: - scheduler.set_timesteps(num_inference_steps, device=device, **kwargs) - timesteps = scheduler.timesteps - return timesteps, num_inference_steps - - -class StableDiffusion3Pipeline(DiffusionPipeline, SD3LoraLoaderMixin, FromSingleFileMixin): - r""" - Args: - transformer ([`SD3Transformer2DModel`]): - Conditional Transformer (MMDiT) architecture to denoise the encoded image latents. - scheduler ([`FlowMatchEulerDiscreteScheduler`]): - A scheduler to be used in combination with `transformer` to denoise the encoded image latents. - vae ([`AutoencoderKL`]): - Variational Auto-Encoder (VAE) Model to encode and decode images to and from latent representations. - text_encoder ([`CLIPTextModelWithProjection`]): - [CLIP](https://huggingface.co/docs/transformers/model_doc/clip#transformers.CLIPTextModelWithProjection), - specifically the [clip-vit-large-patch14](https://huggingface.co/openai/clip-vit-large-patch14) variant, - with an additional added projection layer that is initialized with a diagonal matrix with the `hidden_size` - as its dimension. - text_encoder_2 ([`CLIPTextModelWithProjection`]): - [CLIP](https://huggingface.co/docs/transformers/model_doc/clip#transformers.CLIPTextModelWithProjection), - specifically the - [laion/CLIP-ViT-bigG-14-laion2B-39B-b160k](https://huggingface.co/laion/CLIP-ViT-bigG-14-laion2B-39B-b160k) - variant. - text_encoder_3 ([`T5EncoderModel`]): - Frozen text-encoder. Stable Diffusion 3 uses - [T5](https://huggingface.co/docs/transformers/model_doc/t5#transformers.T5EncoderModel), specifically the - [t5-v1_1-xxl](https://huggingface.co/google/t5-v1_1-xxl) variant. - tokenizer (`CLIPTokenizer`): - Tokenizer of class - [CLIPTokenizer](https://huggingface.co/docs/transformers/v4.21.0/en/model_doc/clip#transformers.CLIPTokenizer). - tokenizer_2 (`CLIPTokenizer`): - Second Tokenizer of class - [CLIPTokenizer](https://huggingface.co/docs/transformers/v4.21.0/en/model_doc/clip#transformers.CLIPTokenizer). - tokenizer_3 (`T5TokenizerFast`): - Tokenizer of class - [T5Tokenizer](https://huggingface.co/docs/transformers/model_doc/t5#transformers.T5Tokenizer). - """ - - model_cpu_offload_seq = "text_encoder->text_encoder_2->text_encoder_3->transformer->vae" - _optional_components = [] - _callback_tensor_inputs = ["latents", "prompt_embeds", "negative_prompt_embeds", "negative_pooled_prompt_embeds"] - - def __init__( - self, - transformer: SD3Transformer2DModel, - scheduler: FlowMatchEulerDiscreteScheduler, - vae: AutoencoderKL, - text_encoder: CLIPTextModelWithProjection, - tokenizer: CLIPTokenizer, - text_encoder_2: CLIPTextModelWithProjection, - tokenizer_2: CLIPTokenizer, - text_encoder_3: T5EncoderModel, - tokenizer_3: T5TokenizerFast, - ): - super().__init__() - - self.register_modules( - vae=vae, - text_encoder=text_encoder, - text_encoder_2=text_encoder_2, - text_encoder_3=text_encoder_3, - tokenizer=tokenizer, - tokenizer_2=tokenizer_2, - tokenizer_3=tokenizer_3, - transformer=transformer, - scheduler=scheduler, - ) - self.vae_scale_factor = ( - 2 ** (len(self.vae.config.block_out_channels) - 1) if hasattr(self, "vae") and self.vae is not None else 8 - ) - self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor) - self.tokenizer_max_length = ( - self.tokenizer.model_max_length if hasattr(self, "tokenizer") and self.tokenizer is not None else 77 - ) - self.default_sample_size = ( - self.transformer.config.sample_size - if hasattr(self, "transformer") and self.transformer is not None - else 128 - ) - - def _get_t5_prompt_embeds( - self, - prompt: Union[str, List[str]] = None, - num_images_per_prompt: int = 1, - max_sequence_length: int = 256, - device: Optional[torch.device] = None, - dtype: Optional[torch.dtype] = None, - ): - device = device or self._execution_device - dtype = dtype or self.text_encoder.dtype - - prompt = [prompt] if isinstance(prompt, str) else prompt - batch_size = len(prompt) - - if self.text_encoder_3 is None: - return torch.zeros( - ( - batch_size * num_images_per_prompt, - self.tokenizer_max_length, - self.transformer.config.joint_attention_dim, - ), - device=device, - dtype=dtype, - ) - - text_inputs = self.tokenizer_3( - prompt, - padding="max_length", - max_length=max_sequence_length, - truncation=True, - add_special_tokens=True, - return_tensors="pt", - ) - text_input_ids = text_inputs.input_ids - untruncated_ids = self.tokenizer_3(prompt, padding="longest", return_tensors="pt").input_ids - - if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal(text_input_ids, untruncated_ids): - removed_text = self.tokenizer_3.batch_decode(untruncated_ids[:, self.tokenizer_max_length - 1 : -1]) - logger.warning( - "The following part of your input was truncated because `max_sequence_length` is set to " - f" {max_sequence_length} tokens: {removed_text}" - ) - - prompt_embeds = self.text_encoder_3(text_input_ids.to(device))[0] - - dtype = self.text_encoder_3.dtype - prompt_embeds = prompt_embeds.to(dtype=dtype, device=device) - - _, seq_len, _ = prompt_embeds.shape - - # duplicate text embeddings and attention mask for each generation per prompt, using mps friendly method - prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1) - prompt_embeds = prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1) - - return prompt_embeds - - def _get_clip_prompt_embeds( - self, - prompt: Union[str, List[str]], - num_images_per_prompt: int = 1, - device: Optional[torch.device] = None, - clip_skip: Optional[int] = None, - clip_model_index: int = 0, - ): - device = device or self._execution_device - - clip_tokenizers = [self.tokenizer, self.tokenizer_2] - clip_text_encoders = [self.text_encoder, self.text_encoder_2] - - tokenizer = clip_tokenizers[clip_model_index] - text_encoder = clip_text_encoders[clip_model_index] - - prompt = [prompt] if isinstance(prompt, str) else prompt - batch_size = len(prompt) - - text_inputs = tokenizer( - prompt, - padding="max_length", - max_length=self.tokenizer_max_length, - truncation=True, - return_tensors="pt", - ) - - text_input_ids = text_inputs.input_ids - untruncated_ids = tokenizer(prompt, padding="longest", return_tensors="pt").input_ids - if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal(text_input_ids, untruncated_ids): - removed_text = tokenizer.batch_decode(untruncated_ids[:, self.tokenizer_max_length - 1 : -1]) - logger.warning( - "The following part of your input was truncated because CLIP can only handle sequences up to" - f" {self.tokenizer_max_length} tokens: {removed_text}" - ) - prompt_embeds = text_encoder(text_input_ids.to(device), output_hidden_states=True) - pooled_prompt_embeds = prompt_embeds[0] - - if clip_skip is None: - prompt_embeds = prompt_embeds.hidden_states[-2] - else: - prompt_embeds = prompt_embeds.hidden_states[-(clip_skip + 2)] - - prompt_embeds = prompt_embeds.to(dtype=self.text_encoder.dtype, device=device) - - _, seq_len, _ = prompt_embeds.shape - # duplicate text embeddings for each generation per prompt, using mps friendly method - prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1) - prompt_embeds = prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1) - - pooled_prompt_embeds = pooled_prompt_embeds.repeat(1, num_images_per_prompt, 1) - pooled_prompt_embeds = pooled_prompt_embeds.view(batch_size * num_images_per_prompt, -1) - - return prompt_embeds, pooled_prompt_embeds - - def encode_prompt( - self, - prompt: Union[str, List[str]], - prompt_2: Union[str, List[str]], - prompt_3: Union[str, List[str]], - device: Optional[torch.device] = None, - num_images_per_prompt: int = 1, - do_classifier_free_guidance: bool = True, - negative_prompt: Optional[Union[str, List[str]]] = None, - negative_prompt_2: Optional[Union[str, List[str]]] = None, - negative_prompt_3: Optional[Union[str, List[str]]] = None, - prompt_embeds: Optional[torch.FloatTensor] = None, - negative_prompt_embeds: Optional[torch.FloatTensor] = None, - pooled_prompt_embeds: Optional[torch.FloatTensor] = None, - negative_pooled_prompt_embeds: Optional[torch.FloatTensor] = None, - clip_skip: Optional[int] = None, - max_sequence_length: int = 256, - lora_scale: Optional[float] = None, - ): - r""" - - Args: - prompt (`str` or `List[str]`, *optional*): - prompt to be encoded - prompt_2 (`str` or `List[str]`, *optional*): - The prompt or prompts to be sent to the `tokenizer_2` and `text_encoder_2`. If not defined, `prompt` is - used in all text-encoders - prompt_3 (`str` or `List[str]`, *optional*): - The prompt or prompts to be sent to the `tokenizer_3` and `text_encoder_3`. If not defined, `prompt` is - used in all text-encoders - device: (`torch.device`): - torch device - num_images_per_prompt (`int`): - number of images that should be generated per prompt - do_classifier_free_guidance (`bool`): - whether to use classifier free guidance or not - negative_prompt (`str` or `List[str]`, *optional*): - The prompt or prompts not to guide the image generation. If not defined, one has to pass - `negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is - less than `1`). - negative_prompt_2 (`str` or `List[str]`, *optional*): - The prompt or prompts not to guide the image generation to be sent to `tokenizer_2` and - `text_encoder_2`. If not defined, `negative_prompt` is used in all the text-encoders. - negative_prompt_2 (`str` or `List[str]`, *optional*): - The prompt or prompts not to guide the image generation to be sent to `tokenizer_3` and - `text_encoder_3`. If not defined, `negative_prompt` is used in both text-encoders - prompt_embeds (`torch.FloatTensor`, *optional*): - Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not - provided, text embeddings will be generated from `prompt` input argument. - negative_prompt_embeds (`torch.FloatTensor`, *optional*): - Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt - weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input - argument. - pooled_prompt_embeds (`torch.FloatTensor`, *optional*): - Pre-generated pooled text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. - If not provided, pooled text embeddings will be generated from `prompt` input argument. - negative_pooled_prompt_embeds (`torch.FloatTensor`, *optional*): - Pre-generated negative pooled text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt - weighting. If not provided, pooled negative_prompt_embeds will be generated from `negative_prompt` - input argument. - clip_skip (`int`, *optional*): - Number of layers to be skipped from CLIP while computing the prompt embeddings. A value of 1 means that - the output of the pre-final layer will be used for computing the prompt embeddings. - lora_scale (`float`, *optional*): - A lora scale that will be applied to all LoRA layers of the text encoder if LoRA layers are loaded. - """ - device = device or self._execution_device - - # set lora scale so that monkey patched LoRA - # function of text encoder can correctly access it - if lora_scale is not None and isinstance(self, SD3LoraLoaderMixin): - self._lora_scale = lora_scale - - # dynamically adjust the LoRA scale - if self.text_encoder is not None and USE_PEFT_BACKEND: - scale_lora_layers(self.text_encoder, lora_scale) - if self.text_encoder_2 is not None and USE_PEFT_BACKEND: - scale_lora_layers(self.text_encoder_2, lora_scale) - - prompt = [prompt] if isinstance(prompt, str) else prompt - if prompt is not None: - batch_size = len(prompt) - else: - batch_size = prompt_embeds.shape[0] - - if prompt_embeds is None: - prompt_2 = prompt_2 or prompt - prompt_2 = [prompt_2] if isinstance(prompt_2, str) else prompt_2 - - prompt_3 = prompt_3 or prompt - prompt_3 = [prompt_3] if isinstance(prompt_3, str) else prompt_3 - - prompt_embed, pooled_prompt_embed = self._get_clip_prompt_embeds( - prompt=prompt, - device=device, - num_images_per_prompt=num_images_per_prompt, - clip_skip=clip_skip, - clip_model_index=0, - ) - prompt_2_embed, pooled_prompt_2_embed = self._get_clip_prompt_embeds( - prompt=prompt_2, - device=device, - num_images_per_prompt=num_images_per_prompt, - clip_skip=clip_skip, - clip_model_index=1, - ) - clip_prompt_embeds = torch.cat([prompt_embed, prompt_2_embed], dim=-1) - - t5_prompt_embed = self._get_t5_prompt_embeds( - prompt=prompt_3, - num_images_per_prompt=num_images_per_prompt, - max_sequence_length=max_sequence_length, - device=device, - ) - - clip_prompt_embeds = torch.nn.functional.pad( - clip_prompt_embeds, (0, t5_prompt_embed.shape[-1] - clip_prompt_embeds.shape[-1]) - ) - - prompt_embeds = torch.cat([clip_prompt_embeds, t5_prompt_embed], dim=-2) - pooled_prompt_embeds = torch.cat([pooled_prompt_embed, pooled_prompt_2_embed], dim=-1) - - if do_classifier_free_guidance and negative_prompt_embeds is None: - negative_prompt = negative_prompt or "" - negative_prompt_2 = negative_prompt_2 or negative_prompt - negative_prompt_3 = negative_prompt_3 or negative_prompt - - # normalize str to list - negative_prompt = batch_size * [negative_prompt] if isinstance(negative_prompt, str) else negative_prompt - negative_prompt_2 = ( - batch_size * [negative_prompt_2] if isinstance(negative_prompt_2, str) else negative_prompt_2 - ) - negative_prompt_3 = ( - batch_size * [negative_prompt_3] if isinstance(negative_prompt_3, str) else negative_prompt_3 - ) - - if prompt is not None and type(prompt) is not type(negative_prompt): - raise TypeError( - f"`negative_prompt` should be the same type to `prompt`, but got {type(negative_prompt)} !=" - f" {type(prompt)}." - ) - elif batch_size != len(negative_prompt): - raise ValueError( - f"`negative_prompt`: {negative_prompt} has batch size {len(negative_prompt)}, but `prompt`:" - f" {prompt} has batch size {batch_size}. Please make sure that passed `negative_prompt` matches" - " the batch size of `prompt`." - ) - - negative_prompt_embed, negative_pooled_prompt_embed = self._get_clip_prompt_embeds( - negative_prompt, - device=device, - num_images_per_prompt=num_images_per_prompt, - clip_skip=None, - clip_model_index=0, - ) - negative_prompt_2_embed, negative_pooled_prompt_2_embed = self._get_clip_prompt_embeds( - negative_prompt_2, - device=device, - num_images_per_prompt=num_images_per_prompt, - clip_skip=None, - clip_model_index=1, - ) - negative_clip_prompt_embeds = torch.cat([negative_prompt_embed, negative_prompt_2_embed], dim=-1) - - t5_negative_prompt_embed = self._get_t5_prompt_embeds( - prompt=negative_prompt_3, - num_images_per_prompt=num_images_per_prompt, - max_sequence_length=max_sequence_length, - device=device, - ) - - negative_clip_prompt_embeds = torch.nn.functional.pad( - negative_clip_prompt_embeds, - (0, t5_negative_prompt_embed.shape[-1] - negative_clip_prompt_embeds.shape[-1]), - ) - - negative_prompt_embeds = torch.cat([negative_clip_prompt_embeds, t5_negative_prompt_embed], dim=-2) - negative_pooled_prompt_embeds = torch.cat( - [negative_pooled_prompt_embed, negative_pooled_prompt_2_embed], dim=-1 - ) - - if self.text_encoder is not None: - if isinstance(self, SD3LoraLoaderMixin) and USE_PEFT_BACKEND: - # Retrieve the original scale by scaling back the LoRA layers - unscale_lora_layers(self.text_encoder, lora_scale) - - if self.text_encoder_2 is not None: - if isinstance(self, SD3LoraLoaderMixin) and USE_PEFT_BACKEND: - # Retrieve the original scale by scaling back the LoRA layers - unscale_lora_layers(self.text_encoder_2, lora_scale) - - return prompt_embeds, negative_prompt_embeds, pooled_prompt_embeds, negative_pooled_prompt_embeds - - def check_inputs( - self, - prompt, - prompt_2, - prompt_3, - height, - width, - negative_prompt=None, - negative_prompt_2=None, - negative_prompt_3=None, - prompt_embeds=None, - negative_prompt_embeds=None, - pooled_prompt_embeds=None, - negative_pooled_prompt_embeds=None, - callback_on_step_end_tensor_inputs=None, - max_sequence_length=None, - ): - if height % 8 != 0 or width % 8 != 0: - raise ValueError(f"`height` and `width` have to be divisible by 8 but are {height} and {width}.") - - if callback_on_step_end_tensor_inputs is not None and not all( - k in self._callback_tensor_inputs for k in callback_on_step_end_tensor_inputs - ): - raise ValueError( - f"`callback_on_step_end_tensor_inputs` has to be in {self._callback_tensor_inputs}, but found {[k for k in callback_on_step_end_tensor_inputs if k not in self._callback_tensor_inputs]}" - ) - - if prompt is not None and prompt_embeds is not None: - raise ValueError( - f"Cannot forward both `prompt`: {prompt} and `prompt_embeds`: {prompt_embeds}. Please make sure to" - " only forward one of the two." - ) - elif prompt_2 is not None and prompt_embeds is not None: - raise ValueError( - f"Cannot forward both `prompt_2`: {prompt_2} and `prompt_embeds`: {prompt_embeds}. Please make sure to" - " only forward one of the two." - ) - elif prompt_3 is not None and prompt_embeds is not None: - raise ValueError( - f"Cannot forward both `prompt_3`: {prompt_2} and `prompt_embeds`: {prompt_embeds}. Please make sure to" - " only forward one of the two." - ) - elif prompt is None and prompt_embeds is None: - raise ValueError( - "Provide either `prompt` or `prompt_embeds`. Cannot leave both `prompt` and `prompt_embeds` undefined." - ) - elif prompt is not None and (not isinstance(prompt, str) and not isinstance(prompt, list)): - raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}") - elif prompt_2 is not None and (not isinstance(prompt_2, str) and not isinstance(prompt_2, list)): - raise ValueError(f"`prompt_2` has to be of type `str` or `list` but is {type(prompt_2)}") - elif prompt_3 is not None and (not isinstance(prompt_3, str) and not isinstance(prompt_3, list)): - raise ValueError(f"`prompt_3` has to be of type `str` or `list` but is {type(prompt_3)}") - - if negative_prompt is not None and negative_prompt_embeds is not None: - raise ValueError( - f"Cannot forward both `negative_prompt`: {negative_prompt} and `negative_prompt_embeds`:" - f" {negative_prompt_embeds}. Please make sure to only forward one of the two." - ) - elif negative_prompt_2 is not None and negative_prompt_embeds is not None: - raise ValueError( - f"Cannot forward both `negative_prompt_2`: {negative_prompt_2} and `negative_prompt_embeds`:" - f" {negative_prompt_embeds}. Please make sure to only forward one of the two." - ) - elif negative_prompt_3 is not None and negative_prompt_embeds is not None: - raise ValueError( - f"Cannot forward both `negative_prompt_3`: {negative_prompt_3} and `negative_prompt_embeds`:" - f" {negative_prompt_embeds}. Please make sure to only forward one of the two." - ) - - if prompt_embeds is not None and negative_prompt_embeds is not None: - if prompt_embeds.shape != negative_prompt_embeds.shape: - raise ValueError( - "`prompt_embeds` and `negative_prompt_embeds` must have the same shape when passed directly, but" - f" got: `prompt_embeds` {prompt_embeds.shape} != `negative_prompt_embeds`" - f" {negative_prompt_embeds.shape}." - ) - - if prompt_embeds is not None and pooled_prompt_embeds is None: - raise ValueError( - "If `prompt_embeds` are provided, `pooled_prompt_embeds` also have to be passed. Make sure to generate `pooled_prompt_embeds` from the same text encoder that was used to generate `prompt_embeds`." - ) - - if negative_prompt_embeds is not None and negative_pooled_prompt_embeds is None: - raise ValueError( - "If `negative_prompt_embeds` are provided, `negative_pooled_prompt_embeds` also have to be passed. Make sure to generate `negative_pooled_prompt_embeds` from the same text encoder that was used to generate `negative_prompt_embeds`." - ) - - if max_sequence_length is not None and max_sequence_length > 512: - raise ValueError(f"`max_sequence_length` cannot be greater than 512 but is {max_sequence_length}") - - def prepare_latents( - self, - batch_size, - num_channels_latents, - height, - width, - dtype, - device, - generator, - latents=None, - ): - if latents is not None: - return latents.to(device=device, dtype=dtype) - - shape = ( - batch_size, - num_channels_latents, - int(height) // self.vae_scale_factor, - int(width) // self.vae_scale_factor, - ) - - if isinstance(generator, list) and len(generator) != batch_size: - raise ValueError( - f"You have passed a list of generators of length {len(generator)}, but requested an effective batch" - f" size of {batch_size}. Make sure the batch size matches the length of the generators." - ) - - latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype) - - return latents - - @property - def guidance_scale(self): - return self._guidance_scale - - @property - def clip_skip(self): - return self._clip_skip - - # here `guidance_scale` is defined analog to the guidance weight `w` of equation (2) - # of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1` - # corresponds to doing no classifier free guidance. - @property - def do_classifier_free_guidance(self): - return self._guidance_scale > 1 - - @property - def joint_attention_kwargs(self): - return self._joint_attention_kwargs - - @property - def num_timesteps(self): - return self._num_timesteps - - @property - def interrupt(self): - return self._interrupt - - @torch.inference_mode() - def init_ipadapter(self, ip_adapter_path, image_encoder_path, nb_token, output_dim=2432): - from transformers import SiglipImageProcessor, SiglipVisionModel - - state_dict = torch.load(ip_adapter_path, map_location="cpu") - - device, dtype = self.transformer.device, self.transformer.dtype - image_encoder = SiglipVisionModel.from_pretrained(image_encoder_path) - image_processor = SiglipImageProcessor.from_pretrained(image_encoder_path) - image_encoder.eval() - image_encoder.to(device, dtype=dtype) - self.image_encoder = image_encoder - self.clip_image_processor = image_processor - - sample_class = TimeResampler - image_proj_model = sample_class( - dim=1280, - depth=4, - dim_head=64, - heads=20, - num_queries=nb_token, - embedding_dim=1152, - output_dim=output_dim, - ff_mult=4, - timestep_in_dim=320, - timestep_flip_sin_to_cos=True, - timestep_freq_shift=0, - ) - image_proj_model.eval() - image_proj_model.to(device, dtype=dtype) - key_name = image_proj_model.load_state_dict(state_dict["image_proj"], strict=False) - print(f"=> loading image_proj_model: {key_name}") - - self.image_proj_model = image_proj_model - - attn_procs = {} - transformer = self.transformer - for idx_name, name in enumerate(transformer.attn_processors.keys()): - hidden_size = transformer.config.attention_head_dim * transformer.config.num_attention_heads - ip_hidden_states_dim = transformer.config.attention_head_dim * transformer.config.num_attention_heads - ip_encoder_hidden_states_dim = transformer.config.caption_projection_dim - - attn_procs[name] = JointIPAttnProcessor( - hidden_size=hidden_size, - cross_attention_dim=transformer.config.caption_projection_dim, - ip_hidden_states_dim=ip_hidden_states_dim, - ip_encoder_hidden_states_dim=ip_encoder_hidden_states_dim, - head_dim=transformer.config.attention_head_dim, - timesteps_emb_dim=1280, - ).to(device, dtype=dtype) - - self.transformer.set_attn_processor(attn_procs) - tmp_ip_layers = torch.nn.ModuleList(self.transformer.attn_processors.values()) - - key_name = tmp_ip_layers.load_state_dict(state_dict["ip_adapter"], strict=False) - print(f"=> loading ip_adapter: {key_name}") - - @torch.inference_mode() - def encode_clip_image_emb(self, clip_image, device, dtype): - # clip - clip_image_tensor = self.clip_image_processor(images=clip_image, return_tensors="pt").pixel_values - clip_image_tensor = clip_image_tensor.to(device, dtype=dtype) - clip_image_embeds = self.image_encoder(clip_image_tensor, output_hidden_states=True).hidden_states[-2] - clip_image_embeds = torch.cat([torch.zeros_like(clip_image_embeds), clip_image_embeds], dim=0) - - return clip_image_embeds - - @torch.no_grad() - @replace_example_docstring(EXAMPLE_DOC_STRING) - def __call__( - self, - prompt: Union[str, List[str]] = None, - prompt_2: Optional[Union[str, List[str]]] = None, - prompt_3: Optional[Union[str, List[str]]] = None, - height: Optional[int] = None, - width: Optional[int] = None, - num_inference_steps: int = 28, - timesteps: List[int] = None, - guidance_scale: float = 7.0, - negative_prompt: Optional[Union[str, List[str]]] = None, - negative_prompt_2: Optional[Union[str, List[str]]] = None, - negative_prompt_3: Optional[Union[str, List[str]]] = None, - num_images_per_prompt: Optional[int] = 1, - generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None, - latents: Optional[torch.FloatTensor] = None, - prompt_embeds: Optional[torch.FloatTensor] = None, - negative_prompt_embeds: Optional[torch.FloatTensor] = None, - pooled_prompt_embeds: Optional[torch.FloatTensor] = None, - negative_pooled_prompt_embeds: Optional[torch.FloatTensor] = None, - output_type: Optional[str] = "pil", - return_dict: bool = True, - joint_attention_kwargs: Optional[Dict[str, Any]] = None, - clip_skip: Optional[int] = None, - callback_on_step_end: Optional[Callable[[int, int, Dict], None]] = None, - callback_on_step_end_tensor_inputs: List[str] = ["latents"], - max_sequence_length: int = 256, - # ipa - clip_image=None, - ipadapter_scale=1.0, - ): - r""" - Function invoked when calling the pipeline for generation. - - Args: - prompt (`str` or `List[str]`, *optional*): - The prompt or prompts to guide the image generation. If not defined, one has to pass `prompt_embeds`. - instead. - prompt_2 (`str` or `List[str]`, *optional*): - The prompt or prompts to be sent to `tokenizer_2` and `text_encoder_2`. If not defined, `prompt` is - will be used instead - prompt_3 (`str` or `List[str]`, *optional*): - The prompt or prompts to be sent to `tokenizer_3` and `text_encoder_3`. If not defined, `prompt` is - will be used instead - height (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor): - The height in pixels of the generated image. This is set to 1024 by default for the best results. - width (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor): - The width in pixels of the generated image. This is set to 1024 by default for the best results. - num_inference_steps (`int`, *optional*, defaults to 50): - The number of denoising steps. More denoising steps usually lead to a higher quality image at the - expense of slower inference. - timesteps (`List[int]`, *optional*): - Custom timesteps to use for the denoising process with schedulers which support a `timesteps` argument - in their `set_timesteps` method. If not defined, the default behavior when `num_inference_steps` is - passed will be used. Must be in descending order. - guidance_scale (`float`, *optional*, defaults to 7.0): - Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598). - `guidance_scale` is defined as `w` of equation 2. of [Imagen - Paper](https://arxiv.org/pdf/2205.11487.pdf). Guidance scale is enabled by setting `guidance_scale > - 1`. Higher guidance scale encourages to generate images that are closely linked to the text `prompt`, - usually at the expense of lower image quality. - negative_prompt (`str` or `List[str]`, *optional*): - The prompt or prompts not to guide the image generation. If not defined, one has to pass - `negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is - less than `1`). - negative_prompt_2 (`str` or `List[str]`, *optional*): - The prompt or prompts not to guide the image generation to be sent to `tokenizer_2` and - `text_encoder_2`. If not defined, `negative_prompt` is used instead - negative_prompt_3 (`str` or `List[str]`, *optional*): - The prompt or prompts not to guide the image generation to be sent to `tokenizer_3` and - `text_encoder_3`. If not defined, `negative_prompt` is used instead - num_images_per_prompt (`int`, *optional*, defaults to 1): - The number of images to generate per prompt. - generator (`torch.Generator` or `List[torch.Generator]`, *optional*): - One or a list of [torch generator(s)](https://pytorch.org/docs/stable/generated/torch.Generator.html) - to make generation deterministic. - latents (`torch.FloatTensor`, *optional*): - Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for image - generation. Can be used to tweak the same generation with different prompts. If not provided, a latents - tensor will ge generated by sampling using the supplied random `generator`. - prompt_embeds (`torch.FloatTensor`, *optional*): - Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not - provided, text embeddings will be generated from `prompt` input argument. - negative_prompt_embeds (`torch.FloatTensor`, *optional*): - Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt - weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input - argument. - pooled_prompt_embeds (`torch.FloatTensor`, *optional*): - Pre-generated pooled text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. - If not provided, pooled text embeddings will be generated from `prompt` input argument. - negative_pooled_prompt_embeds (`torch.FloatTensor`, *optional*): - Pre-generated negative pooled text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt - weighting. If not provided, pooled negative_prompt_embeds will be generated from `negative_prompt` - input argument. - output_type (`str`, *optional*, defaults to `"pil"`): - The output format of the generate image. Choose between - [PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`. - return_dict (`bool`, *optional*, defaults to `True`): - Whether or not to return a [`~pipelines.stable_diffusion_xl.StableDiffusionXLPipelineOutput`] instead - of a plain tuple. - joint_attention_kwargs (`dict`, *optional*): - A kwargs dictionary that if specified is passed along to the `AttentionProcessor` as defined under - `self.processor` in - [diffusers.models.attention_processor](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py). - callback_on_step_end (`Callable`, *optional*): - A function that calls at the end of each denoising steps during the inference. The function is called - with the following arguments: `callback_on_step_end(self: DiffusionPipeline, step: int, timestep: int, - callback_kwargs: Dict)`. `callback_kwargs` will include a list of all tensors as specified by - `callback_on_step_end_tensor_inputs`. - callback_on_step_end_tensor_inputs (`List`, *optional*): - The list of tensor inputs for the `callback_on_step_end` function. The tensors specified in the list - will be passed as `callback_kwargs` argument. You will only be able to include variables listed in the - `._callback_tensor_inputs` attribute of your pipeline class. - max_sequence_length (`int` defaults to 256): Maximum sequence length to use with the `prompt`. - - Examples: - - Returns: - [`~pipelines.stable_diffusion_3.StableDiffusion3PipelineOutput`] or `tuple`: - [`~pipelines.stable_diffusion_3.StableDiffusion3PipelineOutput`] if `return_dict` is True, otherwise a - `tuple`. When returning a tuple, the first element is a list with the generated images. - """ - - height = height or self.default_sample_size * self.vae_scale_factor - width = width or self.default_sample_size * self.vae_scale_factor - - # 1. Check inputs. Raise error if not correct - self.check_inputs( - prompt, - prompt_2, - prompt_3, - height, - width, - negative_prompt=negative_prompt, - negative_prompt_2=negative_prompt_2, - negative_prompt_3=negative_prompt_3, - prompt_embeds=prompt_embeds, - negative_prompt_embeds=negative_prompt_embeds, - pooled_prompt_embeds=pooled_prompt_embeds, - negative_pooled_prompt_embeds=negative_pooled_prompt_embeds, - callback_on_step_end_tensor_inputs=callback_on_step_end_tensor_inputs, - max_sequence_length=max_sequence_length, - ) - - self._guidance_scale = guidance_scale - self._clip_skip = clip_skip - self._joint_attention_kwargs = joint_attention_kwargs - self._interrupt = False - - # 2. Define call parameters - if prompt is not None and isinstance(prompt, str): - batch_size = 1 - elif prompt is not None and isinstance(prompt, list): - batch_size = len(prompt) - else: - batch_size = prompt_embeds.shape[0] - - device = self._execution_device - dtype = self.transformer.dtype - - lora_scale = ( - self.joint_attention_kwargs.get("scale", None) if self.joint_attention_kwargs is not None else None - ) - ( - prompt_embeds, - negative_prompt_embeds, - pooled_prompt_embeds, - negative_pooled_prompt_embeds, - ) = self.encode_prompt( - prompt=prompt, - prompt_2=prompt_2, - prompt_3=prompt_3, - negative_prompt=negative_prompt, - negative_prompt_2=negative_prompt_2, - negative_prompt_3=negative_prompt_3, - do_classifier_free_guidance=self.do_classifier_free_guidance, - prompt_embeds=prompt_embeds, - negative_prompt_embeds=negative_prompt_embeds, - pooled_prompt_embeds=pooled_prompt_embeds, - negative_pooled_prompt_embeds=negative_pooled_prompt_embeds, - device=device, - clip_skip=self.clip_skip, - num_images_per_prompt=num_images_per_prompt, - max_sequence_length=max_sequence_length, - lora_scale=lora_scale, - ) - - if self.do_classifier_free_guidance: - prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds], dim=0) - pooled_prompt_embeds = torch.cat([negative_pooled_prompt_embeds, pooled_prompt_embeds], dim=0) - - # 3. prepare clip emb - clip_image = clip_image.resize((max(clip_image.size), max(clip_image.size))) - clip_image_embeds = self.encode_clip_image_emb(clip_image, device, dtype) - - # 4. Prepare timesteps - timesteps, num_inference_steps = retrieve_timesteps(self.scheduler, num_inference_steps, device, timesteps) - num_warmup_steps = max(len(timesteps) - num_inference_steps * self.scheduler.order, 0) - self._num_timesteps = len(timesteps) - - # 5. Prepare latent variables - num_channels_latents = self.transformer.config.in_channels - latents = self.prepare_latents( - batch_size * num_images_per_prompt, - num_channels_latents, - height, - width, - prompt_embeds.dtype, - device, - generator, - latents, - ) - - # 6. Denoising loop - with self.progress_bar(total=num_inference_steps) as progress_bar: - for i, t in enumerate(timesteps): - if self.interrupt: - continue - - # expand the latents if we are doing classifier free guidance - latent_model_input = torch.cat([latents] * 2) if self.do_classifier_free_guidance else latents - # broadcast to batch dimension in a way that's compatible with ONNX/Core ML - timestep = t.expand(latent_model_input.shape[0]) - - image_prompt_embeds, timestep_emb = self.image_proj_model( - clip_image_embeds, timestep.to(dtype=latents.dtype), need_temb=True - ) - - joint_attention_kwargs = { - "emb_dict": { - "ip_hidden_states": image_prompt_embeds, - "temb": timestep_emb, - "scale": ipadapter_scale, - } - } - - noise_pred = self.transformer( - hidden_states=latent_model_input, - timestep=timestep, - encoder_hidden_states=prompt_embeds, - pooled_projections=pooled_prompt_embeds, - joint_attention_kwargs=joint_attention_kwargs, - return_dict=False, - )[0] - - # perform guidance - if self.do_classifier_free_guidance: - noise_pred_uncond, noise_pred_text = noise_pred.chunk(2) - noise_pred = noise_pred_uncond + self.guidance_scale * (noise_pred_text - noise_pred_uncond) - - # compute the previous noisy sample x_t -> x_t-1 - latents_dtype = latents.dtype - latents = self.scheduler.step(noise_pred, t, latents, return_dict=False)[0] - - if latents.dtype != latents_dtype: - if torch.backends.mps.is_available(): - # some platforms (eg. apple mps) misbehave due to a pytorch bug: https://github.com/pytorch/pytorch/pull/99272 - latents = latents.to(latents_dtype) - - if callback_on_step_end is not None: - callback_kwargs = {} - for k in callback_on_step_end_tensor_inputs: - callback_kwargs[k] = locals()[k] - callback_outputs = callback_on_step_end(self, i, t, callback_kwargs) - - latents = callback_outputs.pop("latents", latents) - prompt_embeds = callback_outputs.pop("prompt_embeds", prompt_embeds) - negative_prompt_embeds = callback_outputs.pop("negative_prompt_embeds", negative_prompt_embeds) - negative_pooled_prompt_embeds = callback_outputs.pop( - "negative_pooled_prompt_embeds", negative_pooled_prompt_embeds - ) - - # call the callback, if provided - if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0): - progress_bar.update() - - if XLA_AVAILABLE: - xm.mark_step() - - if output_type == "latent": - image = latents - - else: - latents = (latents / self.vae.config.scaling_factor) + self.vae.config.shift_factor - - image = self.vae.decode(latents, return_dict=False)[0] - image = self.image_processor.postprocess(image, output_type=output_type) - - # Offload all models - self.maybe_free_model_hooks() - - if not return_dict: - return (image,) - - return StableDiffusion3PipelineOutput(images=image) diff --git a/src/diffusers/models/attention_processor.py b/src/diffusers/models/attention_processor.py index c76909d8255f..a8fb8bf28f89 100644 --- a/src/diffusers/models/attention_processor.py +++ b/src/diffusers/models/attention_processor.py @@ -3927,8 +3927,9 @@ def __call__( key = attn.norm_k(key) # the output of sdp = (batch, num_heads, seq_len, head_dim) + # TODO: add support for attn.scale when we move to Torch 2.1 hidden_states = F.scaled_dot_product_attention( - query, key, value, attn_mask=attention_mask, dropout_p=0.0, is_causal=False, scale=attn.scale + query, key, value, attn_mask=attention_mask, dropout_p=0.0, is_causal=False ) hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim) diff --git a/src/diffusers/models/embeddings.py b/src/diffusers/models/embeddings.py index 2e0c773d8e04..a65e1338b4f5 100644 --- a/src/diffusers/models/embeddings.py +++ b/src/diffusers/models/embeddings.py @@ -21,7 +21,7 @@ from ..utils import deprecate from .activations import FP32SiLU, get_activation -from .attention_processor import Attention, FusedAttnProcessor2_0 +from .attention_processor import Attention def get_timestep_embedding( @@ -2143,7 +2143,7 @@ def forward(self, x, latents, timestep_emb): emb = self.adaln_proj(self.adaln_silu(timestep_emb)) shift_msa, scale_msa, shift_mlp, scale_mlp = emb.chunk(4, dim=1) - # Fused Attention + # Fused Attention residual = latents x = self.ln0(x) latents = self.ln1(latents) * (1 + scale_msa[:, None]) + shift_msa[:, None] @@ -2171,7 +2171,7 @@ def forward(self, x, latents, timestep_emb): latents = self.attn.to_out[0](latents) latents = self.attn.to_out[1](latents) latents = latents + residual - + ## FeedForward residual = latents latents = self.adaln_norm(latents) * (1 + scale_mlp[:, None]) + shift_mlp[:, None] From 84aa4a3ef70cebcc7e6a433fb4de3b97ac7fd2a3 Mon Sep 17 00:00:00 2001 From: Daniel Regado <danielregado@gmail.com> Date: Fri, 13 Dec 2024 03:32:53 +0000 Subject: [PATCH 26/44] Minor changes in code structure --- src/diffusers/models/attention_processor.py | 19 ++++++++----------- src/diffusers/models/embeddings.py | 4 +--- 2 files changed, 9 insertions(+), 14 deletions(-) diff --git a/src/diffusers/models/attention_processor.py b/src/diffusers/models/attention_processor.py index a8fb8bf28f89..e004893b20b6 100644 --- a/src/diffusers/models/attention_processor.py +++ b/src/diffusers/models/attention_processor.py @@ -5090,9 +5090,6 @@ def __call__( query = attn.to_q(hidden_states) key = attn.to_k(hidden_states) value = attn.to_v(hidden_states) - img_query = query - img_key = key - img_value = value inner_dim = key.shape[-1] head_dim = inner_dim // attn.heads @@ -5100,6 +5097,9 @@ def __call__( query = query.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2) key = key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2) value = value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2) + img_query = query + img_key = key + img_value = value if attn.norm_q is not None: query = attn.norm_q(query) @@ -5154,26 +5154,23 @@ def __call__( ip_value = self.to_v_ip(norm_ip_hidden_states) # Reshape - img_query = img_query.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2) - img_key = img_key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2) - img_value = img_value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2) ip_key = ip_key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2) ip_value = ip_value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2) # Norm - img_query = self.norm_q(img_query) + query = self.norm_q(img_query) img_key = self.norm_k(img_key) ip_key = self.norm_ip_k(ip_key) # cat img - img_key = torch.cat([img_key, ip_key], dim=2) - img_value = torch.cat([img_value, ip_value], dim=2) + key = torch.cat([img_key, ip_key], dim=2) + value = torch.cat([img_value, ip_value], dim=2) ip_hidden_states = F.scaled_dot_product_attention( - img_query, img_key, img_value, dropout_p=0.0, is_causal=False + query, key, value, dropout_p=0.0, is_causal=False ) ip_hidden_states = ip_hidden_states.transpose(1, 2).view(batch_size, -1, attn.heads * head_dim) - ip_hidden_states = ip_hidden_states.to(img_query.dtype) + ip_hidden_states = ip_hidden_states.to(query.dtype) hidden_states = hidden_states + ip_hidden_states * self.scale diff --git a/src/diffusers/models/embeddings.py b/src/diffusers/models/embeddings.py index a65e1338b4f5..b6b8ac810d6a 100644 --- a/src/diffusers/models/embeddings.py +++ b/src/diffusers/models/embeddings.py @@ -2152,9 +2152,7 @@ def forward(self, x, latents, timestep_emb): query = self.attn.to_q(latents) kv_input = torch.cat((x, latents), dim=-2) - kv = self.attn.to_kv(kv_input) - split_size = kv.shape[-1] // 2 - key, value = torch.split(kv, split_size, dim=-1) + key, value = self.attn.to_kv(kv_input).chunk(2, dim=-1) inner_dim = key.shape[-1] head_dim = inner_dim // self.attn.heads From 34793fbd27a291b0536de93127dbf8a9dc20b470 Mon Sep 17 00:00:00 2001 From: Daniel Regado <danielregado@gmail.com> Date: Fri, 13 Dec 2024 04:32:24 +0000 Subject: [PATCH 27/44] make style && make quality --- src/diffusers/models/attention_processor.py | 4 +--- 1 file changed, 1 insertion(+), 3 deletions(-) diff --git a/src/diffusers/models/attention_processor.py b/src/diffusers/models/attention_processor.py index e004893b20b6..260e9aa51d46 100644 --- a/src/diffusers/models/attention_processor.py +++ b/src/diffusers/models/attention_processor.py @@ -5166,9 +5166,7 @@ def __call__( key = torch.cat([img_key, ip_key], dim=2) value = torch.cat([img_value, ip_value], dim=2) - ip_hidden_states = F.scaled_dot_product_attention( - query, key, value, dropout_p=0.0, is_causal=False - ) + ip_hidden_states = F.scaled_dot_product_attention(query, key, value, dropout_p=0.0, is_causal=False) ip_hidden_states = ip_hidden_states.transpose(1, 2).view(batch_size, -1, attn.heads * head_dim) ip_hidden_states = ip_hidden_states.to(query.dtype) From 68169f8f2eade82072524251157668acfffb45a3 Mon Sep 17 00:00:00 2001 From: Daniel Regado <danielregado@gmail.com> Date: Mon, 16 Dec 2024 18:56:31 +0000 Subject: [PATCH 28/44] Updated dosctrings and doc entries --- docs/source/en/api/attnprocessor.md | 3 + docs/source/en/api/loaders/ip_adapter.md | 6 ++ src/diffusers/loaders/ip_adapter.py | 41 ++++++----- src/diffusers/models/attention_processor.py | 39 ++++++++++- src/diffusers/models/embeddings.py | 68 ++++++++++++++++++- .../models/transformers/transformer_sd3.py | 14 +++- .../pipeline_stable_diffusion_3.py | 63 ++++++++++++----- 7 files changed, 196 insertions(+), 38 deletions(-) diff --git a/docs/source/en/api/attnprocessor.md b/docs/source/en/api/attnprocessor.md index 5b1f0be72ae6..db6a761d5607 100644 --- a/docs/source/en/api/attnprocessor.md +++ b/docs/source/en/api/attnprocessor.md @@ -52,3 +52,6 @@ An attention processor is a class for applying different types of attention mech ## AttnProcessorNPU [[autodoc]] models.attention_processor.AttnProcessorNPU + +## IPAdapterJointAttnProcessor2_0 +[[autodoc]] models.attention_processor.IPAdapterJointAttnProcessor2_0 \ No newline at end of file diff --git a/docs/source/en/api/loaders/ip_adapter.md b/docs/source/en/api/loaders/ip_adapter.md index a10f30ef8e5b..946a8b1af875 100644 --- a/docs/source/en/api/loaders/ip_adapter.md +++ b/docs/source/en/api/loaders/ip_adapter.md @@ -24,6 +24,12 @@ Learn how to load an IP-Adapter checkpoint and image in the IP-Adapter [loading] [[autodoc]] loaders.ip_adapter.IPAdapterMixin +## SD3IPAdapterMixin + +[[autodoc]] loaders.ip_adapter.SD3IPAdapterMixin + - all + - is_ip_adapter_active + ## IPAdapterMaskProcessor [[autodoc]] image_processor.IPAdapterMaskProcessor \ No newline at end of file diff --git a/src/diffusers/loaders/ip_adapter.py b/src/diffusers/loaders/ip_adapter.py index d2aca385c49c..faa0528d7c4b 100644 --- a/src/diffusers/loaders/ip_adapter.py +++ b/src/diffusers/loaders/ip_adapter.py @@ -358,13 +358,13 @@ class SD3IPAdapterMixin: @property def is_ip_adapter_active(self) -> bool: - r"""Checks if any ip_adapter attention processor have scale > 0. + """Checks if IP-Adapter is loaded and scale > 0. IP-Adapter scale controls the influence of the image prompt versus text prompt. When this value is set to 0, - image is irrelevant. + the image context is irrelevant. Returns: - `bool`: True when ip_adapter is loaded and any ip_adapter layer scale > 0. + `bool`: True when IP-Adapter is loaded and any layer has scale > 0. """ scales = [ attn_proc.scale @@ -382,7 +382,7 @@ def load_ip_adapter( weight_name: str, image_encoder_folder: Optional[str] = "image_encoder", **kwargs, - ): + ) -> None: """ Parameters: pretrained_model_name_or_path_or_dict (`str` or `os.PathLike` or `dict`): @@ -500,19 +500,19 @@ def load_ip_adapter( image_encoder_subfolder = Path(image_encoder_folder).as_posix() # Commons args for loading image encoder and image processor - args = dict( - pretrained_model_name_or_path_or_dict, - subfolder=image_encoder_subfolder, - low_cpu_mem_usage=low_cpu_mem_usage, - cache_dir=cache_dir, - local_files_only=local_files_only, - ) + kwargs = { + "low_cpu_mem_usage": low_cpu_mem_usage, + "cache_dir": cache_dir, + "local_files_only": local_files_only, + } self.register_modules( - feature_extractor=SiglipImageProcessor.from_pretrained(**args).to( + feature_extractor=SiglipImageProcessor.from_pretrained(image_encoder_subfolder, **kwargs).to( + self.device, dtype=self.dtype + ), + image_encoder=SiglipVisionModel.from_pretrained(image_encoder_subfolder, **kwargs).to( self.device, dtype=self.dtype ), - image_encoder=SiglipVisionModel.from_pretrained(**args).to(self.device, dtype=self.dtype), ) else: raise ValueError( @@ -527,11 +527,11 @@ def load_ip_adapter( # Load IP-Adapter into transformer self.transformer._load_ip_adapter_weights(state_dict, low_cpu_mem_usage=low_cpu_mem_usage) - def set_ip_adapter_scale(self, scale: float): + def set_ip_adapter_scale(self, scale: float) -> None: """ - Controls image/text prompt conditioning. A value of 1.0 means the model is only conditioned on the image - prompt, and 0.0 only conditioned by the text prompt. Lowering this value encourages the model to produce more - diverse images, but they may not be as aligned with the image prompt. + Set IP-Adapter scale, which controls image prompt conditioning. A value of 1.0 means the model is only + conditioned on the image prompt, and 0.0 only conditioned by the text prompt. Lowering this value encourages + the model to produce more diverse images, but they may not be as aligned with the image prompt. Example: @@ -540,12 +540,17 @@ def set_ip_adapter_scale(self, scale: float): >>> pipeline.set_ip_adapter_scale(0.6) >>> ... ``` + + Args: + scale (float): + IP-Adapter scale to be set. + """ for attn_processor in self.transformer.attn_processors.values(): if isinstance(attn_processor, IPAdapterJointAttnProcessor2_0): attn_processor.scale = scale - def unload_ip_adapter(self): + def unload_ip_adapter(self) -> None: """ Unloads the IP Adapter weights. diff --git a/src/diffusers/models/attention_processor.py b/src/diffusers/models/attention_processor.py index b1c6dad86acc..d1788af3df35 100644 --- a/src/diffusers/models/attention_processor.py +++ b/src/diffusers/models/attention_processor.py @@ -5149,7 +5149,22 @@ def __call__( class IPAdapterJointAttnProcessor2_0(torch.nn.Module): - """Attention processor for IP-Adapter used typically in processing the SD3-like self-attention projections.""" + """ + Attention processor for IP-Adapter used typically in processing the SD3-like self-attention projections, with + additional image-based information and timestep embeddings. + + Args: + hidden_size (`int`): + The number of hidden channels. + ip_hidden_states_dim (`int`): + The image feature dimension. + head_dim (`int`): + The number of head channels. + timesteps_emb_dim (`int`, defaults to 1280): + The number of input channels for timestep embedding. + scale (`float`, defaults to 0.5): + IP-Adapter scale. + """ def __init__( self, @@ -5181,6 +5196,28 @@ def __call__( ip_hidden_states: torch.FloatTensor = None, temb: torch.FloatTensor = None, ) -> torch.FloatTensor: + """ + Perform the attention computation, integrating image features (if provided) and timestep embeddings. + + If `ip_hidden_states` is `None`, this is equivalent to using JointAttnProcessor2_0. + + Args: + attn (`Attention`): + Attention instance. + hidden_states (`torch.FloatTensor`): + Input `hidden_states`. + encoder_hidden_states (`torch.FloatTensor`, *optional*): + The encoder hidden states. + attention_mask (`torch.FloatTensor`, *optional*): + Attention mask. + ip_hidden_states (`torch.FloatTensor`, *optional*): + Image embeddings. + temb (`torch.FloatTensor`, *optional*): + Timestep embeddings. + + Returns: + `torch.FloatTensor`: Output hidden states. + """ residual = hidden_states batch_size = hidden_states.shape[0] diff --git a/src/diffusers/models/embeddings.py b/src/diffusers/models/embeddings.py index ae0c92479180..bbde3774a181 100644 --- a/src/diffusers/models/embeddings.py +++ b/src/diffusers/models/embeddings.py @@ -2119,6 +2119,19 @@ def forward(self, id_embeds: torch.Tensor) -> torch.Tensor: class IPAdapterTimeImageProjectionBlock(nn.Module): + """Block for IPAdapterTimeImageProjection. + + Args: + hidden_dim (`int`, defaults to 1280): + The number of hidden channels. + dim_head (`int`, defaults to 64): + The number of head channels. + heads (`int`, defaults to 20): + Parallel attention heads. + ffn_ratio (`int`, defaults to 4): + The expansion ratio of feedforward network hidden layer channels. + """ + def __init__( self, hidden_dim: int = 1280, @@ -2152,7 +2165,21 @@ def __init__( self.attn.to_k = None self.attn.to_v = None - def forward(self, x, latents, timestep_emb): + def forward(self, x: torch.Tensor, latents: torch.Tensor, timestep_emb: torch.Tensor) -> torch.Tensor: + """Forward pass. + + Args: + x (`torch.Tensor`): + Image features. + latents (`torch.Tensor`): + Latent features. + timestep_emb (`torch.Tensor`): + Timestep embedding. + + Returns: + `torch.Tensor`: Output latent features. + """ + # Shift and scale for AdaLayerNorm emb = self.adaln_proj(self.adaln_silu(timestep_emb)) shift_msa, scale_msa, shift_mlp, scale_mlp = emb.chunk(4, dim=1) @@ -2192,6 +2219,33 @@ def forward(self, x, latents, timestep_emb): # Modified from https://github.com/mlfoundations/open_flamingo/blob/main/open_flamingo/src/helpers.py class IPAdapterTimeImageProjection(nn.Module): + """Resampler of SD3 IP-Adapter with timestep embedding. + + Args: + embed_dim (`int`, defaults to 1152): + The feature dimension. + output_dim (`int`, defaults to 2432): + The number of output channels. + hidden_dim (`int`, defaults to 1280): + The number of hidden channels. + depth (`int`, defaults to 4): + The number of blocks. + dim_head (`int`, defaults to 64): + The number of head channels. + heads (`int`, defaults to 20): + Parallel attention heads. + num_queries (`int`, defaults to 64): + The number of queries. + ffn_ratio (`int`, defaults to 4): + The expansion ratio of feedforward network hidden layer channels. + timestep_in_dim (`int`, defaults to 320): + The number of input channels for timestep embedding. + timestep_flip_sin_to_cos (`bool`, defaults to True): + Flip the timestep embedding order to `cos, sin` (if True) or `sin, cos` (if False). + timestep_freq_shift (`int`, defaults to 0): + Controls the timestep delta between frequencies between dimensions. + """ + def __init__( self, embed_dim: int = 1152, @@ -2217,7 +2271,17 @@ def __init__( self.time_proj = Timesteps(timestep_in_dim, timestep_flip_sin_to_cos, timestep_freq_shift) self.time_embedding = TimestepEmbedding(timestep_in_dim, hidden_dim, act_fn="silu") - def forward(self, x, timestep): + def forward(self, x: torch.Tensor, timestep: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]: + """Forward pass. + + Args: + x (`torch.Tensor`): + Image features. + timestep (`torch.Tensor`): + Timestep in denoising process. + Returns: + `Tuple`[`torch.Tensor`, `torch.Tensor`]: The pair (latents, timestep_emb). + """ timestep_emb = self.time_proj(timestep).to(dtype=x.dtype) timestep_emb = self.time_embedding(timestep_emb) diff --git a/src/diffusers/models/transformers/transformer_sd3.py b/src/diffusers/models/transformers/transformer_sd3.py index fb46f2c03f6f..fe0d37ea0c06 100644 --- a/src/diffusers/models/transformers/transformer_sd3.py +++ b/src/diffusers/models/transformers/transformer_sd3.py @@ -331,7 +331,19 @@ def _set_gradient_checkpointing(self, module, value=False): if hasattr(module, "gradient_checkpointing"): module.gradient_checkpointing = value - def _load_ip_adapter_weights(self, state_dict: Dict, low_cpu_mem_usage: bool): + def _load_ip_adapter_weights(self, state_dict: Dict, low_cpu_mem_usage: bool) -> None: + """Sets IP-Adapter attention processors, image projection, and loads state_dict. + + Args: + state_dict (`Dict`): + PyTorch state dict with keys "ip_adapter", which contains parameters for attention processors, and + "image_proj", which contains parameters for image projection net. + low_cpu_mem_usage (`bool`, *optional*, defaults to `True` if torch version >= 1.9.0 else `False`): + Speed up model loading only loading the pretrained weights and not initializing the weights. This also + tries to not use more than 1x model size in CPU memory (including peak memory) while loading the model. + Only supported for PyTorch >= 1.9.0. If you are using an older version of PyTorch, setting this + argument to `True` will raise an error. + """ # IP-Adapter cross attention parameters hidden_size = self.config.attention_head_dim * self.config.num_attention_heads ip_hidden_states_dim = self.config.attention_head_dim * self.config.num_attention_heads diff --git a/src/diffusers/pipelines/stable_diffusion_3/pipeline_stable_diffusion_3.py b/src/diffusers/pipelines/stable_diffusion_3/pipeline_stable_diffusion_3.py index 9efc433d75a5..265104bd8207 100644 --- a/src/diffusers/pipelines/stable_diffusion_3/pipeline_stable_diffusion_3.py +++ b/src/diffusers/pipelines/stable_diffusion_3/pipeline_stable_diffusion_3.py @@ -680,7 +680,16 @@ def interrupt(self): return self._interrupt # Adapted from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_xl.StableDiffusionXLPipeline.encode_image - def encode_image(self, image): + def encode_image(self, image: PipelineImageInput) -> torch.Tensor: + """Encodes the given image into a feature representation using a pre-trained image encoder. + + Args: + image (`PipelineImageInput`): + Input image to be encoded. + + Returns: + `torch.Tensor`: The encoded image feature representation. + """ if not isinstance(image, torch.Tensor): image = self.feature_extractor(image, return_tensors="pt").pixel_values @@ -690,17 +699,42 @@ def encode_image(self, image): # Adapted from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_xl.StableDiffusionXLPipeline.prepare_ip_adapter_image_embeds def prepare_ip_adapter_image_embeds( - self, ip_adapter_image, ip_adapter_image_embeds, device, num_images_per_prompt, do_classifier_free_guidance - ): - if ip_adapter_image_embeds is None: + self, + ip_adapter_image: Optional[PipelineImageInput] = None, + ip_adapter_image_embeds: Optional[torch.Tensor] = None, + device: Optional[torch.device] = None, + num_images_per_prompt: int = 1, + do_classifier_free_guidance: bool = True, + ) -> torch.Tensor: + """Prepares image embeddings for use in the IP-Adapter. + + Either `ip_adapter_image` or `ip_adapter_image_embeds` must be passed. + + Args: + ip_adapter_image (`PipelineImageInput`, *optional*): + The input image to extract features from for IP-Adapter. + ip_adapter_image_embeds (`torch.Tensor`, *optional*): + Precomputed image embeddings. + device: (`torch.device`, *optional*): + Torch device. + num_images_per_prompt (`int`, defaults to 1): + Number of images that should be generated per prompt. + do_classifier_free_guidance (`bool`, defaults to True): + Whether to use classifier free guidance or not. + """ + device = device or self._execution_device + + if ip_adapter_image_embeds is not None: + if do_classifier_free_guidance: + single_negative_image_embeds, single_image_embeds = ip_adapter_image_embeds.chunk(2) + else: + single_image_embeds = ip_adapter_image_embeds + elif ip_adapter_image is not None: single_image_embeds = self.encode_image(ip_adapter_image) if do_classifier_free_guidance: single_negative_image_embeds = torch.zeros_like(single_image_embeds) else: - if do_classifier_free_guidance: - single_negative_image_embeds, single_image_embeds = single_image_embeds.chunk(2) - else: - single_image_embeds = ip_adapter_image_embeds + raise ValueError("Neither `ip_adapter_image_embeds` or `ip_adapter_image_embeds` were provided.") image_embeds = torch.cat([single_image_embeds] * num_images_per_prompt, dim=0) @@ -733,7 +767,7 @@ def __call__( pooled_prompt_embeds: Optional[torch.FloatTensor] = None, negative_pooled_prompt_embeds: Optional[torch.FloatTensor] = None, ip_adapter_image: Optional[PipelineImageInput] = None, - ip_adapter_image_embeds: Optional[List[torch.Tensor]] = None, + ip_adapter_image_embeds: Optional[torch.Tensor] = None, output_type: Optional[str] = "pil", return_dict: bool = True, joint_attention_kwargs: Optional[Dict[str, Any]] = None, @@ -810,11 +844,10 @@ def __call__( weighting. If not provided, pooled negative_prompt_embeds will be generated from `negative_prompt` input argument. ip_adapter_image (`PipelineImageInput`, *optional*): Optional image input to work with IP Adapters. - ip_adapter_image_embeds (`List[torch.Tensor]`, *optional*): - Pre-generated image embeddings for IP-Adapter. It should be a list of length same as number of - IP-adapters. Each element should be a tensor of shape `(batch_size, num_images, emb_dim)`. It should - contain the negative image embedding if `do_classifier_free_guidance` is set to `True`. If not - provided, embeddings are computed from the `ip_adapter_image` input argument. + ip_adapter_image_embeds (`torch.Tensor`, *optional*): + Pre-generated image embeddings for IP-Adapter. Should be a tensor of shape `(batch_size, num_images, + emb_dim)`. It should contain the negative image embedding if `do_classifier_free_guidance` is set to + `True`. If not provided, embeddings are computed from the `ip_adapter_image` input argument. output_type (`str`, *optional*, defaults to `"pil"`): The output format of the generate image. Choose between [PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`. @@ -950,8 +983,6 @@ def __call__( ) # 6. Prepare image embeddings - # Either image is passed and ip_adapter is active - # Or image_embeds are passed directly if (ip_adapter_image is not None and self.is_ip_adapter_active) or ip_adapter_image_embeds is not None: ip_adapter_image_embeds = self.prepare_ip_adapter_image_embeds( ip_adapter_image, From 24e6880bec74ceb15f8e9af79175f064be9eb14e Mon Sep 17 00:00:00 2001 From: hlky <hlky@hlky.ac> Date: Mon, 16 Dec 2024 22:15:21 +0000 Subject: [PATCH 29/44] make --- .../pipelines/stable_diffusion_3/pipeline_stable_diffusion_3.py | 1 - 1 file changed, 1 deletion(-) diff --git a/src/diffusers/pipelines/stable_diffusion_3/pipeline_stable_diffusion_3.py b/src/diffusers/pipelines/stable_diffusion_3/pipeline_stable_diffusion_3.py index f0d8488570d4..7906f699ea2c 100644 --- a/src/diffusers/pipelines/stable_diffusion_3/pipeline_stable_diffusion_3.py +++ b/src/diffusers/pipelines/stable_diffusion_3/pipeline_stable_diffusion_3.py @@ -794,7 +794,6 @@ def __call__( skip_layer_guidance_stop: float = 0.2, skip_layer_guidance_start: float = 0.01, mu: Optional[float] = None, - ): r""" Function invoked when calling the pipeline for generation. From 43d2e77de14861bc02f76e2d65b132a41f94f38f Mon Sep 17 00:00:00 2001 From: Daniel Regado <danielregado@gmail.com> Date: Tue, 17 Dec 2024 00:43:31 +0000 Subject: [PATCH 30/44] More docs and small refactors --- docs/source/en/_toctree.yml | 2 + docs/source/en/api/attnprocessor.md | 4 +- .../source/en/api/loaders/transformers_sd3.md | 29 ++++++ src/diffusers/loaders/__init__.py | 2 + src/diffusers/loaders/ip_adapter.py | 14 +-- src/diffusers/loaders/transformers_sd3.py | 89 ++++++++++++++++++ src/diffusers/models/attention_processor.py | 4 +- .../models/transformers/transformer_sd3.py | 93 ++----------------- 8 files changed, 140 insertions(+), 97 deletions(-) create mode 100644 docs/source/en/api/loaders/transformers_sd3.md create mode 100644 src/diffusers/loaders/transformers_sd3.py diff --git a/docs/source/en/_toctree.yml b/docs/source/en/_toctree.yml index d1404a1d6ea6..1fd3ccb06efe 100644 --- a/docs/source/en/_toctree.yml +++ b/docs/source/en/_toctree.yml @@ -234,6 +234,8 @@ title: Textual Inversion - local: api/loaders/unet title: UNet + - local: api/loaders/transformers_sd3 + title: SD3Transformer2D - local: api/loaders/peft title: PEFT title: Loaders diff --git a/docs/source/en/api/attnprocessor.md b/docs/source/en/api/attnprocessor.md index db6a761d5607..e978e1eaadac 100644 --- a/docs/source/en/api/attnprocessor.md +++ b/docs/source/en/api/attnprocessor.md @@ -53,5 +53,5 @@ An attention processor is a class for applying different types of attention mech ## AttnProcessorNPU [[autodoc]] models.attention_processor.AttnProcessorNPU -## IPAdapterJointAttnProcessor2_0 -[[autodoc]] models.attention_processor.IPAdapterJointAttnProcessor2_0 \ No newline at end of file +## SD3IPAdapterJointAttnProcessor2_0 +[[autodoc]] models.attention_processor.SD3IPAdapterJointAttnProcessor2_0 \ No newline at end of file diff --git a/docs/source/en/api/loaders/transformers_sd3.md b/docs/source/en/api/loaders/transformers_sd3.md new file mode 100644 index 000000000000..fd202dd9b17d --- /dev/null +++ b/docs/source/en/api/loaders/transformers_sd3.md @@ -0,0 +1,29 @@ +<!--Copyright 2024 The HuggingFace Team. All rights reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with +the License. You may obtain a copy of the License at + +http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on +an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the +specific language governing permissions and limitations under the License. +--> + +# SD3Transformer2D + +This class is useful when *only* loading weights into a [`SD3Transformer2DModel`]. If you need to load weights into the text encoder or a text encoder and SD3Transformer2DModel, check [`SD3LoraLoaderMixin`](lora#diffusers.loaders.SD3LoraLoaderMixin) class instead. + +The [`SD3Transformer2DLoadersMixin`] class currently only loads IP-Adapter weights, but will be used in the future to save weights and load LoRAs. + +<Tip> + +To learn more about how to load LoRA weights, see the [LoRA](../../using-diffusers/loading_adapters#lora) loading guide. + +</Tip> + +## SD3Transformer2DLoadersMixin + +[[autodoc]] loaders.transformers_sd3.SD3Transformer2DLoadersMixin + - all + - _load_ip_adapter_weights \ No newline at end of file diff --git a/src/diffusers/loaders/__init__.py b/src/diffusers/loaders/__init__.py index e4ffdd2324af..bda9820d04c5 100644 --- a/src/diffusers/loaders/__init__.py +++ b/src/diffusers/loaders/__init__.py @@ -56,6 +56,7 @@ def text_encoder_attn_modules(text_encoder): if is_torch_available(): _import_structure["single_file_model"] = ["FromOriginalModelMixin"] + _import_structure["transformers_sd3"] = ["SD3Transformer2DLoadersMixin"] _import_structure["unet"] = ["UNet2DConditionLoadersMixin"] _import_structure["utils"] = ["AttnProcsLayers"] if is_transformers_available(): @@ -82,6 +83,7 @@ def text_encoder_attn_modules(text_encoder): if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT: if is_torch_available(): from .single_file_model import FromOriginalModelMixin + from .transformers_sd3 import SD3Transformer2DLoadersMixin from .unet import UNet2DConditionLoadersMixin from .utils import AttnProcsLayers diff --git a/src/diffusers/loaders/ip_adapter.py b/src/diffusers/loaders/ip_adapter.py index faa0528d7c4b..ddd943971220 100644 --- a/src/diffusers/loaders/ip_adapter.py +++ b/src/diffusers/loaders/ip_adapter.py @@ -40,9 +40,9 @@ AttnProcessor2_0, IPAdapterAttnProcessor, IPAdapterAttnProcessor2_0, - IPAdapterJointAttnProcessor2_0, IPAdapterXFormersAttnProcessor, JointAttnProcessor2_0, + SD3IPAdapterJointAttnProcessor2_0, ) @@ -369,7 +369,7 @@ def is_ip_adapter_active(self) -> bool: scales = [ attn_proc.scale for attn_proc in self.transformer.attn_processors.values() - if isinstance(attn_proc, IPAdapterJointAttnProcessor2_0) + if isinstance(attn_proc, SD3IPAdapterJointAttnProcessor2_0) ] return len(scales) > 0 and any(scale > 0 for scale in scales) @@ -379,7 +379,7 @@ def load_ip_adapter( self, pretrained_model_name_or_path_or_dict: Union[str, Dict[str, torch.Tensor]], subfolder: str, - weight_name: str, + weight_name: str = "ip-adapter.safetensors", image_encoder_folder: Optional[str] = "image_encoder", **kwargs, ) -> None: @@ -396,7 +396,7 @@ def load_ip_adapter( subfolder (`str`): The subfolder location of a model file within a larger model repository on the Hub or locally. If a list is passed, it should have the same length as `weight_name`. - weight_name (`str`): + weight_name (`str`, defaults to "ip-adapter.safetensors"): The name of the weight file to load. If a list is passed, it should have the same length as `subfolder`. image_encoder_folder (`str`, *optional*, defaults to `image_encoder`): @@ -547,7 +547,7 @@ def set_ip_adapter_scale(self, scale: float) -> None: """ for attn_processor in self.transformer.attn_processors.values(): - if isinstance(attn_processor, IPAdapterJointAttnProcessor2_0): + if isinstance(attn_processor, SD3IPAdapterJointAttnProcessor2_0): attn_processor.scale = scale def unload_ip_adapter(self) -> None: @@ -577,7 +577,9 @@ def unload_ip_adapter(self) -> None: # Restore original attention processors layers attn_procs = { - name: (JointAttnProcessor2_0() if isinstance(value, IPAdapterJointAttnProcessor2_0) else value.__class__()) + name: ( + JointAttnProcessor2_0() if isinstance(value, SD3IPAdapterJointAttnProcessor2_0) else value.__class__() + ) for name, value in self.transformer.attn_processors.items() } self.transformer.set_attn_processor(attn_procs) diff --git a/src/diffusers/loaders/transformers_sd3.py b/src/diffusers/loaders/transformers_sd3.py new file mode 100644 index 000000000000..e93659318fd4 --- /dev/null +++ b/src/diffusers/loaders/transformers_sd3.py @@ -0,0 +1,89 @@ +# Copyright 2024 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from typing import Dict + +from ..models.attention_processor import SD3IPAdapterJointAttnProcessor2_0 +from ..models.embeddings import IPAdapterTimeImageProjection +from ..models.modeling_utils import _LOW_CPU_MEM_USAGE_DEFAULT, load_model_dict_into_meta + + +class SD3Transformer2DLoadersMixin: + """Load IP-Adapters and LoRA layers into a `[SD3Transformer2DModel]`.""" + + def _load_ip_adapter_weights(self, state_dict: Dict, low_cpu_mem_usage: bool = _LOW_CPU_MEM_USAGE_DEFAULT) -> None: + """Sets IP-Adapter attention processors, image projection, and loads state_dict. + + Args: + state_dict (`Dict`): + State dict with keys "ip_adapter", which contains parameters for attention processors, and + "image_proj", which contains parameters for image projection net. + low_cpu_mem_usage (`bool`, *optional*, defaults to `True` if torch version >= 1.9.0 else `False`): + Speed up model loading only loading the pretrained weights and not initializing the weights. This also + tries to not use more than 1x model size in CPU memory (including peak memory) while loading the model. + Only supported for PyTorch >= 1.9.0. If you are using an older version of PyTorch, setting this + argument to `True` will raise an error. + """ + # IP-Adapter cross attention parameters + hidden_size = self.config.attention_head_dim * self.config.num_attention_heads + ip_hidden_states_dim = self.config.attention_head_dim * self.config.num_attention_heads + timesteps_emb_dim = state_dict["ip_adapter"]["0.norm_ip.linear.weight"].shape[1] + + # Dict where key is transformer layer index, value is attention processor's state dict + # ip_adapter state dict keys example: "0.norm_ip.linear.weight" + layer_state_dict = {idx: {} for idx in range(len(self.attn_processors))} + for key, weights in state_dict["ip_adapter"].items(): + idx, name = key.split(".", maxsplit=1) + layer_state_dict[int(idx)][name] = weights + + # Create IP-Adapter attention processor + attn_procs = {} + for idx, name in enumerate(self.attn_processors.keys()): + attn_procs[name] = SD3IPAdapterJointAttnProcessor2_0( + hidden_size=hidden_size, + ip_hidden_states_dim=ip_hidden_states_dim, + head_dim=self.config.attention_head_dim, + timesteps_emb_dim=timesteps_emb_dim, + ).to(self.device, dtype=self.dtype) + + if not low_cpu_mem_usage: + attn_procs[name].load_state_dict(layer_state_dict[idx], strict=True) + else: + load_model_dict_into_meta( + attn_procs[name], layer_state_dict[idx], device=self.device, dtype=self.dtype + ) + + self.set_attn_processor(attn_procs) + + # Image projetion parameters + embed_dim = state_dict["image_proj"]["proj_in.weight"].shape[1] + output_dim = state_dict["image_proj"]["proj_out.weight"].shape[0] + hidden_dim = state_dict["image_proj"]["proj_in.weight"].shape[0] + heads = state_dict["image_proj"]["layers.0.attn.to_q.weight"].shape[0] // 64 + num_queries = state_dict["image_proj"]["latents"].shape[1] + timestep_in_dim = state_dict["image_proj"]["time_embedding.linear_1.weight"].shape[1] + + # Image projection + self.image_proj = IPAdapterTimeImageProjection( + embed_dim=embed_dim, + output_dim=output_dim, + hidden_dim=hidden_dim, + heads=heads, + num_queries=num_queries, + timestep_in_dim=timestep_in_dim, + ).to(device=self.device, dtype=self.dtype) + + if not low_cpu_mem_usage: + self.image_proj.load_state_dict(state_dict, strict=True) + else: + load_model_dict_into_meta(self.image_proj, state_dict, device=self.device, dtype=self.dtype) diff --git a/src/diffusers/models/attention_processor.py b/src/diffusers/models/attention_processor.py index 500da45a7802..2e17c18540ec 100644 --- a/src/diffusers/models/attention_processor.py +++ b/src/diffusers/models/attention_processor.py @@ -5160,7 +5160,7 @@ def __call__( return hidden_states -class IPAdapterJointAttnProcessor2_0(torch.nn.Module): +class SD3IPAdapterJointAttnProcessor2_0(torch.nn.Module): """ Attention processor for IP-Adapter used typically in processing the SD3-like self-attention projections, with additional image-based information and timestep embeddings. @@ -5844,7 +5844,7 @@ def __call__( IPAdapterAttnProcessor, IPAdapterAttnProcessor2_0, IPAdapterXFormersAttnProcessor, - IPAdapterJointAttnProcessor2_0, + SD3IPAdapterJointAttnProcessor2_0, PAGIdentitySelfAttnProcessor2_0, PAGCFGIdentitySelfAttnProcessor2_0, LoRAAttnProcessor, diff --git a/src/diffusers/models/transformers/transformer_sd3.py b/src/diffusers/models/transformers/transformer_sd3.py index fe0d37ea0c06..6c7d1c2868d5 100644 --- a/src/diffusers/models/transformers/transformer_sd3.py +++ b/src/diffusers/models/transformers/transformer_sd3.py @@ -19,19 +19,19 @@ from ...configuration_utils import ConfigMixin, register_to_config from ...loaders import FromOriginalModelMixin, PeftAdapterMixin +from ...loaders.transformers_sd3 import SD3Transformer2DLoadersMixin from ...models.attention import FeedForward, JointTransformerBlock from ...models.attention_processor import ( Attention, AttentionProcessor, FusedJointAttnProcessor2_0, - IPAdapterJointAttnProcessor2_0, JointAttnProcessor2_0, ) -from ...models.modeling_utils import ModelMixin, load_model_dict_into_meta +from ...models.modeling_utils import ModelMixin from ...models.normalization import AdaLayerNormContinuous, AdaLayerNormZero from ...utils import USE_PEFT_BACKEND, is_torch_version, logging, scale_lora_layers, unscale_lora_layers from ...utils.torch_utils import maybe_allow_in_graph -from ..embeddings import CombinedTimestepTextProjEmbeddings, IPAdapterTimeImageProjection, PatchEmbed +from ..embeddings import CombinedTimestepTextProjEmbeddings, PatchEmbed from ..modeling_outputs import Transformer2DModelOutput @@ -104,7 +104,9 @@ def forward(self, hidden_states: torch.Tensor, temb: torch.Tensor): return hidden_states -class SD3Transformer2DModel(ModelMixin, ConfigMixin, PeftAdapterMixin, FromOriginalModelMixin): +class SD3Transformer2DModel( + ModelMixin, ConfigMixin, PeftAdapterMixin, FromOriginalModelMixin, SD3Transformer2DLoadersMixin +): """ The Transformer model introduced in Stable Diffusion 3. @@ -331,89 +333,6 @@ def _set_gradient_checkpointing(self, module, value=False): if hasattr(module, "gradient_checkpointing"): module.gradient_checkpointing = value - def _load_ip_adapter_weights(self, state_dict: Dict, low_cpu_mem_usage: bool) -> None: - """Sets IP-Adapter attention processors, image projection, and loads state_dict. - - Args: - state_dict (`Dict`): - PyTorch state dict with keys "ip_adapter", which contains parameters for attention processors, and - "image_proj", which contains parameters for image projection net. - low_cpu_mem_usage (`bool`, *optional*, defaults to `True` if torch version >= 1.9.0 else `False`): - Speed up model loading only loading the pretrained weights and not initializing the weights. This also - tries to not use more than 1x model size in CPU memory (including peak memory) while loading the model. - Only supported for PyTorch >= 1.9.0. If you are using an older version of PyTorch, setting this - argument to `True` will raise an error. - """ - # IP-Adapter cross attention parameters - hidden_size = self.config.attention_head_dim * self.config.num_attention_heads - ip_hidden_states_dim = self.config.attention_head_dim * self.config.num_attention_heads - timesteps_emb_dim = state_dict["ip_adapter"]["0.norm_ip.linear.weight"].shape[1] - - # Dict where key is transformer layer index, value is attention processor's state dict - # ip_adapter state dict keys example: "0.norm_ip.linear.weight" - layer_state_dict = {idx: {} for idx in range(len(self.attn_processors))} - for key, weights in state_dict["ip_adapter"].items(): - idx, name = key.split(".", maxsplit=1) - layer_state_dict[int(idx)][name] = weights - - # Create IP-Adapter attention processor - attn_procs = {} - for idx, name in enumerate(self.attn_processors.keys()): - attn_procs[name] = IPAdapterJointAttnProcessor2_0( - hidden_size=hidden_size, - ip_hidden_states_dim=ip_hidden_states_dim, - head_dim=self.config.attention_head_dim, - timesteps_emb_dim=timesteps_emb_dim, - ).to(self.device, dtype=self.dtype) - - if not low_cpu_mem_usage: - attn_procs[name].load_state_dict(layer_state_dict[idx], strict=True) - else: - load_model_dict_into_meta( - attn_procs[name], layer_state_dict[idx], device=self.device, dtype=self.dtype - ) - - self.set_attn_processor(attn_procs) - - # Convert image_proj state dict to diffusers - image_proj_state_dict = {} - for key, value in state_dict["image_proj"].items(): - if key.startswith("layers."): - idx = key.split(".")[1] - key = key.replace(f"layers.{idx}.0.norm1", f"layers.{idx}.ln0") - key = key.replace(f"layers.{idx}.0.norm2", f"layers.{idx}.ln1") - key = key.replace(f"layers.{idx}.0.to_q", f"layers.{idx}.attn.to_q") - key = key.replace(f"layers.{idx}.0.to_kv", f"layers.{idx}.attn.to_kv") - key = key.replace(f"layers.{idx}.0.to_out", f"layers.{idx}.attn.to_out.0") - key = key.replace(f"layers.{idx}.1.0", f"layers.{idx}.adaln_norm") - key = key.replace(f"layers.{idx}.1.1", f"layers.{idx}.ff.net.0.proj") - key = key.replace(f"layers.{idx}.1.3", f"layers.{idx}.ff.net.2") - key = key.replace(f"layers.{idx}.2.1", f"layers.{idx}.adaln_proj") - image_proj_state_dict[key] = value - - # Image projetion parameters - embed_dim = image_proj_state_dict["proj_in.weight"].shape[1] - output_dim = image_proj_state_dict["proj_out.weight"].shape[0] - hidden_dim = image_proj_state_dict["proj_in.weight"].shape[0] - heads = image_proj_state_dict["layers.0.attn.to_q.weight"].shape[0] // 64 - num_queries = image_proj_state_dict["latents"].shape[1] - timestep_in_dim = image_proj_state_dict["time_embedding.linear_1.weight"].shape[1] - - # Image projection - self.image_proj = IPAdapterTimeImageProjection( - embed_dim=embed_dim, - output_dim=output_dim, - hidden_dim=hidden_dim, - heads=heads, - num_queries=num_queries, - timestep_in_dim=timestep_in_dim, - ).to(device=self.device, dtype=self.dtype) - - if not low_cpu_mem_usage: - self.image_proj.load_state_dict(image_proj_state_dict, strict=True) - else: - load_model_dict_into_meta(self.image_proj, image_proj_state_dict, device=self.device, dtype=self.dtype) - def forward( self, hidden_states: torch.FloatTensor, From 44e3847715a155203ff0ae68eb6de934e1d781ff Mon Sep 17 00:00:00 2001 From: Daniel Regado <danielregado@gmail.com> Date: Wed, 18 Dec 2024 00:47:37 +0000 Subject: [PATCH 31/44] Fix in loading state dict --- src/diffusers/loaders/transformers_sd3.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/diffusers/loaders/transformers_sd3.py b/src/diffusers/loaders/transformers_sd3.py index e93659318fd4..435d1da06ca1 100644 --- a/src/diffusers/loaders/transformers_sd3.py +++ b/src/diffusers/loaders/transformers_sd3.py @@ -84,6 +84,6 @@ def _load_ip_adapter_weights(self, state_dict: Dict, low_cpu_mem_usage: bool = _ ).to(device=self.device, dtype=self.dtype) if not low_cpu_mem_usage: - self.image_proj.load_state_dict(state_dict, strict=True) + self.image_proj.load_state_dict(state_dict["image_proj"], strict=True) else: - load_model_dict_into_meta(self.image_proj, state_dict, device=self.device, dtype=self.dtype) + load_model_dict_into_meta(self.image_proj, state_dict["image_proj"], device=self.device, dtype=self.dtype) From 178e513afcbef913b1121b98c502857f61caf7a0 Mon Sep 17 00:00:00 2001 From: Daniel Regado <danielregado@gmail.com> Date: Wed, 18 Dec 2024 02:20:54 +0000 Subject: [PATCH 32/44] Enabled cpu offload --- src/diffusers/loaders/ip_adapter.py | 8 ++++---- .../stable_diffusion_3/pipeline_stable_diffusion_3.py | 9 ++++++--- 2 files changed, 10 insertions(+), 7 deletions(-) diff --git a/src/diffusers/loaders/ip_adapter.py b/src/diffusers/loaders/ip_adapter.py index ddd943971220..11ce4f1634d7 100644 --- a/src/diffusers/loaders/ip_adapter.py +++ b/src/diffusers/loaders/ip_adapter.py @@ -378,8 +378,8 @@ def is_ip_adapter_active(self) -> bool: def load_ip_adapter( self, pretrained_model_name_or_path_or_dict: Union[str, Dict[str, torch.Tensor]], - subfolder: str, weight_name: str = "ip-adapter.safetensors", + subfolder: Optional[str] = None, image_encoder_folder: Optional[str] = "image_encoder", **kwargs, ) -> None: @@ -393,12 +393,12 @@ def load_ip_adapter( with [`ModelMixin.save_pretrained`]. - A [torch state dict](https://pytorch.org/tutorials/beginner/saving_loading_models.html#what-is-a-state-dict). - subfolder (`str`): - The subfolder location of a model file within a larger model repository on the Hub or locally. If a - list is passed, it should have the same length as `weight_name`. weight_name (`str`, defaults to "ip-adapter.safetensors"): The name of the weight file to load. If a list is passed, it should have the same length as `subfolder`. + subfolder (`str`, *optional*): + The subfolder location of a model file within a larger model repository on the Hub or locally. If a + list is passed, it should have the same length as `weight_name`. image_encoder_folder (`str`, *optional*, defaults to `image_encoder`): The subfolder location of the image encoder within a larger model repository on the Hub or locally. Pass `None` to not load the image encoder. If the image encoder is located in a folder inside diff --git a/src/diffusers/pipelines/stable_diffusion_3/pipeline_stable_diffusion_3.py b/src/diffusers/pipelines/stable_diffusion_3/pipeline_stable_diffusion_3.py index 7906f699ea2c..8c9a88c09946 100644 --- a/src/diffusers/pipelines/stable_diffusion_3/pipeline_stable_diffusion_3.py +++ b/src/diffusers/pipelines/stable_diffusion_3/pipeline_stable_diffusion_3.py @@ -183,6 +183,7 @@ class StableDiffusion3Pipeline(DiffusionPipeline, SD3LoraLoaderMixin, FromSingle """ model_cpu_offload_seq = "text_encoder->text_encoder_2->text_encoder_3->transformer->vae" + _exclude_from_cpu_offload = ["image_encoder"] _optional_components = ["image_encoder", "feature_extractor"] _callback_tensor_inputs = ["latents", "prompt_embeds", "negative_prompt_embeds", "negative_pooled_prompt_embeds"] @@ -694,12 +695,14 @@ def interrupt(self): return self._interrupt # Adapted from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_xl.StableDiffusionXLPipeline.encode_image - def encode_image(self, image: PipelineImageInput) -> torch.Tensor: + def encode_image(self, image: PipelineImageInput, device: torch.device) -> torch.Tensor: """Encodes the given image into a feature representation using a pre-trained image encoder. Args: image (`PipelineImageInput`): Input image to be encoded. + device: (`torch.device`): + Torch device. Returns: `torch.Tensor`: The encoded image feature representation. @@ -707,7 +710,7 @@ def encode_image(self, image: PipelineImageInput) -> torch.Tensor: if not isinstance(image, torch.Tensor): image = self.feature_extractor(image, return_tensors="pt").pixel_values - image = image.to(device=self.device, dtype=self.dtype) + image = image.to(device=device, dtype=self.dtype) return self.image_encoder(image, output_hidden_states=True).hidden_states[-2] @@ -744,7 +747,7 @@ def prepare_ip_adapter_image_embeds( else: single_image_embeds = ip_adapter_image_embeds elif ip_adapter_image is not None: - single_image_embeds = self.encode_image(ip_adapter_image) + single_image_embeds = self.encode_image(ip_adapter_image, device) if do_classifier_free_guidance: single_negative_image_embeds = torch.zeros_like(single_image_embeds) else: From 8daca65764c61878ca2411eac3da8d08d349d8dc Mon Sep 17 00:00:00 2001 From: Daniel Regado <danielregado@gmail.com> Date: Wed, 18 Dec 2024 11:05:57 +0000 Subject: [PATCH 33/44] Renaming from transformers_sd3 to transformer_sd3 --- docs/source/en/_toctree.yml | 2 +- .../api/loaders/{transformers_sd3.md => transformer_sd3.md} | 0 src/diffusers/loaders/__init__.py | 4 ++-- .../loaders/{transformers_sd3.py => transformer_sd3.py} | 0 src/diffusers/models/transformers/transformer_sd3.py | 3 +-- 5 files changed, 4 insertions(+), 5 deletions(-) rename docs/source/en/api/loaders/{transformers_sd3.md => transformer_sd3.md} (100%) rename src/diffusers/loaders/{transformers_sd3.py => transformer_sd3.py} (100%) diff --git a/docs/source/en/_toctree.yml b/docs/source/en/_toctree.yml index 0e576a0df2bf..6ac66db73026 100644 --- a/docs/source/en/_toctree.yml +++ b/docs/source/en/_toctree.yml @@ -238,7 +238,7 @@ title: Textual Inversion - local: api/loaders/unet title: UNet - - local: api/loaders/transformers_sd3 + - local: api/loaders/transformer_sd3 title: SD3Transformer2D - local: api/loaders/peft title: PEFT diff --git a/docs/source/en/api/loaders/transformers_sd3.md b/docs/source/en/api/loaders/transformer_sd3.md similarity index 100% rename from docs/source/en/api/loaders/transformers_sd3.md rename to docs/source/en/api/loaders/transformer_sd3.md diff --git a/src/diffusers/loaders/__init__.py b/src/diffusers/loaders/__init__.py index 3aaefb39cca5..dce271967ab3 100644 --- a/src/diffusers/loaders/__init__.py +++ b/src/diffusers/loaders/__init__.py @@ -56,7 +56,7 @@ def text_encoder_attn_modules(text_encoder): if is_torch_available(): _import_structure["single_file_model"] = ["FromOriginalModelMixin"] - _import_structure["transformers_sd3"] = ["SD3Transformer2DLoadersMixin"] + _import_structure["transformer_sd3"] = ["SD3Transformer2DLoadersMixin"] _import_structure["unet"] = ["UNet2DConditionLoadersMixin"] _import_structure["utils"] = ["AttnProcsLayers"] if is_transformers_available(): @@ -85,7 +85,7 @@ def text_encoder_attn_modules(text_encoder): if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT: if is_torch_available(): from .single_file_model import FromOriginalModelMixin - from .transformers_sd3 import SD3Transformer2DLoadersMixin + from .transformer_sd3 import SD3Transformer2DLoadersMixin from .unet import UNet2DConditionLoadersMixin from .utils import AttnProcsLayers diff --git a/src/diffusers/loaders/transformers_sd3.py b/src/diffusers/loaders/transformer_sd3.py similarity index 100% rename from src/diffusers/loaders/transformers_sd3.py rename to src/diffusers/loaders/transformer_sd3.py diff --git a/src/diffusers/models/transformers/transformer_sd3.py b/src/diffusers/models/transformers/transformer_sd3.py index 6c7d1c2868d5..235f9c3daf99 100644 --- a/src/diffusers/models/transformers/transformer_sd3.py +++ b/src/diffusers/models/transformers/transformer_sd3.py @@ -18,8 +18,7 @@ import torch.nn.functional as F from ...configuration_utils import ConfigMixin, register_to_config -from ...loaders import FromOriginalModelMixin, PeftAdapterMixin -from ...loaders.transformers_sd3 import SD3Transformer2DLoadersMixin +from ...loaders import FromOriginalModelMixin, PeftAdapterMixin, SD3Transformer2DLoadersMixin from ...models.attention import FeedForward, JointTransformerBlock from ...models.attention_processor import ( Attention, From 7c918db6be3d7e088415bc5530b1ee63d326d282 Mon Sep 17 00:00:00 2001 From: Daniel Regado <danielregado@gmail.com> Date: Wed, 18 Dec 2024 11:37:52 +0000 Subject: [PATCH 34/44] Missing rename --- docs/source/en/api/loaders/transformer_sd3.md | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/docs/source/en/api/loaders/transformer_sd3.md b/docs/source/en/api/loaders/transformer_sd3.md index fd202dd9b17d..4fc9603054b4 100644 --- a/docs/source/en/api/loaders/transformer_sd3.md +++ b/docs/source/en/api/loaders/transformer_sd3.md @@ -24,6 +24,6 @@ To learn more about how to load LoRA weights, see the [LoRA](../../using-diffuse ## SD3Transformer2DLoadersMixin -[[autodoc]] loaders.transformers_sd3.SD3Transformer2DLoadersMixin +[[autodoc]] loaders.transformer_sd3.SD3Transformer2DLoadersMixin - all - _load_ip_adapter_weights \ No newline at end of file From 99a6d594d20f16bc698a0d501b3673f3f2efce07 Mon Sep 17 00:00:00 2001 From: Daniel Regado <danielregado@gmail.com> Date: Wed, 18 Dec 2024 16:08:11 +0000 Subject: [PATCH 35/44] Updated docs for SD3 pipeline --- .../stable_diffusion/stable_diffusion_3.md | 69 ++++++++++++++++++- 1 file changed, 68 insertions(+), 1 deletion(-) diff --git a/docs/source/en/api/pipelines/stable_diffusion/stable_diffusion_3.md b/docs/source/en/api/pipelines/stable_diffusion/stable_diffusion_3.md index 8170c5280d38..a1659dd41690 100644 --- a/docs/source/en/api/pipelines/stable_diffusion/stable_diffusion_3.md +++ b/docs/source/en/api/pipelines/stable_diffusion/stable_diffusion_3.md @@ -59,9 +59,76 @@ image.save("sd3_hello_world.png") - [`stabilityai/stable-diffusion-3.5-large`](https://huggingface.co/stabilityai/stable-diffusion-3-5-large) - [`stabilityai/stable-diffusion-3.5-large-turbo`](https://huggingface.co/stabilityai/stable-diffusion-3-5-large-turbo) +## Image Prompting with IP-Adapters + +An IP-Adapter lets you prompt SD3 with images, in addition to the text prompt. This is especially useful when describing complex concepts that are difficult to articulate through text alone and you have reference images. To load and use an IP-Adapter, you need: + +- `image_encoder`: Pre-trained vision model used to obtain image features, usually a CLIP image encoder. +- `feature_extractor`: Image processor that prepares the input image for the choosen `image_encoder`. +- `ip_adapter_id`: Checkpoint containing parameters of image cross attention layers and image projection. + +IP-Adapters are trained for a specific model architecture, so they also work in finetuned variations of the base model. You can use the `set_ip_adapter_scale()` function to adjust how strongly the output aligns with the image prompt. The higher the value, the more closely the model follows the image prompt. A default value of 0.5 is typically a good balance, ensuring the model considers both the text and image prompts equally. + +```python +import torch +from PIL import Image + +from diffusers import StableDiffusion3Pipeline +from transformers import SiglipVisionModel, SiglipImageProcessor + +image_encoder_id = "google/siglip-so400m-patch14-384" +ip_adapter_id = "InstantX/SD3.5-Large-IP-Adapter" + +feature_extractor = SiglipImageProcessor.from_pretrained( + image_encoder_id, + torch_dtype=torch.float16 +) +image_encoder = SiglipVisionModel.from_pretrained( + image_encoder_id, + torch_dtype=torch.float16 +).to( "cuda") + +pipe = StableDiffusion3Pipeline.from_pretrained( + "stabilityai/stable-diffusion-3.5-large", + torch_dtype=torch.float16, + feature_extractor=feature_extractor, + image_encoder=image_encoder, +).to("cuda") + +pipe.load_ip_adapter(ip_adapter_id) +pipe.set_ip_adapter_scale(0.6) + +ref_img = Image.open("image.jpg").convert('RGB') + +image = pipe( + width=1024, + height=1024, + prompt="a cat", + negative_prompt="lowres, low quality, worst quality", + num_inference_steps=24, + guidance_scale=5.0, + ip_adapter_image=ref_img +).images[0] + +image.save("result.jpg") +``` + +<div class="justify-center"> + <img src="https://github.com/user-attachments/assets/bc93dd29-ff50-48de-971f-c306653a3a10"/> + <figcaption class="mt-2 text-sm text-center text-gray-500">IP-Adapter examples with prompt "a cat"</figcaption> +</div> + + +<Tip> + +Check out [IP-Adapter](../../../using-diffusers/ip_adapter) to learn more about how IP-Adapters work. + +</Tip> + + ## Memory Optimisations for SD3 -SD3 uses three text encoders, one if which is the very large T5-XXL model. This makes it challenging to run the model on GPUs with less than 24GB of VRAM, even when using `fp16` precision. The following section outlines a few memory optimizations in Diffusers that make it easier to run SD3 on low resource hardware. +SD3 uses three text encoders, one of which is the very large T5-XXL model. This makes it challenging to run the model on GPUs with less than 24GB of VRAM, even when using `fp16` precision. The following section outlines a few memory optimizations in Diffusers that make it easier to run SD3 on low resource hardware. ### Running Inference with Model Offloading From 02a6d90eddc9e5e52377474cd39537f77c601d9c Mon Sep 17 00:00:00 2001 From: Daniel Regado <35548192+guiyrt@users.noreply.github.com> Date: Wed, 18 Dec 2024 17:15:20 +0000 Subject: [PATCH 36/44] Update docs/source/en/api/pipelines/stable_diffusion/stable_diffusion_3.md Co-authored-by: Steven Liu <59462357+stevhliu@users.noreply.github.com> --- .../en/api/pipelines/stable_diffusion/stable_diffusion_3.md | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/docs/source/en/api/pipelines/stable_diffusion/stable_diffusion_3.md b/docs/source/en/api/pipelines/stable_diffusion/stable_diffusion_3.md index a1659dd41690..30b3a550e223 100644 --- a/docs/source/en/api/pipelines/stable_diffusion/stable_diffusion_3.md +++ b/docs/source/en/api/pipelines/stable_diffusion/stable_diffusion_3.md @@ -64,7 +64,7 @@ image.save("sd3_hello_world.png") An IP-Adapter lets you prompt SD3 with images, in addition to the text prompt. This is especially useful when describing complex concepts that are difficult to articulate through text alone and you have reference images. To load and use an IP-Adapter, you need: - `image_encoder`: Pre-trained vision model used to obtain image features, usually a CLIP image encoder. -- `feature_extractor`: Image processor that prepares the input image for the choosen `image_encoder`. +- `feature_extractor`: Image processor that prepares the input image for the chosen `image_encoder`. - `ip_adapter_id`: Checkpoint containing parameters of image cross attention layers and image projection. IP-Adapters are trained for a specific model architecture, so they also work in finetuned variations of the base model. You can use the `set_ip_adapter_scale()` function to adjust how strongly the output aligns with the image prompt. The higher the value, the more closely the model follows the image prompt. A default value of 0.5 is typically a good balance, ensuring the model considers both the text and image prompts equally. From 64ab7f9fe8e2f44eb64afc7f4c79bfedf61ccaa9 Mon Sep 17 00:00:00 2001 From: Daniel Regado <danielregado@gmail.com> Date: Wed, 18 Dec 2024 17:22:12 +0000 Subject: [PATCH 37/44] Minor doc correction --- .../en/api/pipelines/stable_diffusion/stable_diffusion_3.md | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/docs/source/en/api/pipelines/stable_diffusion/stable_diffusion_3.md b/docs/source/en/api/pipelines/stable_diffusion/stable_diffusion_3.md index 30b3a550e223..7cbeac1f5827 100644 --- a/docs/source/en/api/pipelines/stable_diffusion/stable_diffusion_3.md +++ b/docs/source/en/api/pipelines/stable_diffusion/stable_diffusion_3.md @@ -67,7 +67,7 @@ An IP-Adapter lets you prompt SD3 with images, in addition to the text prompt. T - `feature_extractor`: Image processor that prepares the input image for the chosen `image_encoder`. - `ip_adapter_id`: Checkpoint containing parameters of image cross attention layers and image projection. -IP-Adapters are trained for a specific model architecture, so they also work in finetuned variations of the base model. You can use the `set_ip_adapter_scale()` function to adjust how strongly the output aligns with the image prompt. The higher the value, the more closely the model follows the image prompt. A default value of 0.5 is typically a good balance, ensuring the model considers both the text and image prompts equally. +IP-Adapters are trained for a specific model architecture, so they also work in finetuned variations of the base model. You can use the [`~SD3IPAdapterMixin.set_ip_adapter_scale`] function to adjust how strongly the output aligns with the image prompt. The higher the value, the more closely the model follows the image prompt. A default value of 0.5 is typically a good balance, ensuring the model considers both the text and image prompts equally. ```python import torch From b254aa3a247f8af13b273c6af0e69b4e9be64d6b Mon Sep 17 00:00:00 2001 From: Daniel Regado <danielregado@gmail.com> Date: Wed, 18 Dec 2024 18:46:31 +0000 Subject: [PATCH 38/44] Updated img source to hf/documentation-images --- .../en/api/pipelines/stable_diffusion/stable_diffusion_3.md | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/docs/source/en/api/pipelines/stable_diffusion/stable_diffusion_3.md b/docs/source/en/api/pipelines/stable_diffusion/stable_diffusion_3.md index 7cbeac1f5827..eb67964ab0bd 100644 --- a/docs/source/en/api/pipelines/stable_diffusion/stable_diffusion_3.md +++ b/docs/source/en/api/pipelines/stable_diffusion/stable_diffusion_3.md @@ -114,7 +114,7 @@ image.save("result.jpg") ``` <div class="justify-center"> - <img src="https://github.com/user-attachments/assets/bc93dd29-ff50-48de-971f-c306653a3a10"/> + <img src="https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/sd3_ip_adapter_example.png"/> <figcaption class="mt-2 text-sm text-center text-gray-500">IP-Adapter examples with prompt "a cat"</figcaption> </div> From 5c28161bf4891387e0562acc00e190a7d5a2a4eb Mon Sep 17 00:00:00 2001 From: Daniel Regado <danielregado@gmail.com> Date: Thu, 19 Dec 2024 03:22:13 +0000 Subject: [PATCH 39/44] image_proj is now called from SD3Transformer2DModel --- .../models/transformers/transformer_sd3.py | 16 ++++++++++++++-- .../pipeline_stable_diffusion_3.py | 15 ++------------- 2 files changed, 16 insertions(+), 15 deletions(-) diff --git a/src/diffusers/models/transformers/transformer_sd3.py b/src/diffusers/models/transformers/transformer_sd3.py index 235f9c3daf99..0898b2a7c481 100644 --- a/src/diffusers/models/transformers/transformer_sd3.py +++ b/src/diffusers/models/transformers/transformer_sd3.py @@ -338,6 +338,7 @@ def forward( encoder_hidden_states: torch.FloatTensor = None, pooled_projections: torch.FloatTensor = None, timestep: torch.LongTensor = None, + ip_adapter_image_embeds: Optional[torch.FloatTensor] = None, block_controlnet_hidden_states: List = None, joint_attention_kwargs: Optional[Dict[str, Any]] = None, return_dict: bool = True, @@ -351,10 +352,12 @@ def forward( Input `hidden_states`. encoder_hidden_states (`torch.FloatTensor` of shape `(batch size, sequence_len, embed_dims)`): Conditional embeddings (embeddings computed from the input conditions such as prompts) to use. - pooled_projections (`torch.FloatTensor` of shape `(batch_size, projection_dim)`): Embeddings projected - from the embeddings of input conditions. + pooled_projections (`torch.FloatTensor` of shape `(batch_size, projection_dim)`): + Embeddings projected from the embeddings of input conditions. timestep (`torch.LongTensor`): Used to indicate denoising step. + ip_adapter_image_embeds (`torch.FloatTensor`): + Image embeddings for IP-Adapter. block_controlnet_hidden_states (`list` of `torch.Tensor`): A list of tensors that if specified are added to the residuals of transformer blocks. joint_attention_kwargs (`dict`, *optional*): @@ -392,6 +395,15 @@ def forward( temb = self.time_text_embed(timestep, pooled_projections) encoder_hidden_states = self.context_embedder(encoder_hidden_states) + if ip_adapter_image_embeds is not None: + ip_hidden_states, ip_temb = self.image_proj(ip_adapter_image_embeds, timestep) + ip_embeds = {"ip_hidden_states": ip_hidden_states, "temb": ip_temb} + + if joint_attention_kwargs is None: + joint_attention_kwargs = ip_embeds + else: + joint_attention_kwargs.update(**ip_embeds) + for index_block, block in enumerate(self.transformer_blocks): # Skip specified layers is_skip = True if skip_layers is not None and index_block in skip_layers else False diff --git a/src/diffusers/pipelines/stable_diffusion_3/pipeline_stable_diffusion_3.py b/src/diffusers/pipelines/stable_diffusion_3/pipeline_stable_diffusion_3.py index 8c9a88c09946..f0c6df4c823c 100644 --- a/src/diffusers/pipelines/stable_diffusion_3/pipeline_stable_diffusion_3.py +++ b/src/diffusers/pipelines/stable_diffusion_3/pipeline_stable_diffusion_3.py @@ -1034,7 +1034,6 @@ def __call__( ) # 7. Denoising loop - with self.progress_bar(total=num_inference_steps) as progress_bar: for i, t in enumerate(timesteps): if self.interrupt: @@ -1045,21 +1044,10 @@ def __call__( # broadcast to batch dimension in a way that's compatible with ONNX/Core ML timestep = t.expand(latent_model_input.shape[0]) - if ip_adapter_image_embeds is not None: - ip_hidden_states, temb = self.transformer.image_proj( - ip_adapter_image_embeds, timestep.to(dtype=latents.dtype) - ) - - image_prompt_embeds = {"ip_hidden_states": ip_hidden_states, "temb": temb} - - if self.joint_attention_kwargs is None: - self._joint_attention_kwargs = image_prompt_embeds - else: - self._joint_attention_kwargs.update(**image_prompt_embeds) - noise_pred = self.transformer( hidden_states=latent_model_input, timestep=timestep, + ip_adapter_image_embeds=ip_adapter_image_embeds, encoder_hidden_states=prompt_embeds, pooled_projections=pooled_prompt_embeds, joint_attention_kwargs=self.joint_attention_kwargs, @@ -1082,6 +1070,7 @@ def __call__( noise_pred_skip_layers = self.transformer( hidden_states=latent_model_input, timestep=timestep, + ip_adapter_image_embeds=ip_adapter_image_embeds, encoder_hidden_states=original_prompt_embeds, pooled_projections=original_pooled_prompt_embeds, joint_attention_kwargs=self.joint_attention_kwargs, From b882e1bdbe40f7b2f74c77954a65883a4f57e8b8 Mon Sep 17 00:00:00 2001 From: Daniel Regado <danielregado@gmail.com> Date: Thu, 19 Dec 2024 15:07:04 +0000 Subject: [PATCH 40/44] ip_adapter_image_embeds go through joint_attention_kwargs --- src/diffusers/models/transformers/transformer_sd3.py | 12 +++--------- .../pipeline_stable_diffusion_3.py | 6 +++++- 2 files changed, 8 insertions(+), 10 deletions(-) diff --git a/src/diffusers/models/transformers/transformer_sd3.py b/src/diffusers/models/transformers/transformer_sd3.py index 0898b2a7c481..415540ef7f6a 100644 --- a/src/diffusers/models/transformers/transformer_sd3.py +++ b/src/diffusers/models/transformers/transformer_sd3.py @@ -338,7 +338,6 @@ def forward( encoder_hidden_states: torch.FloatTensor = None, pooled_projections: torch.FloatTensor = None, timestep: torch.LongTensor = None, - ip_adapter_image_embeds: Optional[torch.FloatTensor] = None, block_controlnet_hidden_states: List = None, joint_attention_kwargs: Optional[Dict[str, Any]] = None, return_dict: bool = True, @@ -356,8 +355,6 @@ def forward( Embeddings projected from the embeddings of input conditions. timestep (`torch.LongTensor`): Used to indicate denoising step. - ip_adapter_image_embeds (`torch.FloatTensor`): - Image embeddings for IP-Adapter. block_controlnet_hidden_states (`list` of `torch.Tensor`): A list of tensors that if specified are added to the residuals of transformer blocks. joint_attention_kwargs (`dict`, *optional*): @@ -395,14 +392,11 @@ def forward( temb = self.time_text_embed(timestep, pooled_projections) encoder_hidden_states = self.context_embedder(encoder_hidden_states) - if ip_adapter_image_embeds is not None: + if joint_attention_kwargs is not None and "ip_adapter_image_embeds" in joint_attention_kwargs: + ip_adapter_image_embeds = joint_attention_kwargs.pop("ip_adapter_image_embeds") ip_hidden_states, ip_temb = self.image_proj(ip_adapter_image_embeds, timestep) - ip_embeds = {"ip_hidden_states": ip_hidden_states, "temb": ip_temb} - if joint_attention_kwargs is None: - joint_attention_kwargs = ip_embeds - else: - joint_attention_kwargs.update(**ip_embeds) + joint_attention_kwargs.update(ip_hidden_states=ip_hidden_states, temb=ip_temb) for index_block, block in enumerate(self.transformer_blocks): # Skip specified layers diff --git a/src/diffusers/pipelines/stable_diffusion_3/pipeline_stable_diffusion_3.py b/src/diffusers/pipelines/stable_diffusion_3/pipeline_stable_diffusion_3.py index f0c6df4c823c..29b5f627d241 100644 --- a/src/diffusers/pipelines/stable_diffusion_3/pipeline_stable_diffusion_3.py +++ b/src/diffusers/pipelines/stable_diffusion_3/pipeline_stable_diffusion_3.py @@ -1033,6 +1033,11 @@ def __call__( self.do_classifier_free_guidance, ) + if self.joint_attention_kwargs is None: + self._joint_attention_kwargs = {"ip_adapter_image_embeds": ip_adapter_image_embeds} + else: + self._joint_attention_kwargs.update(ip_adapter_image_embeds=ip_adapter_image_embeds) + # 7. Denoising loop with self.progress_bar(total=num_inference_steps) as progress_bar: for i, t in enumerate(timesteps): @@ -1047,7 +1052,6 @@ def __call__( noise_pred = self.transformer( hidden_states=latent_model_input, timestep=timestep, - ip_adapter_image_embeds=ip_adapter_image_embeds, encoder_hidden_states=prompt_embeds, pooled_projections=pooled_prompt_embeds, joint_attention_kwargs=self.joint_attention_kwargs, From 988447f863e2d41a39851f661bd9374a51da4f11 Mon Sep 17 00:00:00 2001 From: Daniel Regado <danielregado@gmail.com> Date: Thu, 19 Dec 2024 15:37:03 +0000 Subject: [PATCH 41/44] Warning for sequential cpu offloading with image_encoder --- .../pipeline_stable_diffusion_3.py | 13 +++++++++++-- 1 file changed, 11 insertions(+), 2 deletions(-) diff --git a/src/diffusers/pipelines/stable_diffusion_3/pipeline_stable_diffusion_3.py b/src/diffusers/pipelines/stable_diffusion_3/pipeline_stable_diffusion_3.py index 29b5f627d241..c153d63f652e 100644 --- a/src/diffusers/pipelines/stable_diffusion_3/pipeline_stable_diffusion_3.py +++ b/src/diffusers/pipelines/stable_diffusion_3/pipeline_stable_diffusion_3.py @@ -182,8 +182,7 @@ class StableDiffusion3Pipeline(DiffusionPipeline, SD3LoraLoaderMixin, FromSingle Image processor for IP Adapter. """ - model_cpu_offload_seq = "text_encoder->text_encoder_2->text_encoder_3->transformer->vae" - _exclude_from_cpu_offload = ["image_encoder"] + model_cpu_offload_seq = "text_encoder->text_encoder_2->text_encoder_3->image_encoder->transformer->vae" _optional_components = ["image_encoder", "feature_extractor"] _callback_tensor_inputs = ["latents", "prompt_embeds", "negative_prompt_embeds", "negative_pooled_prompt_embeds"] @@ -760,6 +759,16 @@ def prepare_ip_adapter_image_embeds( image_embeds = torch.cat([negative_image_embeds, image_embeds], dim=0) return image_embeds.to(device=device) + + def enable_sequential_cpu_offload(self, *args, **kwargs): + if self.image_encoder is not None and "image_encoder" not in self._exclude_from_cpu_offload: + logger.warning( + "`pipe.enable_sequential_cpu_offload()` might fail for `image_encoder` if it uses " + "`torch.nn.MultiheadAttention`. You can exclude `image_encoder` from CPU offloading by calling " + "`pipe._exclude_from_cpu_offload.append('image_encoder')` before `pipe.enable_sequential_cpu_offload()`." + ) + + super().enable_sequential_cpu_offload(*args, **kwargs) @torch.no_grad() @replace_example_docstring(EXAMPLE_DOC_STRING) From 98f4521a7b797e68b791c192d8b4eb7b94a94140 Mon Sep 17 00:00:00 2001 From: Daniel Regado <danielregado@gmail.com> Date: Thu, 19 Dec 2024 15:59:52 +0000 Subject: [PATCH 42/44] make style quality --- .../pipelines/stable_diffusion_3/pipeline_stable_diffusion_3.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/diffusers/pipelines/stable_diffusion_3/pipeline_stable_diffusion_3.py b/src/diffusers/pipelines/stable_diffusion_3/pipeline_stable_diffusion_3.py index c153d63f652e..36b82df4ca34 100644 --- a/src/diffusers/pipelines/stable_diffusion_3/pipeline_stable_diffusion_3.py +++ b/src/diffusers/pipelines/stable_diffusion_3/pipeline_stable_diffusion_3.py @@ -759,7 +759,7 @@ def prepare_ip_adapter_image_embeds( image_embeds = torch.cat([negative_image_embeds, image_embeds], dim=0) return image_embeds.to(device=device) - + def enable_sequential_cpu_offload(self, *args, **kwargs): if self.image_encoder is not None and "image_encoder" not in self._exclude_from_cpu_offload: logger.warning( From 18cd8e4f6ec525dacc44ed873c26490944dff341 Mon Sep 17 00:00:00 2001 From: YiYi Xu <yixu310@gmail.com> Date: Thu, 19 Dec 2024 11:28:20 -1000 Subject: [PATCH 43/44] Update src/diffusers/models/attention.py --- src/diffusers/models/attention.py | 4 ---- 1 file changed, 4 deletions(-) diff --git a/src/diffusers/models/attention.py b/src/diffusers/models/attention.py index 5578f4980459..4d1dae879f11 100644 --- a/src/diffusers/models/attention.py +++ b/src/diffusers/models/attention.py @@ -209,10 +209,6 @@ def forward( encoder_hidden_states, emb=temb ) - # Empty dict if None is passed - if joint_attention_kwargs is None: - joint_attention_kwargs = {} - # Attention. attn_output, context_attn_output = self.attn( hidden_states=norm_hidden_states, From 65b477fb167a21abf6df7fdf9a6d38cab26d15e3 Mon Sep 17 00:00:00 2001 From: YiYi Xu <yixu310@gmail.com> Date: Thu, 19 Dec 2024 14:07:24 -1000 Subject: [PATCH 44/44] Update src/diffusers/pipelines/stable_diffusion_3/pipeline_stable_diffusion_3.py --- .../pipelines/stable_diffusion_3/pipeline_stable_diffusion_3.py | 1 - 1 file changed, 1 deletion(-) diff --git a/src/diffusers/pipelines/stable_diffusion_3/pipeline_stable_diffusion_3.py b/src/diffusers/pipelines/stable_diffusion_3/pipeline_stable_diffusion_3.py index 36b82df4ca34..a53d786798ca 100644 --- a/src/diffusers/pipelines/stable_diffusion_3/pipeline_stable_diffusion_3.py +++ b/src/diffusers/pipelines/stable_diffusion_3/pipeline_stable_diffusion_3.py @@ -1083,7 +1083,6 @@ def __call__( noise_pred_skip_layers = self.transformer( hidden_states=latent_model_input, timestep=timestep, - ip_adapter_image_embeds=ip_adapter_image_embeds, encoder_hidden_states=original_prompt_embeds, pooled_projections=original_pooled_prompt_embeds, joint_attention_kwargs=self.joint_attention_kwargs,