From 927553f32ad59f5914a1058ffea0259df6494f1a Mon Sep 17 00:00:00 2001 From: linoytsaban Date: Wed, 30 Oct 2024 22:28:09 +0200 Subject: [PATCH 01/45] initial commit --- .../community/pipeline_flux_rf_inversion.py | 933 ++++++++++++++++++ 1 file changed, 933 insertions(+) create mode 100644 examples/community/pipeline_flux_rf_inversion.py diff --git a/examples/community/pipeline_flux_rf_inversion.py b/examples/community/pipeline_flux_rf_inversion.py new file mode 100644 index 000000000000..75c7813c7513 --- /dev/null +++ b/examples/community/pipeline_flux_rf_inversion.py @@ -0,0 +1,933 @@ +# Copyright 2024 Black Forest Labs and The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import inspect +from typing import Any, Callable, Dict, List, Optional, Union + +import numpy as np +import torch +from transformers import CLIPTextModel, CLIPTokenizer, T5EncoderModel, T5TokenizerFast + +from diffusers.image_processor import VaeImageProcessor +from diffusers.loaders import FluxLoraLoaderMixin, FromSingleFileMixin +from diffusers.models.autoencoders import AutoencoderKL +from diffusers.models.transformers import FluxTransformer2DModel +from diffusers.pipelines.flux.pipeline_output import FluxPipelineOutput +from diffusers.pipelines.pipeline_utils import DiffusionPipeline +from diffusers.schedulers import FlowMatchEulerDiscreteScheduler +from diffusers.utils import ( + USE_PEFT_BACKEND, + is_torch_xla_available, + logging, + replace_example_docstring, + scale_lora_layers, + unscale_lora_layers, +) + + +if is_torch_xla_available(): + import torch_xla.core.xla_model as xm + + XLA_AVAILABLE = True +else: + XLA_AVAILABLE = False + + +logger = logging.get_logger(__name__) # pylint: disable=invalid-name + +EXAMPLE_DOC_STRING = """ + Examples: + ```py + >>> import torch + >>> from diffusers import FluxPipeline + + >>> pipe = FluxPipeline.from_pretrained("black-forest-labs/FLUX.1-schnell", torch_dtype=torch.bfloat16) + >>> pipe.to("cuda") + >>> prompt = "A cat holding a sign that says hello world" + >>> # Depending on the variant being used, the pipeline call will slightly vary. + >>> # Refer to the pipeline documentation for more details. + >>> image = pipe(prompt, num_inference_steps=4, guidance_scale=0.0).images[0] + >>> image.save("flux.png") + ``` +""" + + +def calculate_shift( + image_seq_len, + base_seq_len: int = 256, + max_seq_len: int = 4096, + base_shift: float = 0.5, + max_shift: float = 1.16, +): + m = (max_shift - base_shift) / (max_seq_len - base_seq_len) + b = base_shift - m * base_seq_len + mu = image_seq_len * m + b + return mu + + +# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.retrieve_timesteps +def retrieve_timesteps( + scheduler, + num_inference_steps: Optional[int] = None, + device: Optional[Union[str, torch.device]] = None, + timesteps: Optional[List[int]] = None, + sigmas: Optional[List[float]] = None, + **kwargs, +): + r""" + Calls the scheduler's `set_timesteps` method and retrieves timesteps from the scheduler after the call. Handles + custom timesteps. Any kwargs will be supplied to `scheduler.set_timesteps`. + + Args: + scheduler (`SchedulerMixin`): + The scheduler to get timesteps from. + num_inference_steps (`int`): + The number of diffusion steps used when generating samples with a pre-trained model. If used, `timesteps` + must be `None`. + device (`str` or `torch.device`, *optional*): + The device to which the timesteps should be moved to. If `None`, the timesteps are not moved. + timesteps (`List[int]`, *optional*): + Custom timesteps used to override the timestep spacing strategy of the scheduler. If `timesteps` is passed, + `num_inference_steps` and `sigmas` must be `None`. + sigmas (`List[float]`, *optional*): + Custom sigmas used to override the timestep spacing strategy of the scheduler. If `sigmas` is passed, + `num_inference_steps` and `timesteps` must be `None`. + + Returns: + `Tuple[torch.Tensor, int]`: A tuple where the first element is the timestep schedule from the scheduler and the + second element is the number of inference steps. + """ + if timesteps is not None and sigmas is not None: + raise ValueError("Only one of `timesteps` or `sigmas` can be passed. Please choose one to set custom values") + if timesteps is not None: + accepts_timesteps = "timesteps" in set(inspect.signature(scheduler.set_timesteps).parameters.keys()) + if not accepts_timesteps: + raise ValueError( + f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom" + f" timestep schedules. Please check whether you are using the correct scheduler." + ) + scheduler.set_timesteps(timesteps=timesteps, device=device, **kwargs) + timesteps = scheduler.timesteps + num_inference_steps = len(timesteps) + elif sigmas is not None: + accept_sigmas = "sigmas" in set(inspect.signature(scheduler.set_timesteps).parameters.keys()) + if not accept_sigmas: + raise ValueError( + f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom" + f" sigmas schedules. Please check whether you are using the correct scheduler." + ) + scheduler.set_timesteps(sigmas=sigmas, device=device, **kwargs) + timesteps = scheduler.timesteps + num_inference_steps = len(timesteps) + else: + scheduler.set_timesteps(num_inference_steps, device=device, **kwargs) + timesteps = scheduler.timesteps + return timesteps, num_inference_steps + + +class RFInversionFluxPipeline( + DiffusionPipeline, + FluxLoraLoaderMixin, + FromSingleFileMixin, + TextualInversionLoaderMixin, +): + r""" + The Flux pipeline for text-to-image generation. + + Reference: https://blackforestlabs.ai/announcing-black-forest-labs/ + + Args: + transformer ([`FluxTransformer2DModel`]): + Conditional Transformer (MMDiT) architecture to denoise the encoded image latents. + scheduler ([`FlowMatchEulerDiscreteScheduler`]): + A scheduler to be used in combination with `transformer` to denoise the encoded image latents. + vae ([`AutoencoderKL`]): + Variational Auto-Encoder (VAE) Model to encode and decode images to and from latent representations. + text_encoder ([`CLIPTextModel`]): + [CLIP](https://huggingface.co/docs/transformers/model_doc/clip#transformers.CLIPTextModel), specifically + the [clip-vit-large-patch14](https://huggingface.co/openai/clip-vit-large-patch14) variant. + text_encoder_2 ([`T5EncoderModel`]): + [T5](https://huggingface.co/docs/transformers/en/model_doc/t5#transformers.T5EncoderModel), specifically + the [google/t5-v1_1-xxl](https://huggingface.co/google/t5-v1_1-xxl) variant. + tokenizer (`CLIPTokenizer`): + Tokenizer of class + [CLIPTokenizer](https://huggingface.co/docs/transformers/en/model_doc/clip#transformers.CLIPTokenizer). + tokenizer_2 (`T5TokenizerFast`): + Second Tokenizer of class + [T5TokenizerFast](https://huggingface.co/docs/transformers/en/model_doc/t5#transformers.T5TokenizerFast). + """ + + model_cpu_offload_seq = "text_encoder->text_encoder_2->transformer->vae" + _optional_components = [] + _callback_tensor_inputs = ["latents", "prompt_embeds"] + + def __init__( + self, + scheduler: FlowMatchEulerDiscreteScheduler, + vae: AutoencoderKL, + text_encoder: CLIPTextModel, + tokenizer: CLIPTokenizer, + text_encoder_2: T5EncoderModel, + tokenizer_2: T5TokenizerFast, + transformer: FluxTransformer2DModel, + ): + super().__init__() + + self.register_modules( + vae=vae, + text_encoder=text_encoder, + text_encoder_2=text_encoder_2, + tokenizer=tokenizer, + tokenizer_2=tokenizer_2, + transformer=transformer, + scheduler=scheduler, + ) + self.vae_scale_factor = ( + 2 ** (len(self.vae.config.block_out_channels) - 1) if hasattr(self, "vae") and self.vae is not None else 8 + ) + self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor) + self.tokenizer_max_length = ( + self.tokenizer.model_max_length if hasattr(self, "tokenizer") and self.tokenizer is not None else 77 + ) + self.default_sample_size = 128 + + def _get_t5_prompt_embeds( + self, + prompt: Union[str, List[str]] = None, + num_images_per_prompt: int = 1, + max_sequence_length: int = 512, + device: Optional[torch.device] = None, + dtype: Optional[torch.dtype] = None, + ): + device = device or self._execution_device + dtype = dtype or self.text_encoder.dtype + + prompt = [prompt] if isinstance(prompt, str) else prompt + batch_size = len(prompt) + + if isinstance(self, TextualInversionLoaderMixin): + prompt = self.maybe_convert_prompt(prompt, self.tokenizer_2) + + text_inputs = self.tokenizer_2( + prompt, + padding="max_length", + max_length=max_sequence_length, + truncation=True, + return_length=False, + return_overflowing_tokens=False, + return_tensors="pt", + ) + text_input_ids = text_inputs.input_ids + untruncated_ids = self.tokenizer_2(prompt, padding="longest", return_tensors="pt").input_ids + + if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal(text_input_ids, untruncated_ids): + removed_text = self.tokenizer_2.batch_decode(untruncated_ids[:, self.tokenizer_max_length - 1 : -1]) + logger.warning( + "The following part of your input was truncated because `max_sequence_length` is set to " + f" {max_sequence_length} tokens: {removed_text}" + ) + + prompt_embeds = self.text_encoder_2(text_input_ids.to(device), output_hidden_states=False)[0] + + dtype = self.text_encoder_2.dtype + prompt_embeds = prompt_embeds.to(dtype=dtype, device=device) + + _, seq_len, _ = prompt_embeds.shape + + # duplicate text embeddings and attention mask for each generation per prompt, using mps friendly method + prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1) + prompt_embeds = prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1) + + return prompt_embeds + + def _get_clip_prompt_embeds( + self, + prompt: Union[str, List[str]], + num_images_per_prompt: int = 1, + device: Optional[torch.device] = None, + ): + device = device or self._execution_device + + prompt = [prompt] if isinstance(prompt, str) else prompt + batch_size = len(prompt) + + if isinstance(self, TextualInversionLoaderMixin): + prompt = self.maybe_convert_prompt(prompt, self.tokenizer) + + text_inputs = self.tokenizer( + prompt, + padding="max_length", + max_length=self.tokenizer_max_length, + truncation=True, + return_overflowing_tokens=False, + return_length=False, + return_tensors="pt", + ) + + text_input_ids = text_inputs.input_ids + untruncated_ids = self.tokenizer(prompt, padding="longest", return_tensors="pt").input_ids + if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal(text_input_ids, untruncated_ids): + removed_text = self.tokenizer.batch_decode(untruncated_ids[:, self.tokenizer_max_length - 1 : -1]) + logger.warning( + "The following part of your input was truncated because CLIP can only handle sequences up to" + f" {self.tokenizer_max_length} tokens: {removed_text}" + ) + prompt_embeds = self.text_encoder(text_input_ids.to(device), output_hidden_states=False) + + # Use pooled output of CLIPTextModel + prompt_embeds = prompt_embeds.pooler_output + prompt_embeds = prompt_embeds.to(dtype=self.text_encoder.dtype, device=device) + + # duplicate text embeddings for each generation per prompt, using mps friendly method + prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt) + prompt_embeds = prompt_embeds.view(batch_size * num_images_per_prompt, -1) + + return prompt_embeds + + def encode_prompt( + self, + prompt: Union[str, List[str]], + prompt_2: Union[str, List[str]], + device: Optional[torch.device] = None, + num_images_per_prompt: int = 1, + prompt_embeds: Optional[torch.FloatTensor] = None, + pooled_prompt_embeds: Optional[torch.FloatTensor] = None, + max_sequence_length: int = 512, + lora_scale: Optional[float] = None, + ): + r""" + + Args: + prompt (`str` or `List[str]`, *optional*): + prompt to be encoded + prompt_2 (`str` or `List[str]`, *optional*): + The prompt or prompts to be sent to the `tokenizer_2` and `text_encoder_2`. If not defined, `prompt` is + used in all text-encoders + device: (`torch.device`): + torch device + num_images_per_prompt (`int`): + number of images that should be generated per prompt + prompt_embeds (`torch.FloatTensor`, *optional*): + Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not + provided, text embeddings will be generated from `prompt` input argument. + pooled_prompt_embeds (`torch.FloatTensor`, *optional*): + Pre-generated pooled text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. + If not provided, pooled text embeddings will be generated from `prompt` input argument. + lora_scale (`float`, *optional*): + A lora scale that will be applied to all LoRA layers of the text encoder if LoRA layers are loaded. + """ + device = device or self._execution_device + + # set lora scale so that monkey patched LoRA + # function of text encoder can correctly access it + if lora_scale is not None and isinstance(self, FluxLoraLoaderMixin): + self._lora_scale = lora_scale + + # dynamically adjust the LoRA scale + if self.text_encoder is not None and USE_PEFT_BACKEND: + scale_lora_layers(self.text_encoder, lora_scale) + if self.text_encoder_2 is not None and USE_PEFT_BACKEND: + scale_lora_layers(self.text_encoder_2, lora_scale) + + prompt = [prompt] if isinstance(prompt, str) else prompt + + if prompt_embeds is None: + prompt_2 = prompt_2 or prompt + prompt_2 = [prompt_2] if isinstance(prompt_2, str) else prompt_2 + + # We only use the pooled prompt output from the CLIPTextModel + pooled_prompt_embeds = self._get_clip_prompt_embeds( + prompt=prompt, + device=device, + num_images_per_prompt=num_images_per_prompt, + ) + prompt_embeds = self._get_t5_prompt_embeds( + prompt=prompt_2, + num_images_per_prompt=num_images_per_prompt, + max_sequence_length=max_sequence_length, + device=device, + ) + + if self.text_encoder is not None: + if isinstance(self, FluxLoraLoaderMixin) and USE_PEFT_BACKEND: + # Retrieve the original scale by scaling back the LoRA layers + unscale_lora_layers(self.text_encoder, lora_scale) + + if self.text_encoder_2 is not None: + if isinstance(self, FluxLoraLoaderMixin) and USE_PEFT_BACKEND: + # Retrieve the original scale by scaling back the LoRA layers + unscale_lora_layers(self.text_encoder_2, lora_scale) + + dtype = self.text_encoder.dtype if self.text_encoder is not None else self.transformer.dtype + text_ids = torch.zeros(prompt_embeds.shape[1], 3).to(device=device, dtype=dtype) + + return prompt_embeds, pooled_prompt_embeds, text_ids + + @torch.no_grad() + # Modified from diffusers.pipelines.ledits_pp.pipeline_leditspp_stable_diffusion.LEditsPPPipelineStableDiffusion.encode_image + def encode_image(self, image, dtype=None, height=None, width=None, resize_mode="default", crops_coords=None): + image = self.image_processor.preprocess( + image=image, height=height, width=width, resize_mode=resize_mode, crops_coords=crops_coords + ) + resized = self.image_processor.postprocess(image=image, output_type="pil") + + if max(image.shape[-2:]) > self.vae.config["sample_size"] * 1.5: + logger.warning( + "Your input images far exceed the default resolution of the underlying diffusion model. " + "The output images may contain severe artifacts! " + "Consider down-sampling the input using the `height` and `width` parameters" + ) + image = image.to(dtype) + + x0 = self.vae.encode(image.to(self.device)).latent_dist.sample() + x0 = (x0 - self.vae.config.shift_factor) * self.vae.config.scaling_factor + x0 = x0.to(dtype) + return x0, resized + def check_inputs( + self, + prompt, + prompt_2, + height, + width, + prompt_embeds=None, + pooled_prompt_embeds=None, + callback_on_step_end_tensor_inputs=None, + max_sequence_length=None, + ): + if height % self.vae_scale_factor != 0 or width % self.vae_scale_factor != 0: + raise ValueError( + f"`height` and `width` have to be divisible by {self.vae_scale_factor} but are {height} and {width}." + ) + + if callback_on_step_end_tensor_inputs is not None and not all( + k in self._callback_tensor_inputs for k in callback_on_step_end_tensor_inputs + ): + raise ValueError( + f"`callback_on_step_end_tensor_inputs` has to be in {self._callback_tensor_inputs}, but found {[k for k in callback_on_step_end_tensor_inputs if k not in self._callback_tensor_inputs]}" + ) + + if prompt is not None and prompt_embeds is not None: + raise ValueError( + f"Cannot forward both `prompt`: {prompt} and `prompt_embeds`: {prompt_embeds}. Please make sure to" + " only forward one of the two." + ) + elif prompt_2 is not None and prompt_embeds is not None: + raise ValueError( + f"Cannot forward both `prompt_2`: {prompt_2} and `prompt_embeds`: {prompt_embeds}. Please make sure to" + " only forward one of the two." + ) + elif prompt is None and prompt_embeds is None: + raise ValueError( + "Provide either `prompt` or `prompt_embeds`. Cannot leave both `prompt` and `prompt_embeds` undefined." + ) + elif prompt is not None and (not isinstance(prompt, str) and not isinstance(prompt, list)): + raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}") + elif prompt_2 is not None and (not isinstance(prompt_2, str) and not isinstance(prompt_2, list)): + raise ValueError(f"`prompt_2` has to be of type `str` or `list` but is {type(prompt_2)}") + + if prompt_embeds is not None and pooled_prompt_embeds is None: + raise ValueError( + "If `prompt_embeds` are provided, `pooled_prompt_embeds` also have to be passed. Make sure to generate `pooled_prompt_embeds` from the same text encoder that was used to generate `prompt_embeds`." + ) + + if max_sequence_length is not None and max_sequence_length > 512: + raise ValueError(f"`max_sequence_length` cannot be greater than 512 but is {max_sequence_length}") + + @staticmethod + def _prepare_latent_image_ids(batch_size, height, width, device, dtype): + latent_image_ids = torch.zeros(height, width, 3) + latent_image_ids[..., 1] = latent_image_ids[..., 1] + torch.arange(height)[:, None] + latent_image_ids[..., 2] = latent_image_ids[..., 2] + torch.arange(width)[None, :] + + latent_image_id_height, latent_image_id_width, latent_image_id_channels = latent_image_ids.shape + + latent_image_ids = latent_image_ids.reshape( + latent_image_id_height * latent_image_id_width, latent_image_id_channels + ) + + return latent_image_ids.to(device=device, dtype=dtype) + + @staticmethod + def _pack_latents(latents, batch_size, num_channels_latents, height, width): + latents = latents.view(batch_size, num_channels_latents, height // 2, 2, width // 2, 2) + latents = latents.permute(0, 2, 4, 1, 3, 5) + latents = latents.reshape(batch_size, (height // 2) * (width // 2), num_channels_latents * 4) + + return latents + + @staticmethod + def _unpack_latents(latents, height, width, vae_scale_factor): + batch_size, num_patches, channels = latents.shape + + height = height // vae_scale_factor + width = width // vae_scale_factor + + latents = latents.view(batch_size, height // 2, width // 2, channels // 4, 2, 2) + latents = latents.permute(0, 3, 1, 4, 2, 5) + + latents = latents.reshape(batch_size, channels // (2 * 2), height, width) + + return latents + + def enable_vae_slicing(self): + r""" + Enable sliced VAE decoding. When this option is enabled, the VAE will split the input tensor in slices to + compute decoding in several steps. This is useful to save some memory and allow larger batch sizes. + """ + self.vae.enable_slicing() + + def disable_vae_slicing(self): + r""" + Disable sliced VAE decoding. If `enable_vae_slicing` was previously enabled, this method will go back to + computing decoding in one step. + """ + self.vae.disable_slicing() + + def enable_vae_tiling(self): + r""" + Enable tiled VAE decoding. When this option is enabled, the VAE will split the input tensor into tiles to + compute decoding and encoding in several steps. This is useful for saving a large amount of memory and to allow + processing larger images. + """ + self.vae.enable_tiling() + + def disable_vae_tiling(self): + r""" + Disable tiled VAE decoding. If `enable_vae_tiling` was previously enabled, this method will go back to + computing decoding in one step. + """ + self.vae.disable_tiling() + + def prepare_latents( + self, + batch_size, + num_channels_latents, + height, + width, + dtype, + device, + generator, + latents=None, + ): + height = int(height) // self.vae_scale_factor + width = int(width) // self.vae_scale_factor + + shape = (batch_size, num_channels_latents, height, width) + + if latents is not None: + latent_image_ids = self._prepare_latent_image_ids(batch_size, height, width, device, dtype) + return latents.to(device=device, dtype=dtype), latent_image_ids + + if isinstance(generator, list) and len(generator) != batch_size: + raise ValueError( + f"You have passed a list of generators of length {len(generator)}, but requested an effective batch" + f" size of {batch_size}. Make sure the batch size matches the length of the generators." + ) + + latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype) + latents = self._pack_latents(latents, batch_size, num_channels_latents, height, width) + + latent_image_ids = self._prepare_latent_image_ids(batch_size, height // 2, width // 2, device, dtype) + + return latents, latent_image_ids + + def prepare_inverted_latents( + self, + batch_size, + num_channels_latents, + height, + width, + dtype, + device, + generator, + ): + + latent_image_ids = self._prepare_latent_image_ids(batch_size, height, width, device, dtype) + + if isinstance(generator, list) and len(generator) != batch_size: + raise ValueError( + f"You have passed a list of generators of length {len(generator)}, but requested an effective batch" + f" size of {batch_size}. Make sure the batch size matches the length of the generators." + ) + + latents = self._pack_latents(self.inverted_latents, batch_size, num_channels_latents, height, width) + return latents, latent_image_ids + + @property + def guidance_scale(self): + return self._guidance_scale + + @property + def joint_attention_kwargs(self): + return self._joint_attention_kwargs + + @property + def num_timesteps(self): + return self._num_timesteps + + @property + def interrupt(self): + return self._interrupt + + @torch.no_grad() + @replace_example_docstring(EXAMPLE_DOC_STRING) + def __call__( + self, + prompt: Union[str, List[str]] = None, + prompt_2: Optional[Union[str, List[str]]] = None, + height: Optional[int] = None, + width: Optional[int] = None, + num_inference_steps: int = 28, + eta_base: float = 1.0, # base eta value + eta_trend: str = 'linear_decrease', # constant, linear_increase, linear_decrease + start_step: int = 0, # 0-based indexing, closed interval + end_step: int = 30, # 0-based indexing, open interval + timesteps: List[int] = None, + use_shift_t_sampling: bool =True, + guidance_scale: float = 3.5, + num_images_per_prompt: Optional[int] = 1, + generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None, + latents: Optional[torch.FloatTensor] = None, + prompt_embeds: Optional[torch.FloatTensor] = None, + pooled_prompt_embeds: Optional[torch.FloatTensor] = None, + output_type: Optional[str] = "pil", + return_dict: bool = True, + joint_attention_kwargs: Optional[Dict[str, Any]] = None, + callback_on_step_end: Optional[Callable[[int, int, Dict], None]] = None, + callback_on_step_end_tensor_inputs: List[str] = ["latents"], + max_sequence_length: int = 512, + ): + r""" + Function invoked when calling the pipeline for generation. + + Args: + prompt (`str` or `List[str]`, *optional*): + The prompt or prompts to guide the image generation. If not defined, one has to pass `prompt_embeds`. + instead. + prompt_2 (`str` or `List[str]`, *optional*): + The prompt or prompts to be sent to `tokenizer_2` and `text_encoder_2`. If not defined, `prompt` is + will be used instead + height (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor): + The height in pixels of the generated image. This is set to 1024 by default for the best results. + width (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor): + The width in pixels of the generated image. This is set to 1024 by default for the best results. + num_inference_steps (`int`, *optional*, defaults to 50): + The number of denoising steps. More denoising steps usually lead to a higher quality image at the + expense of slower inference. + timesteps (`List[int]`, *optional*): + Custom timesteps to use for the denoising process with schedulers which support a `timesteps` argument + in their `set_timesteps` method. If not defined, the default behavior when `num_inference_steps` is + passed will be used. Must be in descending order. + guidance_scale (`float`, *optional*, defaults to 7.0): + Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598). + `guidance_scale` is defined as `w` of equation 2. of [Imagen + Paper](https://arxiv.org/pdf/2205.11487.pdf). Guidance scale is enabled by setting `guidance_scale > + 1`. Higher guidance scale encourages to generate images that are closely linked to the text `prompt`, + usually at the expense of lower image quality. + num_images_per_prompt (`int`, *optional*, defaults to 1): + The number of images to generate per prompt. + generator (`torch.Generator` or `List[torch.Generator]`, *optional*): + One or a list of [torch generator(s)](https://pytorch.org/docs/stable/generated/torch.Generator.html) + to make generation deterministic. + latents (`torch.FloatTensor`, *optional*): + Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for image + generation. Can be used to tweak the same generation with different prompts. If not provided, a latents + tensor will ge generated by sampling using the supplied random `generator`. + prompt_embeds (`torch.FloatTensor`, *optional*): + Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not + provided, text embeddings will be generated from `prompt` input argument. + pooled_prompt_embeds (`torch.FloatTensor`, *optional*): + Pre-generated pooled text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. + If not provided, pooled text embeddings will be generated from `prompt` input argument. + output_type (`str`, *optional*, defaults to `"pil"`): + The output format of the generate image. Choose between + [PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`. + return_dict (`bool`, *optional*, defaults to `True`): + Whether or not to return a [`~pipelines.flux.FluxPipelineOutput`] instead of a plain tuple. + joint_attention_kwargs (`dict`, *optional*): + A kwargs dictionary that if specified is passed along to the `AttentionProcessor` as defined under + `self.processor` in + [diffusers.models.attention_processor](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py). + callback_on_step_end (`Callable`, *optional*): + A function that calls at the end of each denoising steps during the inference. The function is called + with the following arguments: `callback_on_step_end(self: DiffusionPipeline, step: int, timestep: int, + callback_kwargs: Dict)`. `callback_kwargs` will include a list of all tensors as specified by + `callback_on_step_end_tensor_inputs`. + callback_on_step_end_tensor_inputs (`List`, *optional*): + The list of tensor inputs for the `callback_on_step_end` function. The tensors specified in the list + will be passed as `callback_kwargs` argument. You will only be able to include variables listed in the + `._callback_tensor_inputs` attribute of your pipeline class. + max_sequence_length (`int` defaults to 512): Maximum sequence length to use with the `prompt`. + + Examples: + + Returns: + [`~pipelines.flux.FluxPipelineOutput`] or `tuple`: [`~pipelines.flux.FluxPipelineOutput`] if `return_dict` + is True, otherwise a `tuple`. When returning a tuple, the first element is a list with the generated + images. + """ + + height = height or self.default_sample_size * self.vae_scale_factor + width = width or self.default_sample_size * self.vae_scale_factor + + # 1. Check inputs. Raise error if not correct + self.check_inputs( + prompt, + prompt_2, + height, + width, + prompt_embeds=prompt_embeds, + pooled_prompt_embeds=pooled_prompt_embeds, + callback_on_step_end_tensor_inputs=callback_on_step_end_tensor_inputs, + max_sequence_length=max_sequence_length, + ) + + self._guidance_scale = guidance_scale + self._joint_attention_kwargs = joint_attention_kwargs + self._interrupt = False + + # 2. Define call parameters + if prompt is not None and isinstance(prompt, str): + batch_size = 1 + elif prompt is not None and isinstance(prompt, list): + batch_size = len(prompt) + else: + batch_size = prompt_embeds.shape[0] + + device = self._execution_device + + lora_scale = ( + self.joint_attention_kwargs.get("scale", None) if self.joint_attention_kwargs is not None else None + ) + ( + prompt_embeds, + pooled_prompt_embeds, + text_ids, + ) = self.encode_prompt( + prompt=prompt, + prompt_2=prompt_2, + prompt_embeds=prompt_embeds, + pooled_prompt_embeds=pooled_prompt_embeds, + device=device, + num_images_per_prompt=num_images_per_prompt, + max_sequence_length=max_sequence_length, + lora_scale=lora_scale, + ) + + # 4. Prepare latent variables + num_channels_latents = self.transformer.config.in_channels // 4 + # latents, latent_image_ids = self.prepare_latents( + # batch_size * num_images_per_prompt, + # num_channels_latents, + # height, + # width, + # prompt_embeds.dtype, + # device, + # generator, + # latents, + # ) + packed_inv_latents, latent_image_ids = self.prepare_inverted_latents( + self.inverted_latents.shape[0], + self.inverted_latents.shape[1], + self.inverted_latents.shape[2], + self.inverted_latents.shape[3], + prompt_embeds.dtype, + device, + generator, + ) + # 5. Prepare timesteps + sigmas = np.linspace(1.0, 1 / num_inference_steps, num_inference_steps) + image_seq_len = latents.shape[1] + mu = calculate_shift( + image_seq_len, + self.scheduler.config.base_image_seq_len, + self.scheduler.config.max_image_seq_len, + self.scheduler.config.base_shift, + self.scheduler.config.max_shift, + ) + timesteps, num_inference_steps = retrieve_timesteps( + self.scheduler, + num_inference_steps, + device, + timesteps, + sigmas, + mu=mu, + ) + num_warmup_steps = max(len(timesteps) - num_inference_steps * self.scheduler.order, 0) + self._num_timesteps = len(timesteps) + + # handle guidance + if self.transformer.config.guidance_embeds: + guidance = torch.full([1], guidance_scale, device=device, dtype=torch.float32) + guidance = guidance.expand(latents.shape[0]) + else: + guidance = None + + # 6. Denoising loop + with self.progress_bar(total=num_inference_steps) as progress_bar: + for i, t in enumerate(timesteps): + if self.interrupt: + continue + + # broadcast to batch dimension in a way that's compatible with ONNX/Core ML + timestep = t.expand(latents.shape[0]).to(latents.dtype) + + noise_pred = self.transformer( + hidden_states=latents, + timestep=timestep / 1000, + guidance=guidance, + pooled_projections=pooled_prompt_embeds, + encoder_hidden_states=prompt_embeds, + txt_ids=text_ids, + img_ids=latent_image_ids, + joint_attention_kwargs=self.joint_attention_kwargs, + return_dict=False, + )[0] + + # compute the previous noisy sample x_t -> x_t-1 + latents_dtype = latents.dtype + latents = self.scheduler.step(noise_pred, t, latents, return_dict=False)[0] + + if latents.dtype != latents_dtype: + if torch.backends.mps.is_available(): + # some platforms (eg. apple mps) misbehave due to a pytorch bug: https://github.com/pytorch/pytorch/pull/99272 + latents = latents.to(latents_dtype) + + if callback_on_step_end is not None: + callback_kwargs = {} + for k in callback_on_step_end_tensor_inputs: + callback_kwargs[k] = locals()[k] + callback_outputs = callback_on_step_end(self, i, t, callback_kwargs) + + latents = callback_outputs.pop("latents", latents) + prompt_embeds = callback_outputs.pop("prompt_embeds", prompt_embeds) + + # call the callback, if provided + if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0): + progress_bar.update() + + if XLA_AVAILABLE: + xm.mark_step() + + if output_type == "latent": + image = latents + + else: + latents = self._unpack_latents(latents, height, width, self.vae_scale_factor) + latents = (latents / self.vae.config.scaling_factor) + self.vae.config.shift_factor + image = self.vae.decode(latents, return_dict=False)[0] + image = self.image_processor.postprocess(image, output_type=output_type) + + # Offload all models + self.maybe_free_model_hooks() + + if not return_dict: + return (image,) + + return FluxPipelineOutput(images=image) + + @torch.no_grad() + def invert( + self, + image: PipelineImageInput, + source_prompt: str = "", + source_guidance_scale = 0.0, + num_inversion_steps: int = 28, + gamma: float = 0.5, + use_shift_t_sampling: bool = False, + generator: Optional[torch.Generator] = None, + dtype: Optional[torch.dtype] = None, + ): + dtype = dtype or self.text_encoder.dtype + height = height or self.default_sample_size * self.vae_scale_factor + width = width or self.default_sample_size * self.vae_scale_factor + # 1. prepare image + img_latents, _ = self.encode_image(image, dtype=dtype) + + timesteps = get_schedule( + num_steps=num_steps, + image_seq_len=(1024 // 16) * (1024 // 16), # vae_scale_factor = 16 + shift=use_shift_t_sampling, # Set True for Flux-dev, False for Flux-schnell + )[::-1] # flipped for inversion + + prompt_embeds, pooled_prompt_embeds, text_ids = pipeline.encode_prompt( + prompt=source_prompt + ) + latent_image_ids = self._prepare_latent_image_ids( + img_latents.shape[0], + img_latents.shape[2], + img_latents.shape[3], + img_latents.device, + dtype, + ) + packed_latents = self._pack_latents( + latents, + batch_size=latents.shape[0], + num_channels_latents=latents.shape[1], + height=latents.shape[2], + width=latents.shape[3], + ) + + target_noise = torch.randn(packed_latents.shape, device=packed_latents.device, dtype=torch.float32) + + guidance_vec = torch.full((packed_latents.shape[0],), source_guidance_scale, device=packed_latents.device, + dtype=packed_latents.dtype) + + # Image inversion with interpolated velocity field. t goes from 0.0 to 1.0 + with self.progress_bar(total=len(timesteps) - 1) as progress_bar: + for t_curr, t_prev in zip(timesteps[:-1], timesteps[1:]): + t_vec = torch.full((packed_latents.shape[0],), t_curr, dtype=packed_latents.dtype, + device=packed_latents.device) + + # Null text velocity + flux_velocity = pipeline.transformer( + hidden_states=packed_latents, + timestep=t_vec, + guidance=guidance_vec, + pooled_projections=pooled_prompt_embeds, + encoder_hidden_states=prompt_embeds, + txt_ids=text_ids, + img_ids=latent_image_ids, + joint_attention_kwargs=None, + return_dict=pipeline, + )[0] + + # Prevents precision issues + packed_latents = packed_latents.to(torch.float32) + flux_velocity = flux_velocity.to(torch.float32) + + # Target noise velocity + target_noise_velocity = (target_noise - packed_latents) / (1.0 - t_curr) + + # interpolated velocity + interpolated_velocity = gamma * target_noise_velocity + (1 - gamma) * flux_velocity + + # one step Euler + packed_latents = packed_latents + (t_prev - t_curr) * interpolated_velocity + + packed_latents = packed_latents.to(dtype) + progress_bar.update() + + print("Mean Absolute Error", torch.mean(torch.abs(packed_latents - target_noise))) + + latents = self._unpack_latents( + packed_latents, + height=height, + width=width, + vae_scale_factor=self.vae_scale_factor, + ) + latents = latents.to(dtype) + self.inverted_latents = latents + self.image_latents = img_latents + return latents From 9382687a859403523374aac845114a4e4b7a904a Mon Sep 17 00:00:00 2001 From: linoytsaban Date: Thu, 31 Oct 2024 12:03:52 +0200 Subject: [PATCH 02/45] update denoising loop --- .../community/pipeline_flux_rf_inversion.py | 74 ++++++++++++++++--- 1 file changed, 65 insertions(+), 9 deletions(-) diff --git a/examples/community/pipeline_flux_rf_inversion.py b/examples/community/pipeline_flux_rf_inversion.py index 75c7813c7513..78c18cba589f 100644 --- a/examples/community/pipeline_flux_rf_inversion.py +++ b/examples/community/pipeline_flux_rf_inversion.py @@ -746,6 +746,17 @@ def __call__( device, generator, ) + + packed_img_latents = self._pack_latents( + self.image_latents, + batch_size=self.image_latents.shape[0], + num_channels_latents=self.image_latents.shape[1], + height=self.image_latents.shape[2], + width=self.image_latents.shape[3], + ) + target_img = packed_img_latents.clone().to(torch.float32) + + # 5. Prepare timesteps sigmas = np.linspace(1.0, 1 / num_inference_steps, num_inference_steps) image_seq_len = latents.shape[1] @@ -770,20 +781,22 @@ def __call__( # handle guidance if self.transformer.config.guidance_embeds: guidance = torch.full([1], guidance_scale, device=device, dtype=torch.float32) - guidance = guidance.expand(latents.shape[0]) + guidance = guidance.expand(packed_inv_latents.shape[0]) else: guidance = None + eta_values = self.generate_eta_values(timesteps, start_step, end_step, eta_base, eta_trend) + # 6. Denoising loop with self.progress_bar(total=num_inference_steps) as progress_bar: - for i, t in enumerate(timesteps): - if self.interrupt: - continue + for idx, (t_curr, t_prev) in enumerate(zip(timesteps[:-1], timesteps[1:])): + timestep = torch.full((packed_inv_latents.shape[0],), t_curr, dtype=packed_inv_latents.dtype, + device=packed_inv_latents.device) # broadcast to batch dimension in a way that's compatible with ONNX/Core ML - timestep = t.expand(latents.shape[0]).to(latents.dtype) + #timestep = t.expand(latents.shape[0]).to(latents.dtype) - noise_pred = self.transformer( + velocity = self.transformer( hidden_states=latents, timestep=timestep / 1000, guidance=guidance, @@ -796,8 +809,22 @@ def __call__( )[0] # compute the previous noisy sample x_t -> x_t-1 - latents_dtype = latents.dtype - latents = self.scheduler.step(noise_pred, t, latents, return_dict=False)[0] + #latents_dtype = latents.dtype + # Prevents precision issues + packed_inv_latents = packed_inv_latents.to(torch.float32) + velocity = velocity.to(torch.float32) + + # Target image velocity + target_img_velocity = -(target_img - packed_inv_latents) / t_curr + + # interpolated velocity + eta = eta_values[idx] + interpolated_velocity = eta * target_img_velocity + (1 - eta) * velocity + latents = packed_inv_latents + (t_prev - t_curr) * interpolated_velocity + print( + f"X_{t_prev:.3f} = X_{t_curr:.3f} + {t_prev - t_curr:.3f} * ({eta:.3f} * target_img_velocity + {1 - eta:.3f} * flux_velocity)") + + #latents = self.scheduler.step(noise_pred, t, latents, return_dict=False)[0] if latents.dtype != latents_dtype: if torch.backends.mps.is_available(): @@ -837,6 +864,35 @@ def __call__( return FluxPipelineOutput(images=image) + def generate_eta_values( + self, + timesteps, + start_step, + end_step, + eta, + eta_trend, + ): + assert start_step < end_step and start_step >= 0 and end_step <= len( + timesteps), "Invalid start_step and end_step" + # timesteps are monotonically decreasing, from 1.0 to 0.0 + eta_values = [0.0] * (len(timesteps) - 1) + + if eta_trend == 'constant': + for i in range(start_step, end_step): + eta_values[i] = eta + elif eta_trend == 'linear_increase': + total_time = timesteps[start_step] - timesteps[end_step - 1] + for i in range(start_step, end_step): + eta_values[i] = eta * (timesteps[start_step] - timesteps[i]) / total_time + elif eta_trend == 'linear_decrease': + total_time = timesteps[start_step] - timesteps[end_step - 1] + for i in range(start_step, end_step): + eta_values[i] = eta * (timesteps[i] - timesteps[end_step - 1]) / total_time + else: + raise NotImplementedError(f"Unsupported eta_trend: {eta_trend}") + + return eta_values + @torch.no_grad() def invert( self, @@ -913,7 +969,7 @@ def invert( # interpolated velocity interpolated_velocity = gamma * target_noise_velocity + (1 - gamma) * flux_velocity - # one step Euler + # one-step Euler packed_latents = packed_latents + (t_prev - t_curr) * interpolated_velocity packed_latents = packed_latents.to(dtype) From ac497896cf08c1da7a36eeaff9a28b8866944527 Mon Sep 17 00:00:00 2001 From: linoytsaban Date: Thu, 31 Oct 2024 12:58:45 +0200 Subject: [PATCH 03/45] fix scheduling --- .../community/pipeline_flux_rf_inversion.py | 62 +++++++++++++------ 1 file changed, 43 insertions(+), 19 deletions(-) diff --git a/examples/community/pipeline_flux_rf_inversion.py b/examples/community/pipeline_flux_rf_inversion.py index 78c18cba589f..da204df38b6c 100644 --- a/examples/community/pipeline_flux_rf_inversion.py +++ b/examples/community/pipeline_flux_rf_inversion.py @@ -759,7 +759,7 @@ def __call__( # 5. Prepare timesteps sigmas = np.linspace(1.0, 1 / num_inference_steps, num_inference_steps) - image_seq_len = latents.shape[1] + image_seq_len = packed_inv_latents.shape[1] mu = calculate_shift( image_seq_len, self.scheduler.config.base_image_seq_len, @@ -790,11 +790,10 @@ def __call__( # 6. Denoising loop with self.progress_bar(total=num_inference_steps) as progress_bar: for idx, (t_curr, t_prev) in enumerate(zip(timesteps[:-1], timesteps[1:])): - timestep = torch.full((packed_inv_latents.shape[0],), t_curr, dtype=packed_inv_latents.dtype, - device=packed_inv_latents.device) # broadcast to batch dimension in a way that's compatible with ONNX/Core ML #timestep = t.expand(latents.shape[0]).to(latents.dtype) + timestep = t_curr.expand(packed_inv_latents.shape[0]).to(packed_inv_latents.dtype) velocity = self.transformer( hidden_states=latents, @@ -908,14 +907,33 @@ def invert( dtype = dtype or self.text_encoder.dtype height = height or self.default_sample_size * self.vae_scale_factor width = width or self.default_sample_size * self.vae_scale_factor + device = self._execution_device + # 1. prepare image img_latents, _ = self.encode_image(image, dtype=dtype) - timesteps = get_schedule( - num_steps=num_steps, - image_seq_len=(1024 // 16) * (1024 // 16), # vae_scale_factor = 16 - shift=use_shift_t_sampling, # Set True for Flux-dev, False for Flux-schnell - )[::-1] # flipped for inversion + # 2. Prepare timesteps + sigmas = np.linspace(1.0, 1 / num_inversion_steps, num_inversion_steps) + image_seq_len = img_latents.shape[1] + if use_shift_t_sampling: + mu = calculate_shift( + image_seq_len, + self.scheduler.config.base_image_seq_len, + self.scheduler.config.max_image_seq_len, + self.scheduler.config.base_shift, + self.scheduler.config.max_shift, + ) + else: + mu = None + timesteps, num_inversion_steps = retrieve_timesteps( + self.scheduler, + num_inversion_steps, + device, + sigmas=sigmas, + mu=mu, + ) + + timesteps = timesteps[::-1] # flip for inversion prompt_embeds, pooled_prompt_embeds, text_ids = pipeline.encode_prompt( prompt=source_prompt @@ -924,7 +942,7 @@ def invert( img_latents.shape[0], img_latents.shape[2], img_latents.shape[3], - img_latents.device, + device, dtype, ) packed_latents = self._pack_latents( @@ -935,22 +953,28 @@ def invert( width=latents.shape[3], ) - target_noise = torch.randn(packed_latents.shape, device=packed_latents.device, dtype=torch.float32) + target_noise = torch.randn(packed_latents.shape, device=device, dtype=torch.float32) - guidance_vec = torch.full((packed_latents.shape[0],), source_guidance_scale, device=packed_latents.device, - dtype=packed_latents.dtype) + # handle guidance + if self.transformer.config.guidance_embeds: + guidance = torch.full([1], source_guidance_scale, device=device, dtype=torch.float32) + guidance = guidance.expand(packed_latents.shape[0]) + else: + guidance = None # Image inversion with interpolated velocity field. t goes from 0.0 to 1.0 with self.progress_bar(total=len(timesteps) - 1) as progress_bar: for t_curr, t_prev in zip(timesteps[:-1], timesteps[1:]): - t_vec = torch.full((packed_latents.shape[0],), t_curr, dtype=packed_latents.dtype, - device=packed_latents.device) + + # t_vec = torch.full((packed_latents.shape[0],), t_curr, dtype=packed_latents.dtype, + # device=device) + timestep = t_curr.expand(packed_latents.shape[0]).to(packed_latents.dtype) # Null text velocity - flux_velocity = pipeline.transformer( + velocity = pipeline.transformer( hidden_states=packed_latents, - timestep=t_vec, - guidance=guidance_vec, + timestep=timestep, + guidance=guidance, pooled_projections=pooled_prompt_embeds, encoder_hidden_states=prompt_embeds, txt_ids=text_ids, @@ -961,13 +985,13 @@ def invert( # Prevents precision issues packed_latents = packed_latents.to(torch.float32) - flux_velocity = flux_velocity.to(torch.float32) + velocity = velocity.to(torch.float32) # Target noise velocity target_noise_velocity = (target_noise - packed_latents) / (1.0 - t_curr) # interpolated velocity - interpolated_velocity = gamma * target_noise_velocity + (1 - gamma) * flux_velocity + interpolated_velocity = gamma * target_noise_velocity + (1 - gamma) * velocity # one-step Euler packed_latents = packed_latents + (t_prev - t_curr) * interpolated_velocity From bd9b7f1913ae4ee7b9b391e97ecb6cb3475c9a6b Mon Sep 17 00:00:00 2001 From: Linoy Date: Thu, 31 Oct 2024 12:25:27 +0000 Subject: [PATCH 04/45] style --- .../community/pipeline_flux_rf_inversion.py | 65 +++++++++---------- 1 file changed, 31 insertions(+), 34 deletions(-) diff --git a/examples/community/pipeline_flux_rf_inversion.py b/examples/community/pipeline_flux_rf_inversion.py index da204df38b6c..c3bf7a1def16 100644 --- a/examples/community/pipeline_flux_rf_inversion.py +++ b/examples/community/pipeline_flux_rf_inversion.py @@ -394,6 +394,7 @@ def encode_image(self, image, dtype=None, height=None, width=None, resize_mode=" x0 = (x0 - self.vae.config.shift_factor) * self.vae.config.scaling_factor x0 = x0.to(dtype) return x0, resized + def check_inputs( self, prompt, @@ -552,7 +553,6 @@ def prepare_inverted_latents( device, generator, ): - latent_image_ids = self._prepare_latent_image_ids(batch_size, height, width, device, dtype) if isinstance(generator, list) and len(generator) != batch_size: @@ -590,11 +590,11 @@ def __call__( width: Optional[int] = None, num_inference_steps: int = 28, eta_base: float = 1.0, # base eta value - eta_trend: str = 'linear_decrease', # constant, linear_increase, linear_decrease + eta_trend: str = "linear_decrease", # constant, linear_increase, linear_decrease start_step: int = 0, # 0-based indexing, closed interval end_step: int = 30, # 0-based indexing, open interval timesteps: List[int] = None, - use_shift_t_sampling: bool =True, + use_shift_t_sampling: bool = True, guidance_scale: float = 3.5, num_images_per_prompt: Optional[int] = 1, generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None, @@ -756,7 +756,6 @@ def __call__( ) target_img = packed_img_latents.clone().to(torch.float32) - # 5. Prepare timesteps sigmas = np.linspace(1.0, 1 / num_inference_steps, num_inference_steps) image_seq_len = packed_inv_latents.shape[1] @@ -790,9 +789,8 @@ def __call__( # 6. Denoising loop with self.progress_bar(total=num_inference_steps) as progress_bar: for idx, (t_curr, t_prev) in enumerate(zip(timesteps[:-1], timesteps[1:])): - # broadcast to batch dimension in a way that's compatible with ONNX/Core ML - #timestep = t.expand(latents.shape[0]).to(latents.dtype) + # timestep = t.expand(latents.shape[0]).to(latents.dtype) timestep = t_curr.expand(packed_inv_latents.shape[0]).to(packed_inv_latents.dtype) velocity = self.transformer( @@ -808,7 +806,7 @@ def __call__( )[0] # compute the previous noisy sample x_t -> x_t-1 - #latents_dtype = latents.dtype + # latents_dtype = latents.dtype # Prevents precision issues packed_inv_latents = packed_inv_latents.to(torch.float32) velocity = velocity.to(torch.float32) @@ -821,9 +819,10 @@ def __call__( interpolated_velocity = eta * target_img_velocity + (1 - eta) * velocity latents = packed_inv_latents + (t_prev - t_curr) * interpolated_velocity print( - f"X_{t_prev:.3f} = X_{t_curr:.3f} + {t_prev - t_curr:.3f} * ({eta:.3f} * target_img_velocity + {1 - eta:.3f} * flux_velocity)") + f"X_{t_prev:.3f} = X_{t_curr:.3f} + {t_prev - t_curr:.3f} * ({eta:.3f} * target_img_velocity + {1 - eta:.3f} * flux_velocity)" + ) - #latents = self.scheduler.step(noise_pred, t, latents, return_dict=False)[0] + # latents = self.scheduler.step(noise_pred, t, latents, return_dict=False)[0] if latents.dtype != latents_dtype: if torch.backends.mps.is_available(): @@ -864,26 +863,27 @@ def __call__( return FluxPipelineOutput(images=image) def generate_eta_values( - self, - timesteps, - start_step, - end_step, - eta, - eta_trend, + self, + timesteps, + start_step, + end_step, + eta, + eta_trend, ): - assert start_step < end_step and start_step >= 0 and end_step <= len( - timesteps), "Invalid start_step and end_step" + assert ( + start_step < end_step and start_step >= 0 and end_step <= len(timesteps) + ), "Invalid start_step and end_step" # timesteps are monotonically decreasing, from 1.0 to 0.0 eta_values = [0.0] * (len(timesteps) - 1) - if eta_trend == 'constant': + if eta_trend == "constant": for i in range(start_step, end_step): eta_values[i] = eta - elif eta_trend == 'linear_increase': + elif eta_trend == "linear_increase": total_time = timesteps[start_step] - timesteps[end_step - 1] for i in range(start_step, end_step): eta_values[i] = eta * (timesteps[start_step] - timesteps[i]) / total_time - elif eta_trend == 'linear_decrease': + elif eta_trend == "linear_decrease": total_time = timesteps[start_step] - timesteps[end_step - 1] for i in range(start_step, end_step): eta_values[i] = eta * (timesteps[i] - timesteps[end_step - 1]) / total_time @@ -894,15 +894,15 @@ def generate_eta_values( @torch.no_grad() def invert( - self, - image: PipelineImageInput, - source_prompt: str = "", - source_guidance_scale = 0.0, - num_inversion_steps: int = 28, - gamma: float = 0.5, - use_shift_t_sampling: bool = False, - generator: Optional[torch.Generator] = None, - dtype: Optional[torch.dtype] = None, + self, + image: PipelineImageInput, + source_prompt: str = "", + source_guidance_scale=0.0, + num_inversion_steps: int = 28, + gamma: float = 0.5, + use_shift_t_sampling: bool = False, + generator: Optional[torch.Generator] = None, + dtype: Optional[torch.dtype] = None, ): dtype = dtype or self.text_encoder.dtype height = height or self.default_sample_size * self.vae_scale_factor @@ -933,11 +933,9 @@ def invert( mu=mu, ) - timesteps = timesteps[::-1] # flip for inversion + timesteps = timesteps[::-1] # flip for inversion - prompt_embeds, pooled_prompt_embeds, text_ids = pipeline.encode_prompt( - prompt=source_prompt - ) + prompt_embeds, pooled_prompt_embeds, text_ids = pipeline.encode_prompt(prompt=source_prompt) latent_image_ids = self._prepare_latent_image_ids( img_latents.shape[0], img_latents.shape[2], @@ -965,7 +963,6 @@ def invert( # Image inversion with interpolated velocity field. t goes from 0.0 to 1.0 with self.progress_bar(total=len(timesteps) - 1) as progress_bar: for t_curr, t_prev in zip(timesteps[:-1], timesteps[1:]): - # t_vec = torch.full((packed_latents.shape[0],), t_curr, dtype=packed_latents.dtype, # device=device) timestep = t_curr.expand(packed_latents.shape[0]).to(packed_latents.dtype) From 8358ecd7de761acbac0b6d5b2937d038efc84c16 Mon Sep 17 00:00:00 2001 From: linoytsaban Date: Thu, 31 Oct 2024 14:29:46 +0200 Subject: [PATCH 05/45] fix import --- examples/community/pipeline_flux_rf_inversion.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/examples/community/pipeline_flux_rf_inversion.py b/examples/community/pipeline_flux_rf_inversion.py index c3bf7a1def16..ee6c64d936d0 100644 --- a/examples/community/pipeline_flux_rf_inversion.py +++ b/examples/community/pipeline_flux_rf_inversion.py @@ -20,7 +20,7 @@ from transformers import CLIPTextModel, CLIPTokenizer, T5EncoderModel, T5TokenizerFast from diffusers.image_processor import VaeImageProcessor -from diffusers.loaders import FluxLoraLoaderMixin, FromSingleFileMixin +from diffusers.loaders import FluxLoraLoaderMixin, FromSingleFileMixin, TextualInversionLoaderMixin from diffusers.models.autoencoders import AutoencoderKL from diffusers.models.transformers import FluxTransformer2DModel from diffusers.pipelines.flux.pipeline_output import FluxPipelineOutput From 0d28b0e6078db8ab4d98f7a83045ed31112d1571 Mon Sep 17 00:00:00 2001 From: linoytsaban Date: Thu, 31 Oct 2024 14:43:08 +0200 Subject: [PATCH 06/45] fixes --- .../community/pipeline_flux_rf_inversion.py | 29 ++++++++++--------- 1 file changed, 15 insertions(+), 14 deletions(-) diff --git a/examples/community/pipeline_flux_rf_inversion.py b/examples/community/pipeline_flux_rf_inversion.py index ee6c64d936d0..3ecec1d4d8a5 100644 --- a/examples/community/pipeline_flux_rf_inversion.py +++ b/examples/community/pipeline_flux_rf_inversion.py @@ -19,7 +19,7 @@ import torch from transformers import CLIPTextModel, CLIPTokenizer, T5EncoderModel, T5TokenizerFast -from diffusers.image_processor import VaeImageProcessor +from diffusers.image_processor import PipelineImageInput, VaeImageProcessor from diffusers.loaders import FluxLoraLoaderMixin, FromSingleFileMixin, TextualInversionLoaderMixin from diffusers.models.autoencoders import AutoencoderKL from diffusers.models.transformers import FluxTransformer2DModel @@ -34,7 +34,7 @@ scale_lora_layers, unscale_lora_layers, ) - +from diffusers.utils.torch_utils import randn_tensor if is_torch_xla_available(): import torch_xla.core.xla_model as xm @@ -788,13 +788,13 @@ def __call__( # 6. Denoising loop with self.progress_bar(total=num_inference_steps) as progress_bar: - for idx, (t_curr, t_prev) in enumerate(zip(timesteps[:-1], timesteps[1:])): + for i, (t_curr, t_prev) in enumerate(zip(timesteps[:-1], timesteps[1:])): # broadcast to batch dimension in a way that's compatible with ONNX/Core ML # timestep = t.expand(latents.shape[0]).to(latents.dtype) timestep = t_curr.expand(packed_inv_latents.shape[0]).to(packed_inv_latents.dtype) velocity = self.transformer( - hidden_states=latents, + hidden_states=packed_inv_latents, timestep=timestep / 1000, guidance=guidance, pooled_projections=pooled_prompt_embeds, @@ -815,7 +815,7 @@ def __call__( target_img_velocity = -(target_img - packed_inv_latents) / t_curr # interpolated velocity - eta = eta_values[idx] + eta = eta_values[i] interpolated_velocity = eta * target_img_velocity + (1 - eta) * velocity latents = packed_inv_latents + (t_prev - t_curr) * interpolated_velocity print( @@ -824,16 +824,16 @@ def __call__( # latents = self.scheduler.step(noise_pred, t, latents, return_dict=False)[0] - if latents.dtype != latents_dtype: - if torch.backends.mps.is_available(): - # some platforms (eg. apple mps) misbehave due to a pytorch bug: https://github.com/pytorch/pytorch/pull/99272 - latents = latents.to(latents_dtype) + # if latents.dtype != latents_dtype: + # if torch.backends.mps.is_available(): + # # some platforms (eg. apple mps) misbehave due to a pytorch bug: https://github.com/pytorch/pytorch/pull/99272 + # latents = latents.to(latents_dtype) if callback_on_step_end is not None: callback_kwargs = {} for k in callback_on_step_end_tensor_inputs: callback_kwargs[k] = locals()[k] - callback_outputs = callback_on_step_end(self, i, t, callback_kwargs) + callback_outputs = callback_on_step_end(self, i, t_curr, callback_kwargs) latents = callback_outputs.pop("latents", latents) prompt_embeds = callback_outputs.pop("prompt_embeds", prompt_embeds) @@ -901,7 +901,8 @@ def invert( num_inversion_steps: int = 28, gamma: float = 0.5, use_shift_t_sampling: bool = False, - generator: Optional[torch.Generator] = None, + height: Optional[int] = None, + width: Optional[int] = None, dtype: Optional[torch.dtype] = None, ): dtype = dtype or self.text_encoder.dtype @@ -935,7 +936,7 @@ def invert( timesteps = timesteps[::-1] # flip for inversion - prompt_embeds, pooled_prompt_embeds, text_ids = pipeline.encode_prompt(prompt=source_prompt) + prompt_embeds, pooled_prompt_embeds, text_ids = self.encode_prompt(prompt=source_prompt) latent_image_ids = self._prepare_latent_image_ids( img_latents.shape[0], img_latents.shape[2], @@ -968,7 +969,7 @@ def invert( timestep = t_curr.expand(packed_latents.shape[0]).to(packed_latents.dtype) # Null text velocity - velocity = pipeline.transformer( + velocity = self.transformer( hidden_states=packed_latents, timestep=timestep, guidance=guidance, @@ -977,7 +978,7 @@ def invert( txt_ids=text_ids, img_ids=latent_image_ids, joint_attention_kwargs=None, - return_dict=pipeline, + return_dict=False, )[0] # Prevents precision issues From f5df0304f50feee3620f2f6dfbda7b310d93dad2 Mon Sep 17 00:00:00 2001 From: linoytsaban Date: Thu, 31 Oct 2024 14:45:18 +0200 Subject: [PATCH 07/45] fixes --- examples/community/pipeline_flux_rf_inversion.py | 12 ++++++------ 1 file changed, 6 insertions(+), 6 deletions(-) diff --git a/examples/community/pipeline_flux_rf_inversion.py b/examples/community/pipeline_flux_rf_inversion.py index 3ecec1d4d8a5..20906250ea98 100644 --- a/examples/community/pipeline_flux_rf_inversion.py +++ b/examples/community/pipeline_flux_rf_inversion.py @@ -936,7 +936,7 @@ def invert( timesteps = timesteps[::-1] # flip for inversion - prompt_embeds, pooled_prompt_embeds, text_ids = self.encode_prompt(prompt=source_prompt) + prompt_embeds, pooled_prompt_embeds, text_ids = self.encode_prompt(prompt=source_prompt, prompt_2=source_prompt) latent_image_ids = self._prepare_latent_image_ids( img_latents.shape[0], img_latents.shape[2], @@ -945,11 +945,11 @@ def invert( dtype, ) packed_latents = self._pack_latents( - latents, - batch_size=latents.shape[0], - num_channels_latents=latents.shape[1], - height=latents.shape[2], - width=latents.shape[3], + img_latents, + batch_size=img_latents.shape[0], + num_channels_latents=img_latents.shape[1], + height=img_latents.shape[2], + width=img_latents.shape[3], ) target_noise = torch.randn(packed_latents.shape, device=device, dtype=torch.float32) From ad66f17d2d65b6155f0b0a342996bcdd776bdbc6 Mon Sep 17 00:00:00 2001 From: Linoy Date: Thu, 31 Oct 2024 12:48:42 +0000 Subject: [PATCH 08/45] style --- examples/community/pipeline_flux_rf_inversion.py | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/examples/community/pipeline_flux_rf_inversion.py b/examples/community/pipeline_flux_rf_inversion.py index 20906250ea98..0d19b29f1bae 100644 --- a/examples/community/pipeline_flux_rf_inversion.py +++ b/examples/community/pipeline_flux_rf_inversion.py @@ -36,6 +36,7 @@ ) from diffusers.utils.torch_utils import randn_tensor + if is_torch_xla_available(): import torch_xla.core.xla_model as xm @@ -936,7 +937,9 @@ def invert( timesteps = timesteps[::-1] # flip for inversion - prompt_embeds, pooled_prompt_embeds, text_ids = self.encode_prompt(prompt=source_prompt, prompt_2=source_prompt) + prompt_embeds, pooled_prompt_embeds, text_ids = self.encode_prompt( + prompt=source_prompt, prompt_2=source_prompt + ) latent_image_ids = self._prepare_latent_image_ids( img_latents.shape[0], img_latents.shape[2], From 757f98bd93a616e586f3912fa5047339d908a09c Mon Sep 17 00:00:00 2001 From: linoytsaban Date: Fri, 1 Nov 2024 09:04:48 +0200 Subject: [PATCH 09/45] fixes --- examples/community/pipeline_flux_rf_inversion.py | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/examples/community/pipeline_flux_rf_inversion.py b/examples/community/pipeline_flux_rf_inversion.py index 20906250ea98..934cd8b17b7d 100644 --- a/examples/community/pipeline_flux_rf_inversion.py +++ b/examples/community/pipeline_flux_rf_inversion.py @@ -934,9 +934,11 @@ def invert( mu=mu, ) - timesteps = timesteps[::-1] # flip for inversion + timesteps = reversed(timesteps[-num_inversion_steps:]) # flip for inversion - prompt_embeds, pooled_prompt_embeds, text_ids = self.encode_prompt(prompt=source_prompt, prompt_2=source_prompt) + prompt_embeds, pooled_prompt_embeds, text_ids = self.encode_prompt( + prompt=source_prompt, prompt_2=source_prompt + ) latent_image_ids = self._prepare_latent_image_ids( img_latents.shape[0], img_latents.shape[2], From d2c4f7d97d1b3a152aa5802792cafc45e051059f Mon Sep 17 00:00:00 2001 From: linoytsaban Date: Sun, 3 Nov 2024 15:37:43 +0200 Subject: [PATCH 10/45] change invert --- .../community/pipeline_flux_rf_inversion.py | 176 +++++++++--------- 1 file changed, 87 insertions(+), 89 deletions(-) diff --git a/examples/community/pipeline_flux_rf_inversion.py b/examples/community/pipeline_flux_rf_inversion.py index c084e52441b3..f764ee81ff3b 100644 --- a/examples/community/pipeline_flux_rf_inversion.py +++ b/examples/community/pipeline_flux_rf_inversion.py @@ -136,7 +136,6 @@ def retrieve_timesteps( timesteps = scheduler.timesteps return timesteps, num_inference_steps - class RFInversionFluxPipeline( DiffusionPipeline, FluxLoraLoaderMixin, @@ -565,6 +564,18 @@ def prepare_inverted_latents( latents = self._pack_latents(self.inverted_latents, batch_size, num_channels_latents, height, width) return latents, latent_image_ids + # Copied from diffusers.pipelines.stable_diffusion_3.pipeline_stable_diffusion_3_img2img.StableDiffusion3Img2ImgPipeline.get_timesteps + def get_timesteps(self, num_inference_steps, strength, device): + # get the original timestep using init_timestep + init_timestep = min(num_inference_steps * strength, num_inference_steps) + + t_start = int(max(num_inference_steps - init_timestep, 0)) + timesteps = self.scheduler.timesteps[t_start * self.scheduler.order:] + if hasattr(self.scheduler, "set_begin_index"): + self.scheduler.set_begin_index(t_start * self.scheduler.order) + + return timesteps, num_inference_steps - t_start + @property def guidance_scale(self): return self._guidance_scale @@ -901,114 +912,101 @@ def invert( source_guidance_scale=0.0, num_inversion_steps: int = 28, gamma: float = 0.5, - use_shift_t_sampling: bool = False, height: Optional[int] = None, width: Optional[int] = None, + timesteps: List[int] = None, dtype: Optional[torch.dtype] = None, + generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None, ): dtype = dtype or self.text_encoder.dtype + batch_size = 1 + num_channels_latents = self.transformer.config.in_channels // 4 height = height or self.default_sample_size * self.vae_scale_factor width = width or self.default_sample_size * self.vae_scale_factor device = self._execution_device # 1. prepare image - img_latents, _ = self.encode_image(image, dtype=dtype) - - # 2. Prepare timesteps + image_latents, _ = self.encode_image(image, dtype=dtype) + _, latent_image_ids = self.prepare_latents(batch_size, + num_channels_latents, + height, width, + dtype, device, generator) + image_latents = self._pack_latents(image_latents, batch_size, num_channels_latents, height, width) + self.image_latents = image_latents.clone() + + # 2. prepare timesteps sigmas = np.linspace(1.0, 1 / num_inversion_steps, num_inversion_steps) - image_seq_len = img_latents.shape[1] - if use_shift_t_sampling: - mu = calculate_shift( - image_seq_len, - self.scheduler.config.base_image_seq_len, - self.scheduler.config.max_image_seq_len, - self.scheduler.config.base_shift, - self.scheduler.config.max_shift, - ) - else: - mu = None - timesteps, num_inversion_steps = retrieve_timesteps( + image_seq_len = (int(height) // self.vae_scale_factor) * (int(width) // self.vae_scale_factor) + mu = calculate_shift( + image_seq_len, + self.scheduler.config.base_image_seq_len, + self.scheduler.config.max_image_seq_len, + self.scheduler.config.base_shift, + self.scheduler.config.max_shift, + ) + timesteps, num_inference_steps = retrieve_timesteps( self.scheduler, num_inversion_steps, device, - sigmas=sigmas, + timesteps, + sigmas, mu=mu, ) + timesteps, sigmas, num_inference_steps = self.get_timesteps(num_inference_steps, strength, device) - timesteps = reversed(timesteps[-num_inversion_steps:]) # flip for inversion - - prompt_embeds, pooled_prompt_embeds, text_ids = self.encode_prompt( - prompt=source_prompt, prompt_2=source_prompt - ) - latent_image_ids = self._prepare_latent_image_ids( - img_latents.shape[0], - img_latents.shape[2], - img_latents.shape[3], - device, - dtype, - ) - packed_latents = self._pack_latents( - img_latents, - batch_size=img_latents.shape[0], - num_channels_latents=img_latents.shape[1], - height=img_latents.shape[2], - width=img_latents.shape[3], + # 3. prepare text embeddings + ( + prompt_embeds, + pooled_prompt_embeds, + text_ids, + ) = self.encode_prompt( + prompt=source_prompt, + prompt_2=source_prompt, + device=device, ) - - target_noise = torch.randn(packed_latents.shape, device=device, dtype=torch.float32) - - # handle guidance + # 4. handle guidance if self.transformer.config.guidance_embeds: guidance = torch.full([1], source_guidance_scale, device=device, dtype=torch.float32) - guidance = guidance.expand(packed_latents.shape[0]) else: guidance = None - # Image inversion with interpolated velocity field. t goes from 0.0 to 1.0 - with self.progress_bar(total=len(timesteps) - 1) as progress_bar: - for t_curr, t_prev in zip(timesteps[:-1], timesteps[1:]): - # t_vec = torch.full((packed_latents.shape[0],), t_curr, dtype=packed_latents.dtype, - # device=device) - timestep = t_curr.expand(packed_latents.shape[0]).to(packed_latents.dtype) - - # Null text velocity - velocity = self.transformer( - hidden_states=packed_latents, - timestep=timestep, - guidance=guidance, - pooled_projections=pooled_prompt_embeds, - encoder_hidden_states=prompt_embeds, - txt_ids=text_ids, - img_ids=latent_image_ids, - joint_attention_kwargs=None, - return_dict=False, - )[0] - - # Prevents precision issues - packed_latents = packed_latents.to(torch.float32) - velocity = velocity.to(torch.float32) - - # Target noise velocity - target_noise_velocity = (target_noise - packed_latents) / (1.0 - t_curr) - - # interpolated velocity - interpolated_velocity = gamma * target_noise_velocity + (1 - gamma) * velocity - - # one-step Euler - packed_latents = packed_latents + (t_prev - t_curr) * interpolated_velocity - - packed_latents = packed_latents.to(dtype) - progress_bar.update() - - print("Mean Absolute Error", torch.mean(torch.abs(packed_latents - target_noise))) - - latents = self._unpack_latents( - packed_latents, - height=height, - width=width, - vae_scale_factor=self.vae_scale_factor, - ) - latents = latents.to(dtype) - self.inverted_latents = latents - self.image_latents = img_latents - return latents + # if num_inference_steps < 1: + # raise ValueError( + # f"After adjusting the num_inference_steps by strength parameter: {strength}, the number of pipeline" + # f"steps is {num_inference_steps} which is < 1 and not appropriate for this pipeline." + # ) + # latent_timestep = timesteps[:1].repeat(batch_size * num_images_per_prompt) + + # Eq 8 dY_t = [u_t(Y_t) + γ(u_t(Y_t|y_1) - u_t(Y_t))]dt + Y_t = image_latents + y_1 = torch.randn_like(Y_t) + N = len(sigmas) + + for i in range(N - 1): # enumerate(timesteps): + t_i = torch.tensor(i / (N), dtype=Y_t.dtype, device=device) + timestep = torch.tensor(t_i, dtype=Y_t.dtype, device=device).repeat(batch_size) + # get the unconditional vector field + + u_t_i = self.transformer( + hidden_states=Y_t, + timestep=timestep, + guidance=guidance, + pooled_projections=pooled_prompt_embeds, + encoder_hidden_states=prompt_embeds, + txt_ids=text_ids, + img_ids=latent_image_ids, + joint_attention_kwargs=self.joint_attention_kwargs, + return_dict=False, + )[0] + + # get the conditional vector field + u_t_i_cond = (y_1 - Y_t) / (1 - t_i) + + # controlled vector field + # Eq 8 dY_t = [u_t(Y_t) + γ(u_t(Y_t|y_1) - u_t(Y_t))]dt + u_hat_t_i = u_t_i + gamma * (u_t_i_cond - u_t_i) + Y_t = Y_t + u_hat_t_i * (sigmas[i] - sigmas[i+1]) + + self.inverted_latents = Y_t + + return latent_image_ids From ff75de8d6faaf4d5dce6a5a320a4784b5b1a6481 Mon Sep 17 00:00:00 2001 From: linoytsaban Date: Sun, 3 Nov 2024 16:00:21 +0200 Subject: [PATCH 11/45] change denoising & check inputs --- .../community/pipeline_flux_rf_inversion.py | 102 +++++++++--------- 1 file changed, 49 insertions(+), 53 deletions(-) diff --git a/examples/community/pipeline_flux_rf_inversion.py b/examples/community/pipeline_flux_rf_inversion.py index f764ee81ff3b..3fdb5a141ac9 100644 --- a/examples/community/pipeline_flux_rf_inversion.py +++ b/examples/community/pipeline_flux_rf_inversion.py @@ -401,10 +401,13 @@ def check_inputs( prompt_2, height, width, + start_timestep, + stop_timestep, prompt_embeds=None, pooled_prompt_embeds=None, callback_on_step_end_tensor_inputs=None, max_sequence_length=None, + ): if height % self.vae_scale_factor != 0 or width % self.vae_scale_factor != 0: raise ValueError( @@ -445,6 +448,10 @@ def check_inputs( if max_sequence_length is not None and max_sequence_length > 512: raise ValueError(f"`max_sequence_length` cannot be greater than 512 but is {max_sequence_length}") + # check start_timestep and stop_timestep + if start_timestep < 0 or start_timestep > stop_timestep: + raise ValueError(f"`start_timestep` should be in [0, stop_timestep] but is {start_timestep}") + @staticmethod def _prepare_latent_image_ids(batch_size, height, width, device, dtype): latent_image_ids = torch.zeros(height, width, 3) @@ -600,11 +607,12 @@ def __call__( prompt_2: Optional[Union[str, List[str]]] = None, height: Optional[int] = None, width: Optional[int] = None, + strength: float = 0.6, + eta: float = 1.0, + gamma: float = 1.0, + start_timestep: int = 0, + stop_timestep: int = 6, num_inference_steps: int = 28, - eta_base: float = 1.0, # base eta value - eta_trend: str = "linear_decrease", # constant, linear_increase, linear_decrease - start_step: int = 0, # 0-based indexing, closed interval - end_step: int = 30, # 0-based indexing, open interval timesteps: List[int] = None, use_shift_t_sampling: bool = True, guidance_scale: float = 3.5, @@ -703,6 +711,7 @@ def __call__( pooled_prompt_embeds=pooled_prompt_embeds, callback_on_step_end_tensor_inputs=callback_on_step_end_tensor_inputs, max_sequence_length=max_sequence_length, + ) self._guidance_scale = guidance_scale @@ -739,6 +748,10 @@ def __call__( # 4. Prepare latent variables num_channels_latents = self.transformer.config.in_channels // 4 + latents = self.inverted_latents + latent_image_ids = self.latent_image_ids + image_latents = self.image_latents + # latents, latent_image_ids = self.prepare_latents( # batch_size * num_images_per_prompt, # num_channels_latents, @@ -749,28 +762,10 @@ def __call__( # generator, # latents, # ) - packed_inv_latents, latent_image_ids = self.prepare_inverted_latents( - self.inverted_latents.shape[0], - self.inverted_latents.shape[1], - self.inverted_latents.shape[2], - self.inverted_latents.shape[3], - prompt_embeds.dtype, - device, - generator, - ) - - packed_img_latents = self._pack_latents( - self.image_latents, - batch_size=self.image_latents.shape[0], - num_channels_latents=self.image_latents.shape[1], - height=self.image_latents.shape[2], - width=self.image_latents.shape[3], - ) - target_img = packed_img_latents.clone().to(torch.float32) # 5. Prepare timesteps sigmas = np.linspace(1.0, 1 / num_inference_steps, num_inference_steps) - image_seq_len = packed_inv_latents.shape[1] + image_seq_len = latents.shape[1] mu = calculate_shift( image_seq_len, self.scheduler.config.base_image_seq_len, @@ -792,21 +787,25 @@ def __call__( # handle guidance if self.transformer.config.guidance_embeds: guidance = torch.full([1], guidance_scale, device=device, dtype=torch.float32) - guidance = guidance.expand(packed_inv_latents.shape[0]) + guidance = guidance.expand(latents.shape[0]) else: guidance = None - eta_values = self.generate_eta_values(timesteps, start_step, end_step, eta_base, eta_trend) - # 6. Denoising loop with self.progress_bar(total=num_inference_steps) as progress_bar: - for i, (t_curr, t_prev) in enumerate(zip(timesteps[:-1], timesteps[1:])): + y_0 = image_latents.clone() + for i, t in enumerate(timesteps): + t_i = 1 - t / 1000 # torch.tensor((i+1) / (len(timesteps)-1), device=device) + # print(t_i, t) + dt = torch.tensor(1 / (len(timesteps) - 1), device=device) + if self.interrupt: + continue + # broadcast to batch dimension in a way that's compatible with ONNX/Core ML - # timestep = t.expand(latents.shape[0]).to(latents.dtype) - timestep = t_curr.expand(packed_inv_latents.shape[0]).to(packed_inv_latents.dtype) + timestep = t.expand(latents.shape[0]).to(latents.dtype) - velocity = self.transformer( - hidden_states=packed_inv_latents, + v_t = -self.transformer( + hidden_states=latents, timestep=timestep / 1000, guidance=guidance, pooled_projections=pooled_prompt_embeds, @@ -817,35 +816,31 @@ def __call__( return_dict=False, )[0] - # compute the previous noisy sample x_t -> x_t-1 - # latents_dtype = latents.dtype - # Prevents precision issues - packed_inv_latents = packed_inv_latents.to(torch.float32) - velocity = velocity.to(torch.float32) - - # Target image velocity - target_img_velocity = -(target_img - packed_inv_latents) / t_curr + v_t_cond = (y_0 - latents) / (1 - t_i) + eta_t = eta if start_timestep <= i < stop_timestep else 0.0 + if start_timestep <= i < stop_timestep: + # controlled vector field + v_hat_t = v_t + eta * (v_t_cond - v_t) + # v_hat_t = ((1 - t_i - eta_t) * latents + eta_t * t_i * y_0) / (t_i*(1 - t_i)) + 2*(1-t_i)*(1 - eta_t) /t_i * v_t - # interpolated velocity - eta = eta_values[i] - interpolated_velocity = eta * target_img_velocity + (1 - eta) * velocity - latents = packed_inv_latents + (t_prev - t_curr) * interpolated_velocity - print( - f"X_{t_prev:.3f} = X_{t_curr:.3f} + {t_prev - t_curr:.3f} * ({eta:.3f} * target_img_velocity + {1 - eta:.3f} * flux_velocity)" - ) + else: + v_hat_t = v_t + # SDE Eq: 17 - # latents = self.scheduler.step(noise_pred, t, latents, return_dict=False)[0] + # compute the previous noisy sample x_t -> x_t-1 + latents_dtype = latents.dtype + latents = latents + v_hat_t * (sigmas[i] - sigmas[i + 1]) - # if latents.dtype != latents_dtype: - # if torch.backends.mps.is_available(): - # # some platforms (eg. apple mps) misbehave due to a pytorch bug: https://github.com/pytorch/pytorch/pull/99272 - # latents = latents.to(latents_dtype) + if latents.dtype != latents_dtype: + if torch.backends.mps.is_available(): + # some platforms (eg. apple mps) misbehave due to a pytorch bug: https://github.com/pytorch/pytorch/pull/99272 + latents = latents.to(latents_dtype) if callback_on_step_end is not None: callback_kwargs = {} for k in callback_on_step_end_tensor_inputs: callback_kwargs[k] = locals()[k] - callback_outputs = callback_on_step_end(self, i, t_curr, callback_kwargs) + callback_outputs = callback_on_step_end(self, i, t, callback_kwargs) latents = callback_outputs.pop("latents", latents) prompt_embeds = callback_outputs.pop("prompt_embeds", prompt_embeds) @@ -1008,5 +1003,6 @@ def invert( Y_t = Y_t + u_hat_t_i * (sigmas[i] - sigmas[i+1]) self.inverted_latents = Y_t + self.latent_image_ids = latent_image_ids - return latent_image_ids + return self.image_latents, Y_t, latent_image_ids From f80ca4271cf90d96f175a89842304484e3f4a65c Mon Sep 17 00:00:00 2001 From: linoytsaban Date: Sun, 3 Nov 2024 22:24:44 +0200 Subject: [PATCH 12/45] shape & timesteps fixes --- examples/community/pipeline_flux_rf_inversion.py | 16 +++++++++++++--- 1 file changed, 13 insertions(+), 3 deletions(-) diff --git a/examples/community/pipeline_flux_rf_inversion.py b/examples/community/pipeline_flux_rf_inversion.py index 3fdb5a141ac9..839effe5ed79 100644 --- a/examples/community/pipeline_flux_rf_inversion.py +++ b/examples/community/pipeline_flux_rf_inversion.py @@ -577,11 +577,12 @@ def get_timesteps(self, num_inference_steps, strength, device): init_timestep = min(num_inference_steps * strength, num_inference_steps) t_start = int(max(num_inference_steps - init_timestep, 0)) - timesteps = self.scheduler.timesteps[t_start * self.scheduler.order:] + timesteps = self.scheduler.timesteps[t_start * self.scheduler.order :] + sigmas = self.scheduler.sigmas[t_start * self.scheduler.order :] if hasattr(self.scheduler, "set_begin_index"): self.scheduler.set_begin_index(t_start * self.scheduler.order) - return timesteps, num_inference_steps - t_start + return timesteps, sigmas, num_inference_steps - t_start @property def guidance_scale(self): @@ -707,6 +708,8 @@ def __call__( prompt_2, height, width, + start_timestep, + stop_timestep, prompt_embeds=prompt_embeds, pooled_prompt_embeds=pooled_prompt_embeds, callback_on_step_end_tensor_inputs=callback_on_step_end_tensor_inputs, @@ -907,25 +910,32 @@ def invert( source_guidance_scale=0.0, num_inversion_steps: int = 28, gamma: float = 0.5, + strength: float = 0.6, height: Optional[int] = None, width: Optional[int] = None, timesteps: List[int] = None, dtype: Optional[torch.dtype] = None, generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None, + joint_attention_kwargs: Optional[Dict[str, Any]] = None, ): dtype = dtype or self.text_encoder.dtype batch_size = 1 + self._joint_attention_kwargs = joint_attention_kwargs num_channels_latents = self.transformer.config.in_channels // 4 + height = height or self.default_sample_size * self.vae_scale_factor width = width or self.default_sample_size * self.vae_scale_factor device = self._execution_device # 1. prepare image - image_latents, _ = self.encode_image(image, dtype=dtype) + image_latents, _ = self.encode_image(image, height=height, width=width, dtype=dtype) _, latent_image_ids = self.prepare_latents(batch_size, num_channels_latents, height, width, dtype, device, generator) + + height = int(height) // self.vae_scale_factor + width = int(width) // self.vae_scale_factor image_latents = self._pack_latents(image_latents, batch_size, num_channels_latents, height, width) self.image_latents = image_latents.clone() From d964300fccf6316407b2457f8198e396a1eb8232 Mon Sep 17 00:00:00 2001 From: linoytsaban Date: Mon, 4 Nov 2024 00:10:51 +0200 Subject: [PATCH 13/45] timesteps fixes --- examples/community/pipeline_flux_rf_inversion.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/examples/community/pipeline_flux_rf_inversion.py b/examples/community/pipeline_flux_rf_inversion.py index 839effe5ed79..773ffc040bae 100644 --- a/examples/community/pipeline_flux_rf_inversion.py +++ b/examples/community/pipeline_flux_rf_inversion.py @@ -784,6 +784,7 @@ def __call__( sigmas, mu=mu, ) + timesteps, sigmas, num_inference_steps = self.get_timesteps(num_inference_steps, strength, device) num_warmup_steps = max(len(timesteps) - num_inference_steps * self.scheduler.order, 0) self._num_timesteps = len(timesteps) @@ -797,6 +798,7 @@ def __call__( # 6. Denoising loop with self.progress_bar(total=num_inference_steps) as progress_bar: y_0 = image_latents.clone() + print(f"y_0 shape: {y_0.shape}") for i, t in enumerate(timesteps): t_i = 1 - t / 1000 # torch.tensor((i+1) / (len(timesteps)-1), device=device) # print(t_i, t) From 5834e7b25f95cc1fac9627f3eee1a6128a4c0e27 Mon Sep 17 00:00:00 2001 From: Linoy Date: Sun, 3 Nov 2024 22:16:20 +0000 Subject: [PATCH 14/45] style --- examples/community/pipeline_flux_rf_inversion.py | 12 +++++------- 1 file changed, 5 insertions(+), 7 deletions(-) diff --git a/examples/community/pipeline_flux_rf_inversion.py b/examples/community/pipeline_flux_rf_inversion.py index 773ffc040bae..06ef5588d08b 100644 --- a/examples/community/pipeline_flux_rf_inversion.py +++ b/examples/community/pipeline_flux_rf_inversion.py @@ -136,6 +136,7 @@ def retrieve_timesteps( timesteps = scheduler.timesteps return timesteps, num_inference_steps + class RFInversionFluxPipeline( DiffusionPipeline, FluxLoraLoaderMixin, @@ -407,7 +408,6 @@ def check_inputs( pooled_prompt_embeds=None, callback_on_step_end_tensor_inputs=None, max_sequence_length=None, - ): if height % self.vae_scale_factor != 0 or width % self.vae_scale_factor != 0: raise ValueError( @@ -714,7 +714,6 @@ def __call__( pooled_prompt_embeds=pooled_prompt_embeds, callback_on_step_end_tensor_inputs=callback_on_step_end_tensor_inputs, max_sequence_length=max_sequence_length, - ) self._guidance_scale = guidance_scale @@ -931,10 +930,9 @@ def invert( # 1. prepare image image_latents, _ = self.encode_image(image, height=height, width=width, dtype=dtype) - _, latent_image_ids = self.prepare_latents(batch_size, - num_channels_latents, - height, width, - dtype, device, generator) + _, latent_image_ids = self.prepare_latents( + batch_size, num_channels_latents, height, width, dtype, device, generator + ) height = int(height) // self.vae_scale_factor width = int(width) // self.vae_scale_factor @@ -1012,7 +1010,7 @@ def invert( # controlled vector field # Eq 8 dY_t = [u_t(Y_t) + γ(u_t(Y_t|y_1) - u_t(Y_t))]dt u_hat_t_i = u_t_i + gamma * (u_t_i_cond - u_t_i) - Y_t = Y_t + u_hat_t_i * (sigmas[i] - sigmas[i+1]) + Y_t = Y_t + u_hat_t_i * (sigmas[i] - sigmas[i + 1]) self.inverted_latents = Y_t self.latent_image_ids = latent_image_ids From 910bce7176521d5ae342bbbe1f8c3b1a4b558925 Mon Sep 17 00:00:00 2001 From: linoytsaban Date: Mon, 4 Nov 2024 10:45:55 +0200 Subject: [PATCH 15/45] remove redundancies --- .../community/pipeline_flux_rf_inversion.py | 40 +------------------ 1 file changed, 1 insertion(+), 39 deletions(-) diff --git a/examples/community/pipeline_flux_rf_inversion.py b/examples/community/pipeline_flux_rf_inversion.py index 773ffc040bae..ebfecd38c172 100644 --- a/examples/community/pipeline_flux_rf_inversion.py +++ b/examples/community/pipeline_flux_rf_inversion.py @@ -615,7 +615,6 @@ def __call__( stop_timestep: int = 6, num_inference_steps: int = 28, timesteps: List[int] = None, - use_shift_t_sampling: bool = True, guidance_scale: float = 3.5, num_images_per_prompt: Optional[int] = 1, generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None, @@ -874,36 +873,6 @@ def __call__( return FluxPipelineOutput(images=image) - def generate_eta_values( - self, - timesteps, - start_step, - end_step, - eta, - eta_trend, - ): - assert ( - start_step < end_step and start_step >= 0 and end_step <= len(timesteps) - ), "Invalid start_step and end_step" - # timesteps are monotonically decreasing, from 1.0 to 0.0 - eta_values = [0.0] * (len(timesteps) - 1) - - if eta_trend == "constant": - for i in range(start_step, end_step): - eta_values[i] = eta - elif eta_trend == "linear_increase": - total_time = timesteps[start_step] - timesteps[end_step - 1] - for i in range(start_step, end_step): - eta_values[i] = eta * (timesteps[start_step] - timesteps[i]) / total_time - elif eta_trend == "linear_decrease": - total_time = timesteps[start_step] - timesteps[end_step - 1] - for i in range(start_step, end_step): - eta_values[i] = eta * (timesteps[i] - timesteps[end_step - 1]) / total_time - else: - raise NotImplementedError(f"Unsupported eta_trend: {eta_trend}") - - return eta_values - @torch.no_grad() def invert( self, @@ -977,13 +946,6 @@ def invert( else: guidance = None - # if num_inference_steps < 1: - # raise ValueError( - # f"After adjusting the num_inference_steps by strength parameter: {strength}, the number of pipeline" - # f"steps is {num_inference_steps} which is < 1 and not appropriate for this pipeline." - # ) - # latent_timestep = timesteps[:1].repeat(batch_size * num_images_per_prompt) - # Eq 8 dY_t = [u_t(Y_t) + γ(u_t(Y_t|y_1) - u_t(Y_t))]dt Y_t = image_latents y_1 = torch.randn_like(Y_t) @@ -992,8 +954,8 @@ def invert( for i in range(N - 1): # enumerate(timesteps): t_i = torch.tensor(i / (N), dtype=Y_t.dtype, device=device) timestep = torch.tensor(t_i, dtype=Y_t.dtype, device=device).repeat(batch_size) - # get the unconditional vector field + # get the unconditional vector field u_t_i = self.transformer( hidden_states=Y_t, timestep=timestep, From 156e611669dc145be8e5c5d027fb07020f7d10a4 Mon Sep 17 00:00:00 2001 From: linoytsaban Date: Mon, 4 Nov 2024 11:18:24 +0200 Subject: [PATCH 16/45] small changes --- .../community/pipeline_flux_rf_inversion.py | 18 ++++++++---------- 1 file changed, 8 insertions(+), 10 deletions(-) diff --git a/examples/community/pipeline_flux_rf_inversion.py b/examples/community/pipeline_flux_rf_inversion.py index bdc7c3416d98..08b70c3428be 100644 --- a/examples/community/pipeline_flux_rf_inversion.py +++ b/examples/community/pipeline_flux_rf_inversion.py @@ -795,12 +795,13 @@ def __call__( # 6. Denoising loop with self.progress_bar(total=num_inference_steps) as progress_bar: + y_0 = image_latents.clone() - print(f"y_0 shape: {y_0.shape}") for i, t in enumerate(timesteps): - t_i = 1 - t / 1000 # torch.tensor((i+1) / (len(timesteps)-1), device=device) - # print(t_i, t) + + t_i = 1 - t / 1000 dt = torch.tensor(1 / (len(timesteps) - 1), device=device) + if self.interrupt: continue @@ -824,13 +825,11 @@ def __call__( if start_timestep <= i < stop_timestep: # controlled vector field v_hat_t = v_t + eta * (v_t_cond - v_t) - # v_hat_t = ((1 - t_i - eta_t) * latents + eta_t * t_i * y_0) / (t_i*(1 - t_i)) + 2*(1-t_i)*(1 - eta_t) /t_i * v_t else: v_hat_t = v_t # SDE Eq: 17 - # compute the previous noisy sample x_t -> x_t-1 latents_dtype = latents.dtype latents = latents + v_hat_t * (sigmas[i] - sigmas[i + 1]) @@ -918,7 +917,7 @@ def invert( self.scheduler.config.base_shift, self.scheduler.config.max_shift, ) - timesteps, num_inference_steps = retrieve_timesteps( + timesteps, num_inversion_steps = retrieve_timesteps( self.scheduler, num_inversion_steps, device, @@ -926,7 +925,7 @@ def invert( sigmas, mu=mu, ) - timesteps, sigmas, num_inference_steps = self.get_timesteps(num_inference_steps, strength, device) + timesteps, sigmas, num_inversion_steps = self.get_timesteps(num_inversion_steps, strength, device) # 3. prepare text embeddings ( @@ -947,10 +946,9 @@ def invert( # Eq 8 dY_t = [u_t(Y_t) + γ(u_t(Y_t|y_1) - u_t(Y_t))]dt Y_t = image_latents y_1 = torch.randn_like(Y_t) - N = len(sigmas) - for i in range(N - 1): # enumerate(timesteps): - t_i = torch.tensor(i / (N), dtype=Y_t.dtype, device=device) + for i in range(num_inversion_steps - 1): + t_i = torch.tensor(i / (num_inversion_steps), dtype=Y_t.dtype, device=device) timestep = torch.tensor(t_i, dtype=Y_t.dtype, device=device).repeat(batch_size) # get the unconditional vector field From b32a08cf83609a4afa974ff7e0898ec4ab9b561b Mon Sep 17 00:00:00 2001 From: linoytsaban Date: Mon, 4 Nov 2024 14:16:45 +0200 Subject: [PATCH 17/45] update documentation a bit --- .../community/pipeline_flux_rf_inversion.py | 37 +++++++++++++++++-- 1 file changed, 33 insertions(+), 4 deletions(-) diff --git a/examples/community/pipeline_flux_rf_inversion.py b/examples/community/pipeline_flux_rf_inversion.py index 08b70c3428be..debfe5de7420 100644 --- a/examples/community/pipeline_flux_rf_inversion.py +++ b/examples/community/pipeline_flux_rf_inversion.py @@ -572,7 +572,7 @@ def prepare_inverted_latents( return latents, latent_image_ids # Copied from diffusers.pipelines.stable_diffusion_3.pipeline_stable_diffusion_3_img2img.StableDiffusion3Img2ImgPipeline.get_timesteps - def get_timesteps(self, num_inference_steps, strength, device): + def get_timesteps(self, num_inference_steps, strength): # get the original timestep using init_timestep init_timestep = min(num_inference_steps * strength, num_inference_steps) @@ -610,7 +610,6 @@ def __call__( width: Optional[int] = None, strength: float = 0.6, eta: float = 1.0, - gamma: float = 1.0, start_timestep: int = 0, stop_timestep: int = 6, num_inference_steps: int = 28, @@ -642,6 +641,9 @@ def __call__( The height in pixels of the generated image. This is set to 1024 by default for the best results. width (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor): The width in pixels of the generated image. This is set to 1024 by default for the best results. + eta (`float`, *optional*, defaults to 1.0): + The controller guidance, balancing faithfulness & editability: + higher eta - better faithfullness, less editability. For more significant edits, lower the value of eta. num_inference_steps (`int`, *optional*, defaults to 50): The number of denoising steps. More denoising steps usually lead to a higher quality image at the expense of slower inference. @@ -782,7 +784,7 @@ def __call__( sigmas, mu=mu, ) - timesteps, sigmas, num_inference_steps = self.get_timesteps(num_inference_steps, strength, device) + timesteps, sigmas, num_inference_steps = self.get_timesteps(num_inference_steps, strength) num_warmup_steps = max(len(timesteps) - num_inference_steps * self.scheduler.order, 0) self._num_timesteps = len(timesteps) @@ -887,6 +889,33 @@ def invert( generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None, joint_attention_kwargs: Optional[Dict[str, Any]] = None, ): + r""" + Performs Algorithm 1: Controlled Forward ODE from https://arxiv.org/pdf/2410.10792 + Args: + image (`PipelineImageInput`): + Input for the image(s) that are to be edited. Multiple input images have to default to the same aspect + ratio. + source_prompt (`str` or `List[str]`, *optional* defaults to an empty prompt as done in the original paper): + The prompt or prompts to guide the image generation. If not defined, one has to pass `prompt_embeds`. + instead. + source_guidance_scale (`float`, *optional*, defaults to 0.0): + Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598). + `guidance_scale` is defined as `w` of equation 2. of [Imagen + Paper](https://arxiv.org/pdf/2205.11487.pdf). For this algorithm, it's better to keep it 0. + num_inversion_steps (`int`, *optional*, defaults to 28): + The number of discretization steps. + height (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor): + The height in pixels of the generated image. This is set to 1024 by default for the best results. + width (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor): + The width in pixels of the generated image. This is set to 1024 by default for the best results. + gamma (`float`, *optional*, defaults to 0.5): + The controller guidance for the forward ODE, balancing faithfulness & editability: + higher eta - better faithfullness, less editability. For more significant edits, lower the value of eta. + timesteps (`List[int]`, *optional*): + Custom timesteps to use for the denoising process with schedulers which support a `timesteps` argument + in their `set_timesteps` method. If not defined, the default behavior when `num_inference_steps` is + passed will be used. Must be in descending order. + """ dtype = dtype or self.text_encoder.dtype batch_size = 1 self._joint_attention_kwargs = joint_attention_kwargs @@ -925,7 +954,7 @@ def invert( sigmas, mu=mu, ) - timesteps, sigmas, num_inversion_steps = self.get_timesteps(num_inversion_steps, strength, device) + timesteps, sigmas, num_inversion_steps = self.get_timesteps(num_inversion_steps, strength) # 3. prepare text embeddings ( From 87146476c7be56ba50ccecc847b4c4a2601459f0 Mon Sep 17 00:00:00 2001 From: linoytsaban Date: Mon, 4 Nov 2024 14:24:43 +0200 Subject: [PATCH 18/45] update documentation a bit --- .../community/pipeline_flux_rf_inversion.py | 19 ++++++++++++++----- 1 file changed, 14 insertions(+), 5 deletions(-) diff --git a/examples/community/pipeline_flux_rf_inversion.py b/examples/community/pipeline_flux_rf_inversion.py index debfe5de7420..d818e5926a57 100644 --- a/examples/community/pipeline_flux_rf_inversion.py +++ b/examples/community/pipeline_flux_rf_inversion.py @@ -55,11 +55,20 @@ >>> pipe = FluxPipeline.from_pretrained("black-forest-labs/FLUX.1-schnell", torch_dtype=torch.bfloat16) >>> pipe.to("cuda") - >>> prompt = "A cat holding a sign that says hello world" - >>> # Depending on the variant being used, the pipeline call will slightly vary. - >>> # Refer to the pipeline documentation for more details. - >>> image = pipe(prompt, num_inference_steps=4, guidance_scale=0.0).images[0] - >>> image.save("flux.png") + + >>> def download_image(url): + ... response = requests.get(url) + ... return PIL.Image.open(BytesIO(response.content)).convert("RGB") + + + >>> img_url = "https://www.aiml.informatik.tu-darmstadt.de/people/mbrack/tennis.jpg" + >>> image = download_image(img_url) + + >>> _,__,___ = pipe.invert(image=image, num_inversion_steps=28) + + >>> edited_image = pipe( + ... prompt="a portrait of a tiger", + ... ).images[0] ``` """ From c64a0e72610b5745b0d0dd194b12f21f91fad5d2 Mon Sep 17 00:00:00 2001 From: linoytsaban Date: Mon, 4 Nov 2024 15:31:06 +0200 Subject: [PATCH 19/45] update documentation a bit --- examples/community/pipeline_flux_rf_inversion.py | 11 ----------- 1 file changed, 11 deletions(-) diff --git a/examples/community/pipeline_flux_rf_inversion.py b/examples/community/pipeline_flux_rf_inversion.py index d818e5926a57..75a1124e3d90 100644 --- a/examples/community/pipeline_flux_rf_inversion.py +++ b/examples/community/pipeline_flux_rf_inversion.py @@ -764,17 +764,6 @@ def __call__( latent_image_ids = self.latent_image_ids image_latents = self.image_latents - # latents, latent_image_ids = self.prepare_latents( - # batch_size * num_images_per_prompt, - # num_channels_latents, - # height, - # width, - # prompt_embeds.dtype, - # device, - # generator, - # latents, - # ) - # 5. Prepare timesteps sigmas = np.linspace(1.0, 1 / num_inference_steps, num_inference_steps) image_seq_len = latents.shape[1] From 142aef569c1d0efef405b3178924e23de2661114 Mon Sep 17 00:00:00 2001 From: Linoy Date: Mon, 4 Nov 2024 13:32:22 +0000 Subject: [PATCH 20/45] style --- examples/community/pipeline_flux_rf_inversion.py | 4 +--- 1 file changed, 1 insertion(+), 3 deletions(-) diff --git a/examples/community/pipeline_flux_rf_inversion.py b/examples/community/pipeline_flux_rf_inversion.py index 75a1124e3d90..2a757c97e8c2 100644 --- a/examples/community/pipeline_flux_rf_inversion.py +++ b/examples/community/pipeline_flux_rf_inversion.py @@ -55,7 +55,7 @@ >>> pipe = FluxPipeline.from_pretrained("black-forest-labs/FLUX.1-schnell", torch_dtype=torch.bfloat16) >>> pipe.to("cuda") - + >>> def download_image(url): ... response = requests.get(url) ... return PIL.Image.open(BytesIO(response.content)).convert("RGB") @@ -795,10 +795,8 @@ def __call__( # 6. Denoising loop with self.progress_bar(total=num_inference_steps) as progress_bar: - y_0 = image_latents.clone() for i, t in enumerate(timesteps): - t_i = 1 - t / 1000 dt = torch.tensor(1 / (len(timesteps) - 1), device=device) From 9d8b37c10c8f02d15da708bf7c1a7cddc6d1bc0d Mon Sep 17 00:00:00 2001 From: linoytsaban Date: Tue, 5 Nov 2024 15:30:56 +0200 Subject: [PATCH 21/45] change strength param, remove redundancies --- .../community/pipeline_flux_rf_inversion.py | 104 ++++++++---------- 1 file changed, 43 insertions(+), 61 deletions(-) diff --git a/examples/community/pipeline_flux_rf_inversion.py b/examples/community/pipeline_flux_rf_inversion.py index 75a1124e3d90..4422d03fb7c6 100644 --- a/examples/community/pipeline_flux_rf_inversion.py +++ b/examples/community/pipeline_flux_rf_inversion.py @@ -559,31 +559,10 @@ def prepare_latents( return latents, latent_image_ids - def prepare_inverted_latents( - self, - batch_size, - num_channels_latents, - height, - width, - dtype, - device, - generator, - ): - latent_image_ids = self._prepare_latent_image_ids(batch_size, height, width, device, dtype) - - if isinstance(generator, list) and len(generator) != batch_size: - raise ValueError( - f"You have passed a list of generators of length {len(generator)}, but requested an effective batch" - f" size of {batch_size}. Make sure the batch size matches the length of the generators." - ) - - latents = self._pack_latents(self.inverted_latents, batch_size, num_channels_latents, height, width) - return latents, latent_image_ids - # Copied from diffusers.pipelines.stable_diffusion_3.pipeline_stable_diffusion_3_img2img.StableDiffusion3Img2ImgPipeline.get_timesteps - def get_timesteps(self, num_inference_steps, strength): + def get_timesteps(self, num_inference_steps, timestep_offset): # get the original timestep using init_timestep - init_timestep = min(num_inference_steps * strength, num_inference_steps) + init_timestep = min(num_inference_steps * timestep_offset, num_inference_steps) t_start = int(max(num_inference_steps - init_timestep, 0)) timesteps = self.scheduler.timesteps[t_start * self.scheduler.order :] @@ -612,29 +591,29 @@ def interrupt(self): @torch.no_grad() @replace_example_docstring(EXAMPLE_DOC_STRING) def __call__( - self, - prompt: Union[str, List[str]] = None, - prompt_2: Optional[Union[str, List[str]]] = None, - height: Optional[int] = None, - width: Optional[int] = None, - strength: float = 0.6, - eta: float = 1.0, - start_timestep: int = 0, - stop_timestep: int = 6, - num_inference_steps: int = 28, - timesteps: List[int] = None, - guidance_scale: float = 3.5, - num_images_per_prompt: Optional[int] = 1, - generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None, - latents: Optional[torch.FloatTensor] = None, - prompt_embeds: Optional[torch.FloatTensor] = None, - pooled_prompt_embeds: Optional[torch.FloatTensor] = None, - output_type: Optional[str] = "pil", - return_dict: bool = True, - joint_attention_kwargs: Optional[Dict[str, Any]] = None, - callback_on_step_end: Optional[Callable[[int, int, Dict], None]] = None, - callback_on_step_end_tensor_inputs: List[str] = ["latents"], - max_sequence_length: int = 512, + self, + prompt: Union[str, List[str]] = None, + prompt_2: Optional[Union[str, List[str]]] = None, + height: Optional[int] = None, + width: Optional[int] = None, + eta: float = 1.0, + timestep_offset: float = 0.6, + start_timestep: float = 0., + stop_timestep: float = 0.25, + num_inference_steps: int = 28, + timesteps: List[int] = None, + guidance_scale: float = 3.5, + num_images_per_prompt: Optional[int] = 1, + generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None, + latents: Optional[torch.FloatTensor] = None, + prompt_embeds: Optional[torch.FloatTensor] = None, + pooled_prompt_embeds: Optional[torch.FloatTensor] = None, + output_type: Optional[str] = "pil", + return_dict: bool = True, + joint_attention_kwargs: Optional[Dict[str, Any]] = None, + callback_on_step_end: Optional[Callable[[int, int, Dict], None]] = None, + callback_on_step_end_tensor_inputs: List[str] = ["latents"], + max_sequence_length: int = 512, ): r""" Function invoked when calling the pipeline for generation. @@ -782,7 +761,10 @@ def __call__( sigmas, mu=mu, ) - timesteps, sigmas, num_inference_steps = self.get_timesteps(num_inference_steps, strength) + start_timestep = int(start_timestep * num_inference_steps) + stop_timestep = min(int(stop_timestep * num_inference_steps), num_inference_steps) + + timesteps, sigmas, num_inference_steps = self.get_timesteps(num_inference_steps, timestep_offset) num_warmup_steps = max(len(timesteps) - num_inference_steps * self.scheduler.order, 0) self._num_timesteps = len(timesteps) @@ -873,19 +855,19 @@ def __call__( @torch.no_grad() def invert( - self, - image: PipelineImageInput, - source_prompt: str = "", - source_guidance_scale=0.0, - num_inversion_steps: int = 28, - gamma: float = 0.5, - strength: float = 0.6, - height: Optional[int] = None, - width: Optional[int] = None, - timesteps: List[int] = None, - dtype: Optional[torch.dtype] = None, - generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None, - joint_attention_kwargs: Optional[Dict[str, Any]] = None, + self, + image: PipelineImageInput, + source_prompt: str = "", + source_guidance_scale=0.0, + num_inversion_steps: int = 28, + timestep_offset: float = 0.6, + gamma: float = 0.5, + height: Optional[int] = None, + width: Optional[int] = None, + timesteps: List[int] = None, + dtype: Optional[torch.dtype] = None, + generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None, + joint_attention_kwargs: Optional[Dict[str, Any]] = None, ): r""" Performs Algorithm 1: Controlled Forward ODE from https://arxiv.org/pdf/2410.10792 @@ -952,7 +934,7 @@ def invert( sigmas, mu=mu, ) - timesteps, sigmas, num_inversion_steps = self.get_timesteps(num_inversion_steps, strength) + timesteps, sigmas, num_inversion_steps = self.get_timesteps(num_inversion_steps, timestep_offset=timestep_offset) # 3. prepare text embeddings ( From 61aea50b7b4fe01dccda0ac3ea9f00ae2dbc2b91 Mon Sep 17 00:00:00 2001 From: Linoy Date: Tue, 5 Nov 2024 13:52:44 +0000 Subject: [PATCH 22/45] style --- .../community/pipeline_flux_rf_inversion.py | 76 ++++++++++--------- 1 file changed, 39 insertions(+), 37 deletions(-) diff --git a/examples/community/pipeline_flux_rf_inversion.py b/examples/community/pipeline_flux_rf_inversion.py index ec8852783429..880c0a218fd7 100644 --- a/examples/community/pipeline_flux_rf_inversion.py +++ b/examples/community/pipeline_flux_rf_inversion.py @@ -591,29 +591,29 @@ def interrupt(self): @torch.no_grad() @replace_example_docstring(EXAMPLE_DOC_STRING) def __call__( - self, - prompt: Union[str, List[str]] = None, - prompt_2: Optional[Union[str, List[str]]] = None, - height: Optional[int] = None, - width: Optional[int] = None, - eta: float = 1.0, - timestep_offset: float = 0.6, - start_timestep: float = 0., - stop_timestep: float = 0.25, - num_inference_steps: int = 28, - timesteps: List[int] = None, - guidance_scale: float = 3.5, - num_images_per_prompt: Optional[int] = 1, - generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None, - latents: Optional[torch.FloatTensor] = None, - prompt_embeds: Optional[torch.FloatTensor] = None, - pooled_prompt_embeds: Optional[torch.FloatTensor] = None, - output_type: Optional[str] = "pil", - return_dict: bool = True, - joint_attention_kwargs: Optional[Dict[str, Any]] = None, - callback_on_step_end: Optional[Callable[[int, int, Dict], None]] = None, - callback_on_step_end_tensor_inputs: List[str] = ["latents"], - max_sequence_length: int = 512, + self, + prompt: Union[str, List[str]] = None, + prompt_2: Optional[Union[str, List[str]]] = None, + height: Optional[int] = None, + width: Optional[int] = None, + eta: float = 1.0, + timestep_offset: float = 0.6, + start_timestep: float = 0.0, + stop_timestep: float = 0.25, + num_inference_steps: int = 28, + timesteps: List[int] = None, + guidance_scale: float = 3.5, + num_images_per_prompt: Optional[int] = 1, + generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None, + latents: Optional[torch.FloatTensor] = None, + prompt_embeds: Optional[torch.FloatTensor] = None, + pooled_prompt_embeds: Optional[torch.FloatTensor] = None, + output_type: Optional[str] = "pil", + return_dict: bool = True, + joint_attention_kwargs: Optional[Dict[str, Any]] = None, + callback_on_step_end: Optional[Callable[[int, int, Dict], None]] = None, + callback_on_step_end_tensor_inputs: List[str] = ["latents"], + max_sequence_length: int = 512, ): r""" Function invoked when calling the pipeline for generation. @@ -853,19 +853,19 @@ def __call__( @torch.no_grad() def invert( - self, - image: PipelineImageInput, - source_prompt: str = "", - source_guidance_scale=0.0, - num_inversion_steps: int = 28, - timestep_offset: float = 0.6, - gamma: float = 0.5, - height: Optional[int] = None, - width: Optional[int] = None, - timesteps: List[int] = None, - dtype: Optional[torch.dtype] = None, - generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None, - joint_attention_kwargs: Optional[Dict[str, Any]] = None, + self, + image: PipelineImageInput, + source_prompt: str = "", + source_guidance_scale=0.0, + num_inversion_steps: int = 28, + timestep_offset: float = 0.6, + gamma: float = 0.5, + height: Optional[int] = None, + width: Optional[int] = None, + timesteps: List[int] = None, + dtype: Optional[torch.dtype] = None, + generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None, + joint_attention_kwargs: Optional[Dict[str, Any]] = None, ): r""" Performs Algorithm 1: Controlled Forward ODE from https://arxiv.org/pdf/2410.10792 @@ -932,7 +932,9 @@ def invert( sigmas, mu=mu, ) - timesteps, sigmas, num_inversion_steps = self.get_timesteps(num_inversion_steps, timestep_offset=timestep_offset) + timesteps, sigmas, num_inversion_steps = self.get_timesteps( + num_inversion_steps, timestep_offset=timestep_offset + ) # 3. prepare text embeddings ( From 61c810572b170d8573cd9349ebe4a61baecd2087 Mon Sep 17 00:00:00 2001 From: linoytsaban Date: Mon, 25 Nov 2024 10:27:22 +0200 Subject: [PATCH 23/45] forward ode loop change --- .../community/pipeline_flux_rf_inversion.py | 51 ++++++++++--------- 1 file changed, 27 insertions(+), 24 deletions(-) diff --git a/examples/community/pipeline_flux_rf_inversion.py b/examples/community/pipeline_flux_rf_inversion.py index 880c0a218fd7..525110dc9ea8 100644 --- a/examples/community/pipeline_flux_rf_inversion.py +++ b/examples/community/pipeline_flux_rf_inversion.py @@ -955,31 +955,34 @@ def invert( # Eq 8 dY_t = [u_t(Y_t) + γ(u_t(Y_t|y_1) - u_t(Y_t))]dt Y_t = image_latents y_1 = torch.randn_like(Y_t) + N = len(sigmas) - for i in range(num_inversion_steps - 1): - t_i = torch.tensor(i / (num_inversion_steps), dtype=Y_t.dtype, device=device) - timestep = torch.tensor(t_i, dtype=Y_t.dtype, device=device).repeat(batch_size) - - # get the unconditional vector field - u_t_i = self.transformer( - hidden_states=Y_t, - timestep=timestep, - guidance=guidance, - pooled_projections=pooled_prompt_embeds, - encoder_hidden_states=prompt_embeds, - txt_ids=text_ids, - img_ids=latent_image_ids, - joint_attention_kwargs=self.joint_attention_kwargs, - return_dict=False, - )[0] - - # get the conditional vector field - u_t_i_cond = (y_1 - Y_t) / (1 - t_i) - - # controlled vector field - # Eq 8 dY_t = [u_t(Y_t) + γ(u_t(Y_t|y_1) - u_t(Y_t))]dt - u_hat_t_i = u_t_i + gamma * (u_t_i_cond - u_t_i) - Y_t = Y_t + u_hat_t_i * (sigmas[i] - sigmas[i + 1]) + # forward ODE loop + with self.progress_bar(total=num_inference_steps) as progress_bar: + for i in range(N - 1): + t_i = torch.tensor(i / (N), dtype=Y_t.dtype, device=device) + timestep = torch.tensor(t_i, dtype=Y_t.dtype, device=device).repeat(batch_size) + + # get the unconditional vector field + u_t_i = self.transformer( + hidden_states=Y_t, + timestep=timestep, + guidance=guidance, + pooled_projections=pooled_prompt_embeds, + encoder_hidden_states=prompt_embeds, + txt_ids=text_ids, + img_ids=latent_image_ids, + joint_attention_kwargs=self.joint_attention_kwargs, + return_dict=False, + )[0] + + # get the conditional vector field + u_t_i_cond = (y_1 - Y_t) / (1 - t_i) + + # controlled vector field + # Eq 8 dY_t = [u_t(Y_t) + γ(u_t(Y_t|y_1) - u_t(Y_t))]dt + u_hat_t_i = u_t_i + gamma * (u_t_i_cond - u_t_i) + Y_t = Y_t + u_hat_t_i * (sigmas[i] - sigmas[i + 1]) self.inverted_latents = Y_t self.latent_image_ids = latent_image_ids From 093a91bddddeb432d938119f88cace75559fba3a Mon Sep 17 00:00:00 2001 From: linoytsaban Date: Mon, 25 Nov 2024 14:09:56 +0200 Subject: [PATCH 24/45] add inversion progress bar --- examples/community/pipeline_flux_rf_inversion.py | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/examples/community/pipeline_flux_rf_inversion.py b/examples/community/pipeline_flux_rf_inversion.py index 525110dc9ea8..9a8b7ef08fa0 100644 --- a/examples/community/pipeline_flux_rf_inversion.py +++ b/examples/community/pipeline_flux_rf_inversion.py @@ -958,7 +958,7 @@ def invert( N = len(sigmas) # forward ODE loop - with self.progress_bar(total=num_inference_steps) as progress_bar: + with self.progress_bar(total=N) as progress_bar: for i in range(N - 1): t_i = torch.tensor(i / (N), dtype=Y_t.dtype, device=device) timestep = torch.tensor(t_i, dtype=Y_t.dtype, device=device).repeat(batch_size) @@ -983,8 +983,9 @@ def invert( # Eq 8 dY_t = [u_t(Y_t) + γ(u_t(Y_t|y_1) - u_t(Y_t))]dt u_hat_t_i = u_t_i + gamma * (u_t_i_cond - u_t_i) Y_t = Y_t + u_hat_t_i * (sigmas[i] - sigmas[i + 1]) + progress_bar.update() self.inverted_latents = Y_t self.latent_image_ids = latent_image_ids - return self.image_latents, Y_t, latent_image_ids + return self.image_latents, Y_t, latent_image_ids \ No newline at end of file From e1c905c3a34f01e453794e0597b37bb902d7cf3b Mon Sep 17 00:00:00 2001 From: linoytsaban Date: Mon, 25 Nov 2024 14:41:01 +0200 Subject: [PATCH 25/45] fix image_seq_len --- .../community/pipeline_flux_rf_inversion.py | 32 ++++--------------- 1 file changed, 7 insertions(+), 25 deletions(-) diff --git a/examples/community/pipeline_flux_rf_inversion.py b/examples/community/pipeline_flux_rf_inversion.py index 9a8b7ef08fa0..0d6fc3557be8 100644 --- a/examples/community/pipeline_flux_rf_inversion.py +++ b/examples/community/pipeline_flux_rf_inversion.py @@ -534,26 +534,12 @@ def prepare_latents( width, dtype, device, - generator, - latents=None, + image_latents, ): height = int(height) // self.vae_scale_factor width = int(width) // self.vae_scale_factor - shape = (batch_size, num_channels_latents, height, width) - - if latents is not None: - latent_image_ids = self._prepare_latent_image_ids(batch_size, height, width, device, dtype) - return latents.to(device=device, dtype=dtype), latent_image_ids - - if isinstance(generator, list) and len(generator) != batch_size: - raise ValueError( - f"You have passed a list of generators of length {len(generator)}, but requested an effective batch" - f" size of {batch_size}. Make sure the batch size matches the length of the generators." - ) - - latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype) - latents = self._pack_latents(latents, batch_size, num_channels_latents, height, width) + latents = self._pack_latents(image_latents, batch_size, num_channels_latents, height, width) latent_image_ids = self._prepare_latent_image_ids(batch_size, height // 2, width // 2, device, dtype) @@ -745,7 +731,7 @@ def __call__( # 5. Prepare timesteps sigmas = np.linspace(1.0, 1 / num_inference_steps, num_inference_steps) - image_seq_len = latents.shape[1] + image_seq_len = (int(height) // self.vae_scale_factor // 2) * (int(width) // self.vae_scale_factor // 2) mu = calculate_shift( image_seq_len, self.scheduler.config.base_image_seq_len, @@ -905,18 +891,14 @@ def invert( # 1. prepare image image_latents, _ = self.encode_image(image, height=height, width=width, dtype=dtype) - _, latent_image_ids = self.prepare_latents( - batch_size, num_channels_latents, height, width, dtype, device, generator + image_latents, latent_image_ids = self.prepare_latents( + batch_size, num_channels_latents, height, width, dtype, device, image_latents ) - - height = int(height) // self.vae_scale_factor - width = int(width) // self.vae_scale_factor - image_latents = self._pack_latents(image_latents, batch_size, num_channels_latents, height, width) self.image_latents = image_latents.clone() # 2. prepare timesteps sigmas = np.linspace(1.0, 1 / num_inversion_steps, num_inversion_steps) - image_seq_len = (int(height) // self.vae_scale_factor) * (int(width) // self.vae_scale_factor) + image_seq_len = (int(height) // self.vae_scale_factor // 2) * (int(width) // self.vae_scale_factor // 2) mu = calculate_shift( image_seq_len, self.scheduler.config.base_image_seq_len, @@ -958,7 +940,7 @@ def invert( N = len(sigmas) # forward ODE loop - with self.progress_bar(total=N) as progress_bar: + with self.progress_bar(total=N-1) as progress_bar: for i in range(N - 1): t_i = torch.tensor(i / (N), dtype=Y_t.dtype, device=device) timestep = torch.tensor(t_i, dtype=Y_t.dtype, device=device).repeat(batch_size) From 645885ff79af9c482a10fde22adafadb15f43971 Mon Sep 17 00:00:00 2001 From: linoytsaban Date: Mon, 25 Nov 2024 18:07:57 +0200 Subject: [PATCH 26/45] revert to strength but == 1 by default. --- examples/community/pipeline_flux_rf_inversion.py | 15 +++++++-------- 1 file changed, 7 insertions(+), 8 deletions(-) diff --git a/examples/community/pipeline_flux_rf_inversion.py b/examples/community/pipeline_flux_rf_inversion.py index 0d6fc3557be8..75e738e00947 100644 --- a/examples/community/pipeline_flux_rf_inversion.py +++ b/examples/community/pipeline_flux_rf_inversion.py @@ -546,9 +546,9 @@ def prepare_latents( return latents, latent_image_ids # Copied from diffusers.pipelines.stable_diffusion_3.pipeline_stable_diffusion_3_img2img.StableDiffusion3Img2ImgPipeline.get_timesteps - def get_timesteps(self, num_inference_steps, timestep_offset): + def get_timesteps(self, num_inference_steps, strength=1.): # get the original timestep using init_timestep - init_timestep = min(num_inference_steps * timestep_offset, num_inference_steps) + init_timestep = min(num_inference_steps * strength, num_inference_steps) t_start = int(max(num_inference_steps - init_timestep, 0)) timesteps = self.scheduler.timesteps[t_start * self.scheduler.order :] @@ -583,8 +583,8 @@ def __call__( height: Optional[int] = None, width: Optional[int] = None, eta: float = 1.0, - timestep_offset: float = 0.6, - start_timestep: float = 0.0, + strength: float = 1., + start_timestep: float = 0, stop_timestep: float = 0.25, num_inference_steps: int = 28, timesteps: List[int] = None, @@ -749,8 +749,7 @@ def __call__( ) start_timestep = int(start_timestep * num_inference_steps) stop_timestep = min(int(stop_timestep * num_inference_steps), num_inference_steps) - - timesteps, sigmas, num_inference_steps = self.get_timesteps(num_inference_steps, timestep_offset) + timesteps, sigmas, num_inference_steps = self.get_timesteps(num_inference_steps, strength) num_warmup_steps = max(len(timesteps) - num_inference_steps * self.scheduler.order, 0) self._num_timesteps = len(timesteps) @@ -844,7 +843,7 @@ def invert( source_prompt: str = "", source_guidance_scale=0.0, num_inversion_steps: int = 28, - timestep_offset: float = 0.6, + strength: float = 1., gamma: float = 0.5, height: Optional[int] = None, width: Optional[int] = None, @@ -915,7 +914,7 @@ def invert( mu=mu, ) timesteps, sigmas, num_inversion_steps = self.get_timesteps( - num_inversion_steps, timestep_offset=timestep_offset + num_inversion_steps, strength ) # 3. prepare text embeddings From d887572c8151b7f13ac2363796b2f6007fc99ade Mon Sep 17 00:00:00 2001 From: Linoy Date: Mon, 25 Nov 2024 16:09:51 +0000 Subject: [PATCH 27/45] style --- examples/community/pipeline_flux_rf_inversion.py | 15 ++++++--------- 1 file changed, 6 insertions(+), 9 deletions(-) diff --git a/examples/community/pipeline_flux_rf_inversion.py b/examples/community/pipeline_flux_rf_inversion.py index 75e738e00947..77aa6d436ea0 100644 --- a/examples/community/pipeline_flux_rf_inversion.py +++ b/examples/community/pipeline_flux_rf_inversion.py @@ -34,7 +34,6 @@ scale_lora_layers, unscale_lora_layers, ) -from diffusers.utils.torch_utils import randn_tensor if is_torch_xla_available(): @@ -546,7 +545,7 @@ def prepare_latents( return latents, latent_image_ids # Copied from diffusers.pipelines.stable_diffusion_3.pipeline_stable_diffusion_3_img2img.StableDiffusion3Img2ImgPipeline.get_timesteps - def get_timesteps(self, num_inference_steps, strength=1.): + def get_timesteps(self, num_inference_steps, strength=1.0): # get the original timestep using init_timestep init_timestep = min(num_inference_steps * strength, num_inference_steps) @@ -583,7 +582,7 @@ def __call__( height: Optional[int] = None, width: Optional[int] = None, eta: float = 1.0, - strength: float = 1., + strength: float = 1.0, start_timestep: float = 0, stop_timestep: float = 0.25, num_inference_steps: int = 28, @@ -843,7 +842,7 @@ def invert( source_prompt: str = "", source_guidance_scale=0.0, num_inversion_steps: int = 28, - strength: float = 1., + strength: float = 1.0, gamma: float = 0.5, height: Optional[int] = None, width: Optional[int] = None, @@ -913,9 +912,7 @@ def invert( sigmas, mu=mu, ) - timesteps, sigmas, num_inversion_steps = self.get_timesteps( - num_inversion_steps, strength - ) + timesteps, sigmas, num_inversion_steps = self.get_timesteps(num_inversion_steps, strength) # 3. prepare text embeddings ( @@ -939,7 +936,7 @@ def invert( N = len(sigmas) # forward ODE loop - with self.progress_bar(total=N-1) as progress_bar: + with self.progress_bar(total=N - 1) as progress_bar: for i in range(N - 1): t_i = torch.tensor(i / (N), dtype=Y_t.dtype, device=device) timestep = torch.tensor(t_i, dtype=Y_t.dtype, device=device).repeat(batch_size) @@ -969,4 +966,4 @@ def invert( self.inverted_latents = Y_t self.latent_image_ids = latent_image_ids - return self.image_latents, Y_t, latent_image_ids \ No newline at end of file + return self.image_latents, Y_t, latent_image_ids From 8c4d5c1158fa146f6c7a6bcac879dede079701d5 Mon Sep 17 00:00:00 2001 From: linoytsaban Date: Wed, 27 Nov 2024 19:00:09 +0200 Subject: [PATCH 28/45] add "copied from..." comments --- examples/community/pipeline_flux_rf_inversion.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/examples/community/pipeline_flux_rf_inversion.py b/examples/community/pipeline_flux_rf_inversion.py index 77aa6d436ea0..8fb48a304ef0 100644 --- a/examples/community/pipeline_flux_rf_inversion.py +++ b/examples/community/pipeline_flux_rf_inversion.py @@ -71,7 +71,7 @@ ``` """ - +# Copied from diffusers.pipelines.flux.pipeline_flux.calculate_shift def calculate_shift( image_seq_len, base_seq_len: int = 256, @@ -211,6 +211,7 @@ def __init__( ) self.default_sample_size = 128 + # Copied from diffusers.pipelines.flux.pipeline_flux.FluxPipeline._get_t5_prompt_embeds def _get_t5_prompt_embeds( self, prompt: Union[str, List[str]] = None, From 3f37448738e7eeb7437e3e88b33f9d3331132e2b Mon Sep 17 00:00:00 2001 From: linoytsaban Date: Wed, 27 Nov 2024 20:54:46 +0200 Subject: [PATCH 29/45] credit authors --- examples/community/pipeline_flux_rf_inversion.py | 15 ++++++++++++--- 1 file changed, 12 insertions(+), 3 deletions(-) diff --git a/examples/community/pipeline_flux_rf_inversion.py b/examples/community/pipeline_flux_rf_inversion.py index 8fb48a304ef0..b51287a0f944 100644 --- a/examples/community/pipeline_flux_rf_inversion.py +++ b/examples/community/pipeline_flux_rf_inversion.py @@ -1,4 +1,6 @@ # Copyright 2024 Black Forest Labs and The HuggingFace Team. All rights reserved. +# modeled after RF Inversion: https://rf-inversion.github.io/, authored by Litu Rout, Yujia Chen, Nataniel Ruiz, +# Constantine Caramanis, Sanjay Shakkottai and Wen-Sheng Chu. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -52,7 +54,10 @@ >>> import torch >>> from diffusers import FluxPipeline - >>> pipe = FluxPipeline.from_pretrained("black-forest-labs/FLUX.1-schnell", torch_dtype=torch.bfloat16) + >>> pipe = DiffusionPipeline.from_pretrained( + ... "black-forest-labs/FLUX.1-dev", + ... torch_dtype=torch.bfloat16, + ... custom_pipeline="pipeline_flux_rf_inversion") >>> pipe.to("cuda") >>> def download_image(url): @@ -63,10 +68,14 @@ >>> img_url = "https://www.aiml.informatik.tu-darmstadt.de/people/mbrack/tennis.jpg" >>> image = download_image(img_url) - >>> _,__,___ = pipe.invert(image=image, num_inversion_steps=28) + >>> _,__,___ = pipe.invert(image=image, num_inversion_steps=28, gamma=0.5) >>> edited_image = pipe( - ... prompt="a portrait of a tiger", + ... prompt="a tomato", + ... start_timestep=0, + ... stop_timestep=.38, + ... num_inference_steps=28, + ... eta=0.9, ... ).images[0] ``` """ From e06b434a17ba2dde13db1386a9466e0af45712ed Mon Sep 17 00:00:00 2001 From: hlky Date: Fri, 6 Dec 2024 08:18:57 +0000 Subject: [PATCH 30/45] make style --- examples/community/pipeline_flux_rf_inversion.py | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/examples/community/pipeline_flux_rf_inversion.py b/examples/community/pipeline_flux_rf_inversion.py index b51287a0f944..2e526df92ec7 100644 --- a/examples/community/pipeline_flux_rf_inversion.py +++ b/examples/community/pipeline_flux_rf_inversion.py @@ -72,14 +72,15 @@ >>> edited_image = pipe( ... prompt="a tomato", - ... start_timestep=0, + ... start_timestep=0, ... stop_timestep=.38, ... num_inference_steps=28, - ... eta=0.9, + ... eta=0.9, ... ).images[0] ``` """ + # Copied from diffusers.pipelines.flux.pipeline_flux.calculate_shift def calculate_shift( image_seq_len, From e318a6701c4db7a39cf60a30bf04224e80015997 Mon Sep 17 00:00:00 2001 From: linoytsaban Date: Fri, 6 Dec 2024 15:17:24 +0200 Subject: [PATCH 31/45] return inversion outputs without self-assigning --- examples/community/pipeline_flux_rf_inversion.py | 15 +++++---------- 1 file changed, 5 insertions(+), 10 deletions(-) diff --git a/examples/community/pipeline_flux_rf_inversion.py b/examples/community/pipeline_flux_rf_inversion.py index 2e526df92ec7..fee9ae4bb80f 100644 --- a/examples/community/pipeline_flux_rf_inversion.py +++ b/examples/community/pipeline_flux_rf_inversion.py @@ -588,6 +588,9 @@ def interrupt(self): @replace_example_docstring(EXAMPLE_DOC_STRING) def __call__( self, + latents: Optional[torch.FloatTensor] = None, + image_latents: Optional[torch.FloatTensor] = None, + latent_image_ids: Optional[torch.FloatTensor] = None, prompt: Union[str, List[str]] = None, prompt_2: Optional[Union[str, List[str]]] = None, height: Optional[int] = None, @@ -601,7 +604,6 @@ def __call__( guidance_scale: float = 3.5, num_images_per_prompt: Optional[int] = 1, generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None, - latents: Optional[torch.FloatTensor] = None, prompt_embeds: Optional[torch.FloatTensor] = None, pooled_prompt_embeds: Optional[torch.FloatTensor] = None, output_type: Optional[str] = "pil", @@ -735,9 +737,6 @@ def __call__( # 4. Prepare latent variables num_channels_latents = self.transformer.config.in_channels // 4 - latents = self.inverted_latents - latent_image_ids = self.latent_image_ids - image_latents = self.image_latents # 5. Prepare timesteps sigmas = np.linspace(1.0, 1 / num_inference_steps, num_inference_steps) @@ -859,7 +858,6 @@ def invert( width: Optional[int] = None, timesteps: List[int] = None, dtype: Optional[torch.dtype] = None, - generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None, joint_attention_kwargs: Optional[Dict[str, Any]] = None, ): r""" @@ -903,7 +901,6 @@ def invert( image_latents, latent_image_ids = self.prepare_latents( batch_size, num_channels_latents, height, width, dtype, device, image_latents ) - self.image_latents = image_latents.clone() # 2. prepare timesteps sigmas = np.linspace(1.0, 1 / num_inversion_steps, num_inversion_steps) @@ -974,7 +971,5 @@ def invert( Y_t = Y_t + u_hat_t_i * (sigmas[i] - sigmas[i + 1]) progress_bar.update() - self.inverted_latents = Y_t - self.latent_image_ids = latent_image_ids - - return self.image_latents, Y_t, latent_image_ids + # return the inverted latents (start point for the denoising loop), encoded image & latent image ids + return Y_t, image_latents, latent_image_ids From d523e2bba7f887446918cc43ec1a12ef4f5af262 Mon Sep 17 00:00:00 2001 From: linoytsaban Date: Fri, 6 Dec 2024 15:33:51 +0200 Subject: [PATCH 32/45] adjust denoising loop to generate regular images if inverted latents are not provided --- .../community/pipeline_flux_rf_inversion.py | 103 +++++++++++++++--- 1 file changed, 86 insertions(+), 17 deletions(-) diff --git a/examples/community/pipeline_flux_rf_inversion.py b/examples/community/pipeline_flux_rf_inversion.py index fee9ae4bb80f..02b46d36d7ae 100644 --- a/examples/community/pipeline_flux_rf_inversion.py +++ b/examples/community/pipeline_flux_rf_inversion.py @@ -419,6 +419,9 @@ def check_inputs( self, prompt, prompt_2, + inverted_latents, + image_latents, + latent_image_ids, height, width, start_timestep, @@ -467,6 +470,10 @@ def check_inputs( if max_sequence_length is not None and max_sequence_length > 512: raise ValueError(f"`max_sequence_length` cannot be greater than 512 but is {max_sequence_length}") + if inverted_latents is not None and (image_latents is None or latent_image_ids is None): + raise ValueError( + "If `inverted_latents` are provided, `image_latents` and `latent_image_ids` also have to be passed. " + ) # check start_timestep and stop_timestep if start_timestep < 0 or start_timestep > stop_timestep: raise ValueError(f"`start_timestep` should be in [0, stop_timestep] but is {start_timestep}") @@ -536,7 +543,7 @@ def disable_vae_tiling(self): """ self.vae.disable_tiling() - def prepare_latents( + def prepare_latents_inversion( self, batch_size, num_channels_latents, @@ -555,6 +562,41 @@ def prepare_latents( return latents, latent_image_ids + def prepare_latents( + self, + batch_size, + num_channels_latents, + height, + width, + dtype, + device, + generator, + latents=None, + ): + # VAE applies 8x compression on images but we must also account for packing which requires + # latent height and width to be divisible by 2. + height = 2 * (int(height) // (self.vae_scale_factor * 2)) + width = 2 * (int(width) // (self.vae_scale_factor * 2)) + + shape = (batch_size, num_channels_latents, height, width) + + if latents is not None: + latent_image_ids = self._prepare_latent_image_ids(batch_size, height // 2, width // 2, device, dtype) + return latents.to(device=device, dtype=dtype), latent_image_ids + + if isinstance(generator, list) and len(generator) != batch_size: + raise ValueError( + f"You have passed a list of generators of length {len(generator)}, but requested an effective batch" + f" size of {batch_size}. Make sure the batch size matches the length of the generators." + ) + + latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype) + latents = self._pack_latents(latents, batch_size, num_channels_latents, height, width) + + latent_image_ids = self._prepare_latent_image_ids(batch_size, height // 2, width // 2, device, dtype) + + return latents, latent_image_ids + # Copied from diffusers.pipelines.stable_diffusion_3.pipeline_stable_diffusion_3_img2img.StableDiffusion3Img2ImgPipeline.get_timesteps def get_timesteps(self, num_inference_steps, strength=1.0): # get the original timestep using init_timestep @@ -588,11 +630,11 @@ def interrupt(self): @replace_example_docstring(EXAMPLE_DOC_STRING) def __call__( self, - latents: Optional[torch.FloatTensor] = None, - image_latents: Optional[torch.FloatTensor] = None, - latent_image_ids: Optional[torch.FloatTensor] = None, prompt: Union[str, List[str]] = None, prompt_2: Optional[Union[str, List[str]]] = None, + inverted_latents: Optional[torch.FloatTensor] = None, + image_latents: Optional[torch.FloatTensor] = None, + latent_image_ids: Optional[torch.FloatTensor] = None, height: Optional[int] = None, width: Optional[int] = None, eta: float = 1.0, @@ -604,6 +646,7 @@ def __call__( guidance_scale: float = 3.5, num_images_per_prompt: Optional[int] = 1, generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None, + latents: Optional[torch.FloatTensor] = None, prompt_embeds: Optional[torch.FloatTensor] = None, pooled_prompt_embeds: Optional[torch.FloatTensor] = None, output_type: Optional[str] = "pil", @@ -693,6 +736,9 @@ def __call__( self.check_inputs( prompt, prompt_2, + inverted_latents, + image_latents, + latent_image_ids, height, width, start_timestep, @@ -706,6 +752,7 @@ def __call__( self._guidance_scale = guidance_scale self._joint_attention_kwargs = joint_attention_kwargs self._interrupt = False + do_rf_inversion = inverted_latents is not None # 2. Define call parameters if prompt is not None and isinstance(prompt, str): @@ -737,6 +784,19 @@ def __call__( # 4. Prepare latent variables num_channels_latents = self.transformer.config.in_channels // 4 + if do_rf_inversion: + latents = inverted_latents + else: + latents, latent_image_ids = self.prepare_latents( + batch_size * num_images_per_prompt, + num_channels_latents, + height, + width, + prompt_embeds.dtype, + device, + generator, + latents, + ) # 5. Prepare timesteps sigmas = np.linspace(1.0, 1 / num_inference_steps, num_inference_steps) @@ -769,9 +829,11 @@ def __call__( else: guidance = None + if do_rf_inversion: + y_0 = image_latents.clone() # 6. Denoising loop with self.progress_bar(total=num_inference_steps) as progress_bar: - y_0 = image_latents.clone() + for i, t in enumerate(timesteps): t_i = 1 - t / 1000 dt = torch.tensor(1 / (len(timesteps) - 1), device=device) @@ -782,7 +844,7 @@ def __call__( # broadcast to batch dimension in a way that's compatible with ONNX/Core ML timestep = t.expand(latents.shape[0]).to(latents.dtype) - v_t = -self.transformer( + noise_pred = self.transformer( hidden_states=latents, timestep=timestep / 1000, guidance=guidance, @@ -794,18 +856,25 @@ def __call__( return_dict=False, )[0] - v_t_cond = (y_0 - latents) / (1 - t_i) - eta_t = eta if start_timestep <= i < stop_timestep else 0.0 - if start_timestep <= i < stop_timestep: - # controlled vector field - v_hat_t = v_t + eta * (v_t_cond - v_t) + if do_rf_inversion: + v_t = -noise_pred - else: - v_hat_t = v_t - # SDE Eq: 17 + v_t_cond = (y_0 - latents) / (1 - t_i) + eta_t = eta if start_timestep <= i < stop_timestep else 0.0 + if start_timestep <= i < stop_timestep: + # controlled vector field + v_hat_t = v_t + eta * (v_t_cond - v_t) - latents_dtype = latents.dtype - latents = latents + v_hat_t * (sigmas[i] - sigmas[i + 1]) + else: + v_hat_t = v_t + # SDE Eq: 17 + + latents_dtype = latents.dtype + latents = latents + v_hat_t * (sigmas[i] - sigmas[i + 1]) + else: + # compute the previous noisy sample x_t -> x_t-1 + latents_dtype = latents.dtype + latents = self.scheduler.step(noise_pred, t, latents, return_dict=False)[0] if latents.dtype != latents_dtype: if torch.backends.mps.is_available(): @@ -898,7 +967,7 @@ def invert( # 1. prepare image image_latents, _ = self.encode_image(image, height=height, width=width, dtype=dtype) - image_latents, latent_image_ids = self.prepare_latents( + image_latents, latent_image_ids = self.prepare_latents_inversion( batch_size, num_channels_latents, height, width, dtype, device, image_latents ) From 6d98da3bb3cb5a40a55e1ecc572b96ce9df7d2bc Mon Sep 17 00:00:00 2001 From: linoytsaban Date: Fri, 6 Dec 2024 15:36:30 +0200 Subject: [PATCH 33/45] adjust denoising loop to generate regular images if inverted latents are not provided --- examples/community/pipeline_flux_rf_inversion.py | 15 ++++++++------- 1 file changed, 8 insertions(+), 7 deletions(-) diff --git a/examples/community/pipeline_flux_rf_inversion.py b/examples/community/pipeline_flux_rf_inversion.py index 02b46d36d7ae..c7498f91cf59 100644 --- a/examples/community/pipeline_flux_rf_inversion.py +++ b/examples/community/pipeline_flux_rf_inversion.py @@ -799,7 +799,7 @@ def __call__( ) # 5. Prepare timesteps - sigmas = np.linspace(1.0, 1 / num_inference_steps, num_inference_steps) + sigmas = np.linspace(1.0, 1 / num_inference_steps, num_inference_steps) if sigmas is None else sigmas image_seq_len = (int(height) // self.vae_scale_factor // 2) * (int(width) // self.vae_scale_factor // 2) mu = calculate_shift( image_seq_len, @@ -816,9 +816,10 @@ def __call__( sigmas, mu=mu, ) - start_timestep = int(start_timestep * num_inference_steps) - stop_timestep = min(int(stop_timestep * num_inference_steps), num_inference_steps) - timesteps, sigmas, num_inference_steps = self.get_timesteps(num_inference_steps, strength) + if do_rf_inversion: + start_timestep = int(start_timestep * num_inference_steps) + stop_timestep = min(int(stop_timestep * num_inference_steps), num_inference_steps) + timesteps, sigmas, num_inference_steps = self.get_timesteps(num_inference_steps, strength) num_warmup_steps = max(len(timesteps) - num_inference_steps * self.scheduler.order, 0) self._num_timesteps = len(timesteps) @@ -833,10 +834,10 @@ def __call__( y_0 = image_latents.clone() # 6. Denoising loop with self.progress_bar(total=num_inference_steps) as progress_bar: - for i, t in enumerate(timesteps): - t_i = 1 - t / 1000 - dt = torch.tensor(1 / (len(timesteps) - 1), device=device) + if do_rf_inversion: + t_i = 1 - t / 1000 + dt = torch.tensor(1 / (len(timesteps) - 1), device=device) if self.interrupt: continue From 37b70a716562a0eb0a60fe130a2721777593d739 Mon Sep 17 00:00:00 2001 From: linoytsaban Date: Fri, 6 Dec 2024 15:52:11 +0200 Subject: [PATCH 34/45] fix import --- examples/community/pipeline_flux_rf_inversion.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/examples/community/pipeline_flux_rf_inversion.py b/examples/community/pipeline_flux_rf_inversion.py index c7498f91cf59..9b46dfb24aba 100644 --- a/examples/community/pipeline_flux_rf_inversion.py +++ b/examples/community/pipeline_flux_rf_inversion.py @@ -36,6 +36,7 @@ scale_lora_layers, unscale_lora_layers, ) +from diffusers.utils.torch_utils import randn_tensor if is_torch_xla_available(): @@ -642,6 +643,7 @@ def __call__( start_timestep: float = 0, stop_timestep: float = 0.25, num_inference_steps: int = 28, + sigmas: Optional[List[float]] = None, timesteps: List[int] = None, guidance_scale: float = 3.5, num_images_per_prompt: Optional[int] = 1, From c4ce7401857640d59330b78188e28ab5297b54e7 Mon Sep 17 00:00:00 2001 From: linoytsaban Date: Fri, 6 Dec 2024 17:30:33 +0200 Subject: [PATCH 35/45] comment --- examples/community/pipeline_flux_rf_inversion.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/examples/community/pipeline_flux_rf_inversion.py b/examples/community/pipeline_flux_rf_inversion.py index 9b46dfb24aba..0c8b40ebf276 100644 --- a/examples/community/pipeline_flux_rf_inversion.py +++ b/examples/community/pipeline_flux_rf_inversion.py @@ -870,8 +870,8 @@ def __call__( else: v_hat_t = v_t - # SDE Eq: 17 + # SDE Eq: 17 from https://arxiv.org/pdf/2410.10792 latents_dtype = latents.dtype latents = latents + v_hat_t * (sigmas[i] - sigmas[i + 1]) else: From a61c93fed3846d7f85fe7afb86527707d610333b Mon Sep 17 00:00:00 2001 From: linoytsaban Date: Fri, 6 Dec 2024 17:31:25 +0200 Subject: [PATCH 36/45] remove redundant line --- examples/community/pipeline_flux_rf_inversion.py | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/examples/community/pipeline_flux_rf_inversion.py b/examples/community/pipeline_flux_rf_inversion.py index 0c8b40ebf276..c153913031bc 100644 --- a/examples/community/pipeline_flux_rf_inversion.py +++ b/examples/community/pipeline_flux_rf_inversion.py @@ -859,6 +859,7 @@ def __call__( return_dict=False, )[0] + latents_dtype = latents.dtype if do_rf_inversion: v_t = -noise_pred @@ -872,11 +873,9 @@ def __call__( v_hat_t = v_t # SDE Eq: 17 from https://arxiv.org/pdf/2410.10792 - latents_dtype = latents.dtype latents = latents + v_hat_t * (sigmas[i] - sigmas[i + 1]) else: # compute the previous noisy sample x_t -> x_t-1 - latents_dtype = latents.dtype latents = self.scheduler.step(noise_pred, t, latents, return_dict=False)[0] if latents.dtype != latents_dtype: From 9116a056ec12665f849359f01892dc04e142c91f Mon Sep 17 00:00:00 2001 From: linoytsaban Date: Fri, 6 Dec 2024 18:01:02 +0200 Subject: [PATCH 37/45] modify comment on ti --- examples/community/pipeline_flux_rf_inversion.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/examples/community/pipeline_flux_rf_inversion.py b/examples/community/pipeline_flux_rf_inversion.py index c153913031bc..c1fed43150fe 100644 --- a/examples/community/pipeline_flux_rf_inversion.py +++ b/examples/community/pipeline_flux_rf_inversion.py @@ -834,10 +834,11 @@ def __call__( if do_rf_inversion: y_0 = image_latents.clone() - # 6. Denoising loop + # 6. Denoising loop / Controlled Reverse ODE, Algorithm 2 from: https://arxiv.org/pdf/2410.10792 with self.progress_bar(total=num_inference_steps) as progress_bar: for i, t in enumerate(timesteps): if do_rf_inversion: + # ti (current timestep) as annotated in algorithm 2 - i/num_inference_steps. t_i = 1 - t / 1000 dt = torch.tensor(1 / (len(timesteps) - 1), device=device) @@ -862,7 +863,6 @@ def __call__( latents_dtype = latents.dtype if do_rf_inversion: v_t = -noise_pred - v_t_cond = (y_0 - latents) / (1 - t_i) eta_t = eta if start_timestep <= i < stop_timestep else 0.0 if start_timestep <= i < stop_timestep: From e87d45a865115ab4d8d305c44518c6d5d3a43a73 Mon Sep 17 00:00:00 2001 From: Linoy Tsaban <57615435+linoytsaban@users.noreply.github.com> Date: Mon, 9 Dec 2024 18:42:10 +0200 Subject: [PATCH 38/45] Update examples/community/pipeline_flux_rf_inversion.py Co-authored-by: hlky --- examples/community/pipeline_flux_rf_inversion.py | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/examples/community/pipeline_flux_rf_inversion.py b/examples/community/pipeline_flux_rf_inversion.py index c1fed43150fe..4d27d79287e7 100644 --- a/examples/community/pipeline_flux_rf_inversion.py +++ b/examples/community/pipeline_flux_rf_inversion.py @@ -69,10 +69,13 @@ >>> img_url = "https://www.aiml.informatik.tu-darmstadt.de/people/mbrack/tennis.jpg" >>> image = download_image(img_url) - >>> _,__,___ = pipe.invert(image=image, num_inversion_steps=28, gamma=0.5) + >>> inverted_latents, image_latents, latent_image_ids = pipe.invert(image=image, num_inversion_steps=28, gamma=0.5) >>> edited_image = pipe( ... prompt="a tomato", + ... inverted_latents=inverted_latents, + ... image_latents=image_latents, + ... latent_image_ids=latent_image_ids, ... start_timestep=0, ... stop_timestep=.38, ... num_inference_steps=28, From f8aeb1681cfde0019cc146b3687e3a532daf2f0e Mon Sep 17 00:00:00 2001 From: Linoy Tsaban <57615435+linoytsaban@users.noreply.github.com> Date: Mon, 9 Dec 2024 18:42:45 +0200 Subject: [PATCH 39/45] Update examples/community/pipeline_flux_rf_inversion.py Co-authored-by: hlky --- examples/community/pipeline_flux_rf_inversion.py | 3 +++ 1 file changed, 3 insertions(+) diff --git a/examples/community/pipeline_flux_rf_inversion.py b/examples/community/pipeline_flux_rf_inversion.py index 4d27d79287e7..003722c3b90e 100644 --- a/examples/community/pipeline_flux_rf_inversion.py +++ b/examples/community/pipeline_flux_rf_inversion.py @@ -80,6 +80,9 @@ ... stop_timestep=.38, ... num_inference_steps=28, ... eta=0.9, + ... stop_timestep=.38, + ... num_inference_steps=28, + ... eta=0.9, ... ).images[0] ``` """ From f290b242b8b9a950eccca3d75c2cc4bcff91befb Mon Sep 17 00:00:00 2001 From: Linoy Tsaban <57615435+linoytsaban@users.noreply.github.com> Date: Mon, 9 Dec 2024 18:43:00 +0200 Subject: [PATCH 40/45] Update examples/community/pipeline_flux_rf_inversion.py Co-authored-by: hlky --- examples/community/pipeline_flux_rf_inversion.py | 1 + 1 file changed, 1 insertion(+) diff --git a/examples/community/pipeline_flux_rf_inversion.py b/examples/community/pipeline_flux_rf_inversion.py index 003722c3b90e..ab9d6de00559 100644 --- a/examples/community/pipeline_flux_rf_inversion.py +++ b/examples/community/pipeline_flux_rf_inversion.py @@ -322,6 +322,7 @@ def _get_clip_prompt_embeds( return prompt_embeds + # Copied from diffusers.pipelines.flux.pipeline_flux.FluxPipeline.encode_prompt def encode_prompt( self, prompt: Union[str, List[str]], From e42dd6bcf322147f3b5cd3a00ae5ec46e88153e3 Mon Sep 17 00:00:00 2001 From: Linoy Tsaban <57615435+linoytsaban@users.noreply.github.com> Date: Mon, 9 Dec 2024 18:43:29 +0200 Subject: [PATCH 41/45] Update examples/community/pipeline_flux_rf_inversion.py Co-authored-by: hlky --- examples/community/pipeline_flux_rf_inversion.py | 1 + 1 file changed, 1 insertion(+) diff --git a/examples/community/pipeline_flux_rf_inversion.py b/examples/community/pipeline_flux_rf_inversion.py index ab9d6de00559..90a85e1f6ba6 100644 --- a/examples/community/pipeline_flux_rf_inversion.py +++ b/examples/community/pipeline_flux_rf_inversion.py @@ -278,6 +278,7 @@ def _get_t5_prompt_embeds( return prompt_embeds + # Copied from diffusers.pipelines.flux.pipeline_flux.FluxPipeline._get_clip_prompt_embeds def _get_clip_prompt_embeds( self, prompt: Union[str, List[str]], From 0036ae2a3c13f86747eab1fb91e452110f57e121 Mon Sep 17 00:00:00 2001 From: Linoy Tsaban <57615435+linoytsaban@users.noreply.github.com> Date: Mon, 9 Dec 2024 18:44:27 +0200 Subject: [PATCH 42/45] Update examples/community/pipeline_flux_rf_inversion.py Co-authored-by: hlky --- examples/community/pipeline_flux_rf_inversion.py | 1 + 1 file changed, 1 insertion(+) diff --git a/examples/community/pipeline_flux_rf_inversion.py b/examples/community/pipeline_flux_rf_inversion.py index 90a85e1f6ba6..38bb08e4c452 100644 --- a/examples/community/pipeline_flux_rf_inversion.py +++ b/examples/community/pipeline_flux_rf_inversion.py @@ -279,6 +279,7 @@ def _get_t5_prompt_embeds( return prompt_embeds # Copied from diffusers.pipelines.flux.pipeline_flux.FluxPipeline._get_clip_prompt_embeds + def _get_clip_prompt_embeds( def _get_clip_prompt_embeds( self, prompt: Union[str, List[str]], From 074cf5adfe4ab331a4b882c2a2824defc23f96df Mon Sep 17 00:00:00 2001 From: Linoy Tsaban <57615435+linoytsaban@users.noreply.github.com> Date: Mon, 9 Dec 2024 18:45:39 +0200 Subject: [PATCH 43/45] Update examples/community/pipeline_flux_rf_inversion.py Co-authored-by: hlky --- examples/community/pipeline_flux_rf_inversion.py | 1 + 1 file changed, 1 insertion(+) diff --git a/examples/community/pipeline_flux_rf_inversion.py b/examples/community/pipeline_flux_rf_inversion.py index 38bb08e4c452..ae1a44e3fdf5 100644 --- a/examples/community/pipeline_flux_rf_inversion.py +++ b/examples/community/pipeline_flux_rf_inversion.py @@ -572,6 +572,7 @@ def prepare_latents_inversion( return latents, latent_image_ids + # Copied from diffusers.pipelines.flux.pipeline_flux.FluxPipeline.prepare_latents def prepare_latents( self, batch_size, From 5f1407dc8e49b3d1a450ebe1e76ea60ea9b4d023 Mon Sep 17 00:00:00 2001 From: Linoy Tsaban <57615435+linoytsaban@users.noreply.github.com> Date: Mon, 9 Dec 2024 18:46:18 +0200 Subject: [PATCH 44/45] Update examples/community/pipeline_flux_rf_inversion.py Co-authored-by: hlky --- examples/community/pipeline_flux_rf_inversion.py | 6 ++++++ 1 file changed, 6 insertions(+) diff --git a/examples/community/pipeline_flux_rf_inversion.py b/examples/community/pipeline_flux_rf_inversion.py index ae1a44e3fdf5..0a208d4b4b28 100644 --- a/examples/community/pipeline_flux_rf_inversion.py +++ b/examples/community/pipeline_flux_rf_inversion.py @@ -678,6 +678,12 @@ def __call__( prompt_2 (`str` or `List[str]`, *optional*): The prompt or prompts to be sent to `tokenizer_2` and `text_encoder_2`. If not defined, `prompt` is will be used instead + inverted_latents (`torch.Tensor`, *optional*): + The inverted latents from `pipe.invert`. + image_latents (`torch.Tensor`, *optional*): + The image latents from `pipe.invert`. + latent_image_ids (`torch.Tensor`, *optional*): + The latent image ids from `pipe.invert`. height (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor): The height in pixels of the generated image. This is set to 1024 by default for the best results. width (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor): From 1cb1ce62871406d4a7ca36d5ce1eca182129257b Mon Sep 17 00:00:00 2001 From: linoytsaban Date: Mon, 9 Dec 2024 18:52:16 +0200 Subject: [PATCH 45/45] fix syntax error --- examples/community/pipeline_flux_rf_inversion.py | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/examples/community/pipeline_flux_rf_inversion.py b/examples/community/pipeline_flux_rf_inversion.py index 0a208d4b4b28..7f5f1b02695e 100644 --- a/examples/community/pipeline_flux_rf_inversion.py +++ b/examples/community/pipeline_flux_rf_inversion.py @@ -279,7 +279,6 @@ def _get_t5_prompt_embeds( return prompt_embeds # Copied from diffusers.pipelines.flux.pipeline_flux.FluxPipeline._get_clip_prompt_embeds - def _get_clip_prompt_embeds( def _get_clip_prompt_embeds( self, prompt: Union[str, List[str]], @@ -723,7 +722,7 @@ def __call__( The output format of the generate image. Choose between [PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`. return_dict (`bool`, *optional*, defaults to `True`): - Whether or not to return a [`~pipelines.flux.FluxPipelineOutput`] instead of a plain tuple. + Whether to return a [`~pipelines.flux.FluxPipelineOutput`] instead of a plain tuple. joint_attention_kwargs (`dict`, *optional*): A kwargs dictionary that if specified is passed along to the `AttentionProcessor` as defined under `self.processor` in