From 825fdfd1f9b53187a7da32dde341f25911f7de2d Mon Sep 17 00:00:00 2001
From: YibLiu <68105073+YibinLiu666@users.noreply.github.com>
Date: Thu, 18 Jan 2024 15:16:55 +0800
Subject: [PATCH] =?UTF-8?q?=E6=96=B0=E5=A2=9EStableDiffusionXL=20Reference?=
=?UTF-8?q?=20Control=E6=94=AF=E6=8C=81=20(#369)?=
MIME-Version: 1.0
Content-Type: text/plain; charset=UTF-8
Content-Transfer-Encoding: 8bit
https://github.com/PaddlePaddle/PaddleMIX/issues/252
---
ppdiffusers/examples/community/README.md | 33 +
.../pipline_stable_diffusion_xl_reference.py | 821 ++++++++++++++++++
2 files changed, 854 insertions(+)
create mode 100644 ppdiffusers/examples/community/pipline_stable_diffusion_xl_reference.py
diff --git a/ppdiffusers/examples/community/README.md b/ppdiffusers/examples/community/README.md
index f3d959b0a..e2f1f3ca1 100644
--- a/ppdiffusers/examples/community/README.md
+++ b/ppdiffusers/examples/community/README.md
@@ -11,6 +11,7 @@
|AUTOMATIC1111 WebUI Stable Diffusion| 与AUTOMATIC1111的WebUI基本一致的Pipeline |[AUTOMATIC1111 WebUI Stable Diffusion](#automatic1111-webui-stable-diffusion)||
|Stable Diffusion with High Resolution Fixing| 使用高分辨率修复功能进行文图生成|[Stable Diffusion with High Resolution Fixing](#stable-diffusion-with-high-resolution-fixing)||
|ControlNet Reference Only| 基于参考图片生成与图片相似的图片|[ControlNet Reference Only](#controlnet-reference-only)||
+|Stable Diffusion XL Reference| 基于参考图片,利用stable diffusion xl 生成与图片相似的图片|[Stable Diffusion XL Reference](#Stable Diffusion XL Reference)||
|Stable Diffusion Mixture Tiling| 基于Mixture机制的多文本大图生成Stable Diffusion Pipeline|[Stable Diffusion Mixture Tiling](#stable-diffusion-mixture-tiling)||
|CLIP Guided Images Mixing Stable Diffusion Pipeline| 一个用于图片融合的Stable Diffusion Pipeline|[CLIP Guided Images Mixing Using Stable Diffusion](#clip-guided-images-mixing-with-stable-diffusion)||
|EDICT Image Editing Pipeline| 一个用于文本引导的图像编辑的 Stable Diffusion Pipeline|[EDICT Image Editing Pipeline](#edict_pipeline)||
@@ -524,6 +525,38 @@ for control_name in ["none", "reference_only", "reference_adain", "reference_ada
[reference_adain]: https://github.com/PaddlePaddle/PaddleNLP/assets/50394665/266968c7-5065-4589-9bd8-47515d50c6de
[reference_adain+attn]: https://github.com/PaddlePaddle/PaddleNLP/assets/50394665/73d53a4f-e601-4969-9cb8-e3fdf719ae0c
+### Stable Diffusion XL Reference
+[Stable Diffusion XL Reference](https://github.com/Mikubill/sd-webui-controlnet#reference-only-control) 是一种基于stable diffusion xl不需要任何控制模型就可以直接使用图像作为参考来引导生成图像的方法。它使用方式如下所示:
+
+```python
+import paddle
+from PIL import Image
+from ppdiffusers.utils import load_image
+from pipline_stable_diffusion_xl_reference import StableDiffusionXLReferencePipeline
+from ppdiffusers.schedulers import UniPCMultistepScheduler
+
+input_image = load_image("https://github.com/PaddlePaddle/PaddleMIX/assets/68105073/9c8e5c53-dc9a-46bb-9504-3d75a7c22ed2")
+
+pipe = StableDiffusionXLReferencePipeline.from_pretrained(
+ "stabilityai/stable-diffusion-xl-base-1.0",
+ use_safetensors=True,
+ variant="fp16")
+
+pipe.scheduler = UniPCMultistepScheduler.from_config(pipe.scheduler.config)
+
+result_img = pipe(ref_image=input_image,
+ prompt="a dog running on grassland, best quality",
+ num_inference_steps=20,
+ reference_attn=True,
+ reference_adain=False).images[0]
+
+result_img.save("output.png")
+```
+参考图片:
+
+
+生成的图片如下所示:
+
### Stable Diffusion Mixture Tiling
`StableDiffusionTilingPipeline`是一个基于Mixture机制的多文本大图生成Stable Diffusion Pipeline。使用方式如下所示:
diff --git a/ppdiffusers/examples/community/pipline_stable_diffusion_xl_reference.py b/ppdiffusers/examples/community/pipline_stable_diffusion_xl_reference.py
new file mode 100644
index 000000000..260b59541
--- /dev/null
+++ b/ppdiffusers/examples/community/pipline_stable_diffusion_xl_reference.py
@@ -0,0 +1,821 @@
+# Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
+# Copyright 2023 The HuggingFace Team. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+from typing import Any, Callable, Dict, List, Optional, Tuple, Union
+
+import numpy as np
+import paddle
+import PIL.Image
+
+from ppdiffusers.models.attention import BasicTransformerBlock
+from ppdiffusers.models.unet_2d_blocks import (
+ CrossAttnDownBlock2D,
+ CrossAttnUpBlock2D,
+ DownBlock2D,
+ UpBlock2D,
+)
+from ppdiffusers.pipelines.stable_diffusion_xl import (
+ StableDiffusionXLPipeline,
+ StableDiffusionXLPipelineOutput,
+)
+from ppdiffusers.utils import PIL_INTERPOLATION, logging
+from ppdiffusers.utils.paddle_utils import randn_tensor
+
+logger = logging.get_logger(__name__) # pylint: disable=invalid-name
+
+EXAMPLE_DOC_STRING = """
+ Examples:
+ ```py
+ >>> import paddle
+ >>> from PIL import Image
+ >>> from ppdiffusers.utils import load_image
+ >>> from pipline_stable_diffusion_xl_reference import StableDiffusionXLReferencePipeline
+ >>> from ppdiffusers.schedulers import UniPCMultistepScheduler
+
+ >>> input_image = load_image("https://raw.githubusercontent.com/Mikubill/sd-webui-controlnet/main/samples/dog_rel.png")
+
+ >>> pipe = StableDiffusionXLReferencePipeline.from_pretrained(
+ "stabilityai/stable-diffusion-xl-base-1.0",
+ use_safetensors=True,
+ variant="fp16")
+
+ >>> pipe.scheduler = UniPCMultistepScheduler.from_config(pipe.scheduler.config)
+ >>> result_img = pipe(ref_image=input_image,
+ prompt="a dog running on grassland, best quality",
+ num_inference_steps=20,
+ reference_attn=True,
+ reference_adain=False).images[0]
+
+ >>> result_img.save("output.png")
+ ```
+"""
+
+
+def paddle_dfs(model: paddle.nn.Layer):
+ result = [model]
+ for child in model.children():
+ result += paddle_dfs(child)
+ return result
+
+
+def rescale_noise_cfg(noise_cfg, noise_pred_text, guidance_rescale=0.0):
+ """
+ Rescale `noise_cfg` according to `guidance_rescale`. Based on findings of [Common Diffusion Noise Schedules and
+ Sample Steps are Flawed](https://arxiv.org/pdf/2305.08891.pdf). See Section 3.4
+ """
+ std_text = noise_pred_text.std(axis=list(range(1, noise_pred_text.ndim)), keepdim=True)
+ std_cfg = noise_cfg.std(axis=list(range(1, noise_cfg.ndim)), keepdim=True)
+ # rescale the results from guidance (fixes overexposure)
+ noise_pred_rescaled = noise_cfg * (std_text / std_cfg)
+ # mix with the original results from guidance by factor guidance_rescale to avoid "plain looking" images
+ noise_cfg = guidance_rescale * noise_pred_rescaled + (1 - guidance_rescale) * noise_cfg
+ return noise_cfg
+
+
+class StableDiffusionXLReferencePipeline(StableDiffusionXLPipeline):
+ def _default_height_width(self, height, width, image):
+ # NOTE: It is possible that a list of images have different
+ # dimensions for each image, so just checking the first image
+ # is not _exactly_ correct, but it is simple.
+ while isinstance(image, list):
+ image = image[0]
+
+ if height is None:
+ if isinstance(image, PIL.Image.Image):
+ height = image.height
+ elif isinstance(image, paddle.Tensor):
+ height = image.shape[2]
+
+ height = (height // 8) * 8 # round down to nearest multiple of 8
+
+ if width is None:
+ if isinstance(image, PIL.Image.Image):
+ width = image.width
+ elif isinstance(image, paddle.Tensor):
+ width = image.shape[3]
+
+ width = (width // 8) * 8
+
+ return height, width
+
+ def prepare_image(
+ self,
+ image,
+ width,
+ height,
+ batch_size,
+ num_images_per_prompt,
+ device,
+ dtype,
+ do_classifier_free_guidance=False,
+ guess_mode=False,
+ ):
+ if not isinstance(image, paddle.Tensor):
+ if isinstance(image, PIL.Image.Image):
+ image = [image]
+
+ if isinstance(image[0], PIL.Image.Image):
+ images = []
+
+ for image_ in image:
+ image_ = image_.convert("RGB")
+ image_ = image_.resize((width, height), resample=PIL_INTERPOLATION["lanczos"])
+ image_ = np.array(image_)
+ image_ = image_[None, :]
+ images.append(image_)
+
+ image = images
+
+ image = np.concatenate(image, axis=0)
+ image = np.array(image).astype(np.float32) / 255.0
+ image = (image - 0.5) / 0.5
+ image = image.transpose(0, 3, 1, 2)
+ image = paddle.to_tensor(image)
+
+ elif isinstance(image[0], paddle.Tensor):
+ image = paddle.stack(image, axis=0)
+
+ image_batch_size = image.shape[0]
+
+ if image_batch_size == 1:
+ repeat_by = batch_size
+ else:
+ repeat_by = num_images_per_prompt
+
+ image = image.repeat_interleave(repeat_by, axis=0)
+
+ image = image._to(device=device, dtype=dtype)
+
+ if do_classifier_free_guidance and not guess_mode:
+ image = paddle.concat([image] * 2)
+
+ return image
+
+ def prepare_ref_latents(self, refimage, batch_size, dtype, device, generator, do_classifier_free_guidance):
+ refimage = refimage._to(device=device)
+ if self.vae.dtype == paddle.float16 and self.vae.config.force_upcast:
+ self.upcast_vae()
+ refimage = refimage.cast(next(iter(self.vae.post_quant_conv.parameters())).dtype)
+ if refimage.dtype != self.vae.dtype:
+ refimage = refimage.cast(self.vae.dtype)
+ # encode the mask image into latents space so we can concatenate it to the latents
+ if isinstance(generator, list):
+ ref_image_latents = [
+ self.vae.encode(refimage[i : i + 1]).latent_dist.sample(generator=generator[i])
+ for i in range(batch_size)
+ ]
+ ref_image_latents = paddle.concat(ref_image_latents, axis=0)
+ else:
+ ref_image_latents = self.vae.encode(refimage).latent_dist.sample(generator=generator)
+ ref_image_latents = self.vae.config.scaling_factor * ref_image_latents
+
+ # duplicate mask and ref_image_latents for each generation per prompt, using mps friendly method
+ if ref_image_latents.shape[0] < batch_size:
+ if not batch_size % ref_image_latents.shape[0] == 0:
+ raise ValueError(
+ "The passed images and the required batch size don't match. Images are supposed to be duplicated"
+ f" to a total batch size of {batch_size}, but {ref_image_latents.shape[0]} images were passed."
+ " Make sure the number of images that you pass is divisible by the total requested batch size."
+ )
+ ref_image_latents = ref_image_latents.tile((batch_size // ref_image_latents.shape[0], 1, 1, 1))
+
+ ref_image_latents = (
+ paddle.concat([ref_image_latents] * 2) if do_classifier_free_guidance else ref_image_latents
+ )
+
+ # aligning device to prevent device errors when concating it with the latent model input
+ ref_image_latents = ref_image_latents._to(device=device, dtype=dtype)
+ return ref_image_latents
+
+ @paddle.no_grad()
+ def __call__(
+ self,
+ prompt: Union[str, List[str]] = None,
+ prompt_2: Optional[Union[str, List[str]]] = None,
+ ref_image: Union[paddle.Tensor, PIL.Image.Image] = None,
+ height: Optional[int] = None,
+ width: Optional[int] = None,
+ num_inference_steps: int = 50,
+ denoising_end: Optional[float] = None,
+ guidance_scale: float = 5.0,
+ negative_prompt: Optional[Union[str, List[str]]] = None,
+ negative_prompt_2: Optional[Union[str, List[str]]] = None,
+ num_images_per_prompt: Optional[int] = 1,
+ eta: float = 0.0,
+ generator: Optional[Union[paddle.Generator, List[paddle.Generator]]] = None,
+ latents: Optional[paddle.Tensor] = None,
+ prompt_embeds: Optional[paddle.Tensor] = None,
+ negative_prompt_embeds: Optional[paddle.Tensor] = None,
+ pooled_prompt_embeds: Optional[paddle.Tensor] = None,
+ negative_pooled_prompt_embeds: Optional[paddle.Tensor] = None,
+ output_type: Optional[str] = "pil",
+ return_dict: bool = True,
+ callback: Optional[Callable[[int, int, paddle.Tensor], None]] = None,
+ callback_steps: int = 1,
+ cross_attention_kwargs: Optional[Dict[str, Any]] = None,
+ guidance_rescale: float = 0.0,
+ original_size: Optional[Tuple[int, int]] = None,
+ crops_coords_top_left: Tuple[int, int] = (0, 0),
+ target_size: Optional[Tuple[int, int]] = None,
+ attention_auto_machine_weight: float = 1.0,
+ gn_auto_machine_weight: float = 1.0,
+ style_fidelity: float = 0.5,
+ reference_attn: bool = True,
+ reference_adain: bool = True,
+ ):
+ assert reference_attn or reference_adain, "`reference_attn` or `reference_adain` must be True."
+
+ # 0. Default height and width to unet
+ # height, width = self._default_height_width(height, width, ref_image)
+
+ height = height or self.default_sample_size * self.vae_scale_factor
+ width = width or self.default_sample_size * self.vae_scale_factor
+ original_size = original_size or (height, width)
+ target_size = target_size or (height, width)
+
+ # 1. Check inputs. Raise error if not correct
+ self.check_inputs(
+ prompt,
+ prompt_2,
+ height,
+ width,
+ callback_steps,
+ negative_prompt,
+ negative_prompt_2,
+ prompt_embeds,
+ negative_prompt_embeds,
+ pooled_prompt_embeds,
+ negative_pooled_prompt_embeds,
+ )
+
+ # 2. Define call parameters
+ if prompt is not None and isinstance(prompt, str):
+ batch_size = 1
+ elif prompt is not None and isinstance(prompt, list):
+ batch_size = len(prompt)
+ else:
+ batch_size = prompt_embeds.shape[0]
+
+ device = paddle.get_device()
+
+ # here `guidance_scale` is defined analog to the guidance weight `w` of equation (2)
+ # of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1`
+ # corresponds to doing no classifier free guidance.
+ do_classifier_free_guidance = guidance_scale > 1.0
+
+ # 3. Encode input prompt
+ text_encoder_lora_scale = (
+ cross_attention_kwargs.get("scale", None) if cross_attention_kwargs is not None else None
+ )
+ (
+ prompt_embeds,
+ negative_prompt_embeds,
+ pooled_prompt_embeds,
+ negative_pooled_prompt_embeds,
+ ) = self.encode_prompt(
+ prompt=prompt,
+ prompt_2=prompt_2,
+ num_images_per_prompt=num_images_per_prompt,
+ do_classifier_free_guidance=do_classifier_free_guidance,
+ negative_prompt=negative_prompt,
+ negative_prompt_2=negative_prompt_2,
+ prompt_embeds=prompt_embeds,
+ negative_prompt_embeds=negative_prompt_embeds,
+ pooled_prompt_embeds=pooled_prompt_embeds,
+ negative_pooled_prompt_embeds=negative_pooled_prompt_embeds,
+ lora_scale=text_encoder_lora_scale,
+ )
+ # 4. Preprocess reference image
+ ref_image = self.prepare_image(
+ image=ref_image,
+ width=width,
+ height=height,
+ batch_size=batch_size * num_images_per_prompt,
+ num_images_per_prompt=num_images_per_prompt,
+ device=device,
+ dtype=prompt_embeds.dtype,
+ )
+
+ # 5. Prepare timesteps
+ self.scheduler.set_timesteps(num_inference_steps)
+
+ timesteps = self.scheduler.timesteps
+
+ # 6. Prepare latent variables
+ num_channels_latents = self.unet.config.in_channels
+ latents = self.prepare_latents(
+ batch_size * num_images_per_prompt,
+ num_channels_latents,
+ height,
+ width,
+ prompt_embeds.dtype,
+ generator,
+ latents,
+ )
+ # 7. Prepare reference latent variables
+ ref_image_latents = self.prepare_ref_latents(
+ ref_image,
+ batch_size * num_images_per_prompt,
+ prompt_embeds.dtype,
+ device,
+ generator,
+ do_classifier_free_guidance,
+ )
+
+ # 8. Prepare extra step kwargs. TODO: Logic should ideally just be moved out of the pipeline
+ extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta)
+
+ # 9. Modify self attebtion and group norm
+ MODE = "write"
+ uc_mask = (
+ paddle.to_tensor([1] * batch_size * num_images_per_prompt + [0] * batch_size * num_images_per_prompt)
+ .astype(ref_image_latents.dtype)
+ .astype("bool")
+ )
+
+ def hacked_basic_transformer_inner_forward(
+ self,
+ hidden_states: paddle.Tensor,
+ attention_mask: Optional[paddle.Tensor] = None,
+ encoder_hidden_states: Optional[paddle.Tensor] = None,
+ encoder_attention_mask: Optional[paddle.Tensor] = None,
+ timestep: Optional[paddle.Tensor] = None,
+ cross_attention_kwargs: Dict[str, Any] = None,
+ class_labels: Optional[paddle.Tensor] = None,
+ ):
+ if self.use_ada_layer_norm:
+ norm_hidden_states = self.norm1(hidden_states, timestep)
+ elif self.use_ada_layer_norm_zero:
+ norm_hidden_states, gate_msa, shift_mlp, scale_mlp, gate_mlp = self.norm1(
+ hidden_states, timestep, class_labels, hidden_dtype=hidden_states.dtype
+ )
+ else:
+ norm_hidden_states = self.norm1(hidden_states)
+
+ # 1. Self-Attention
+ cross_attention_kwargs = cross_attention_kwargs if cross_attention_kwargs is not None else {}
+ if self.only_cross_attention:
+ attn_output = self.attn1(
+ norm_hidden_states,
+ encoder_hidden_states=encoder_hidden_states if self.only_cross_attention else None,
+ attention_mask=attention_mask,
+ **cross_attention_kwargs,
+ )
+ else:
+ if MODE == "write":
+ self.bank.append(norm_hidden_states.detach().clone())
+ attn_output = self.attn1(
+ norm_hidden_states,
+ encoder_hidden_states=encoder_hidden_states if self.only_cross_attention else None,
+ attention_mask=attention_mask,
+ **cross_attention_kwargs,
+ )
+ if MODE == "read":
+ if attention_auto_machine_weight > self.attn_weight:
+ attn_output_uc = self.attn1(
+ norm_hidden_states,
+ encoder_hidden_states=paddle.concat([norm_hidden_states] + self.bank, axis=1),
+ # attention_mask=attention_mask,
+ **cross_attention_kwargs,
+ )
+ attn_output_c = attn_output_uc.clone()
+ if do_classifier_free_guidance and style_fidelity > 0:
+ attn_output_c[uc_mask] = self.attn1(
+ norm_hidden_states[uc_mask],
+ encoder_hidden_states=norm_hidden_states[uc_mask],
+ **cross_attention_kwargs,
+ )
+ attn_output = style_fidelity * attn_output_c + (1.0 - style_fidelity) * attn_output_uc
+ self.bank.clear()
+ else:
+ attn_output = self.attn1(
+ norm_hidden_states,
+ encoder_hidden_states=encoder_hidden_states if self.only_cross_attention else None,
+ attention_mask=attention_mask,
+ **cross_attention_kwargs,
+ )
+ if self.use_ada_layer_norm_zero:
+ attn_output = gate_msa.unsqueeze(1) * attn_output
+ hidden_states = attn_output + hidden_states
+
+ if self.attn2 is not None:
+ norm_hidden_states = (
+ self.norm2(hidden_states, timestep) if self.use_ada_layer_norm else self.norm2(hidden_states)
+ )
+
+ # 2. Cross-Attention
+ attn_output = self.attn2(
+ norm_hidden_states,
+ encoder_hidden_states=encoder_hidden_states,
+ attention_mask=encoder_attention_mask,
+ **cross_attention_kwargs,
+ )
+ hidden_states = attn_output + hidden_states
+
+ # 3. Feed-forward
+ norm_hidden_states = self.norm3(hidden_states)
+
+ if self.use_ada_layer_norm_zero:
+ norm_hidden_states = norm_hidden_states * (1 + scale_mlp[:, None]) + shift_mlp[:, None]
+
+ ff_output = self.ff(norm_hidden_states)
+
+ if self.use_ada_layer_norm_zero:
+ ff_output = gate_mlp.unsqueeze(1) * ff_output
+
+ hidden_states = ff_output + hidden_states
+
+ return hidden_states
+
+ def hacked_mid_forward(self, *args, **kwargs):
+ eps = 1e-6
+ x = self.original_forward(*args, **kwargs)
+ if MODE == "write":
+ if gn_auto_machine_weight >= self.gn_weight:
+ var = paddle.var(x, axis=(2, 3), unbiased=True, keepdim=True)
+ mean = paddle.mean(x, axis=(2, 3), keepdim=True)
+ self.mean_bank.append(mean)
+ self.var_bank.append(var)
+ if MODE == "read":
+ if len(self.mean_bank) > 0 and len(self.var_bank) > 0:
+ var = paddle.var(x, axis=(2, 3), unbiased=True, keepdim=True)
+ mean = paddle.mean(x, axis=(2, 3), keepdim=True)
+ std = paddle.maximum(var, paddle.zeros_like(var) + eps) ** 0.5
+ mean_acc = sum(self.mean_bank) / float(len(self.mean_bank))
+ var_acc = sum(self.var_bank) / float(len(self.var_bank))
+ std_acc = paddle.maximum(var_acc, paddle.zeros_like(var_acc) + eps) ** 0.5
+ x_uc = (((x - mean) / std) * std_acc) + mean_acc
+ x_c = x_uc.clone()
+ if do_classifier_free_guidance and style_fidelity > 0:
+ x_c[uc_mask] = x[uc_mask]
+ x = style_fidelity * x_c + (1.0 - style_fidelity) * x_uc
+ self.mean_bank = []
+ self.var_bank = []
+ return x
+
+ def hack_CrossAttnDownBlock2D_forward(
+ self,
+ hidden_states: paddle.Tensor,
+ temb: Optional[paddle.Tensor] = None,
+ encoder_hidden_states: Optional[paddle.Tensor] = None,
+ attention_mask: Optional[paddle.Tensor] = None,
+ cross_attention_kwargs: Optional[Dict[str, Any]] = None,
+ encoder_attention_mask: Optional[paddle.Tensor] = None,
+ ):
+ eps = 1e-6
+
+ # TODO(Patrick, William) - attention mask is not used
+ output_states = ()
+
+ for i, (resnet, attn) in enumerate(zip(self.resnets, self.attentions)):
+ hidden_states = resnet(hidden_states, temb)
+ hidden_states = attn(
+ hidden_states,
+ encoder_hidden_states=encoder_hidden_states,
+ cross_attention_kwargs=cross_attention_kwargs,
+ attention_mask=attention_mask,
+ encoder_attention_mask=encoder_attention_mask,
+ return_dict=False,
+ )[0]
+ if MODE == "write":
+ if gn_auto_machine_weight >= self.gn_weight:
+ var = paddle.var(hidden_states, axis=(2, 3), unbiased=True, keepdim=True)
+ mean = paddle.mean(hidden_states, axis=(2, 3), keepdim=True)
+ self.mean_bank.append([mean])
+ self.var_bank.append([var])
+ if MODE == "read":
+ if len(self.mean_bank) > 0 and len(self.var_bank) > 0:
+ var = paddle.var(hidden_states, axis=(2, 3), unbiased=True, keepdim=True)
+ mean = paddle.mean(hidden_states, axis=(2, 3), keepdim=True)
+ std = paddle.maximum(var, paddle.zeros_like(var) + eps) ** 0.5
+ mean_acc = sum(self.mean_bank[i]) / float(len(self.mean_bank[i]))
+ var_acc = sum(self.var_bank[i]) / float(len(self.var_bank[i]))
+ std_acc = paddle.maximum(var_acc, paddle.zeros_like(var_acc) + eps) ** 0.5
+ hidden_states_uc = (((hidden_states - mean) / std) * std_acc) + mean_acc
+ hidden_states_c = hidden_states_uc.clone()
+ if do_classifier_free_guidance and style_fidelity > 0:
+ hidden_states_c[uc_mask] = hidden_states[uc_mask]
+ hidden_states = style_fidelity * hidden_states_c + (1.0 - style_fidelity) * hidden_states_uc
+
+ output_states = output_states + (hidden_states,)
+
+ if MODE == "read":
+ self.mean_bank = []
+ self.var_bank = []
+
+ if self.downsamplers is not None:
+ for downsampler in self.downsamplers:
+ hidden_states = downsampler(hidden_states)
+
+ output_states = output_states + (hidden_states,)
+
+ return hidden_states, output_states
+
+ def hacked_DownBlock2D_forward(self, hidden_states, temb=None):
+ eps = 1e-6
+
+ output_states = ()
+
+ for i, resnet in enumerate(self.resnets):
+ hidden_states = resnet(hidden_states, temb)
+
+ if MODE == "write":
+ if gn_auto_machine_weight >= self.gn_weight:
+ var = paddle.var(hidden_states, axis=(2, 3), unbiased=True, keepdim=True)
+ mean = paddle.mean(hidden_states, axis=(2, 3), keepdim=True)
+ self.mean_bank.append([mean])
+ self.var_bank.append([var])
+ if MODE == "read":
+ if len(self.mean_bank) > 0 and len(self.var_bank) > 0:
+ var = paddle.var(hidden_states, axis=(2, 3), unbiased=True, keepdim=True)
+ mean = paddle.mean(hidden_states, axis=(2, 3), keepdim=True)
+ std = paddle.maximum(var, paddle.zeros_like(var) + eps) ** 0.5
+ mean_acc = sum(self.mean_bank[i]) / float(len(self.mean_bank[i]))
+ var_acc = sum(self.var_bank[i]) / float(len(self.var_bank[i]))
+ std_acc = paddle.maximum(var_acc, paddle.zeros_like(var_acc) + eps) ** 0.5
+ hidden_states_uc = (((hidden_states - mean) / std) * std_acc) + mean_acc
+ hidden_states_c = hidden_states_uc.clone()
+ if do_classifier_free_guidance and style_fidelity > 0:
+ hidden_states_c[uc_mask] = hidden_states[uc_mask]
+ hidden_states = style_fidelity * hidden_states_c + (1.0 - style_fidelity) * hidden_states_uc
+
+ output_states = output_states + (hidden_states,)
+
+ if MODE == "read":
+ self.mean_bank = []
+ self.var_bank = []
+
+ if self.downsamplers is not None:
+ for downsampler in self.downsamplers:
+ hidden_states = downsampler(hidden_states)
+
+ output_states = output_states + (hidden_states,)
+
+ return hidden_states, output_states
+
+ def hacked_CrossAttnUpBlock2D_forward(
+ self,
+ hidden_states: paddle.Tensor,
+ res_hidden_states_tuple: Tuple[paddle.Tensor, ...],
+ temb: Optional[paddle.Tensor] = None,
+ encoder_hidden_states: Optional[paddle.Tensor] = None,
+ cross_attention_kwargs: Optional[Dict[str, Any]] = None,
+ upsample_size: Optional[int] = None,
+ attention_mask: Optional[paddle.Tensor] = None,
+ encoder_attention_mask: Optional[paddle.Tensor] = None,
+ ):
+ eps = 1e-6
+ # TODO(Patrick, William) - attention mask is not used
+ for i, (resnet, attn) in enumerate(zip(self.resnets, self.attentions)):
+ # pop res hidden states
+ res_hidden_states = res_hidden_states_tuple[-1]
+ res_hidden_states_tuple = res_hidden_states_tuple[:-1]
+ hidden_states = paddle.concat([hidden_states, res_hidden_states], axis=1)
+ hidden_states = resnet(hidden_states, temb)
+ hidden_states = attn(
+ hidden_states,
+ encoder_hidden_states=encoder_hidden_states,
+ cross_attention_kwargs=cross_attention_kwargs,
+ attention_mask=attention_mask,
+ encoder_attention_mask=encoder_attention_mask,
+ return_dict=False,
+ )[0]
+
+ if MODE == "write":
+ if gn_auto_machine_weight >= self.gn_weight:
+ var = paddle.var(hidden_states, axis=(2, 3), unbiased=True, keepdim=True)
+ mean = paddle.mean(hidden_states, axis=(2, 3), keepdim=True)
+ self.mean_bank.append([mean])
+ self.var_bank.append([var])
+ if MODE == "read":
+ if len(self.mean_bank) > 0 and len(self.var_bank) > 0:
+ var = paddle.var(hidden_states, axis=(2, 3), unbiased=True, keepdim=True)
+ mean = paddle.mean(hidden_states, axis=(2, 3), keepdim=True)
+ std = paddle.maximum(var, paddle.zeros_like(var) + eps) ** 0.5
+ mean_acc = sum(self.mean_bank[i]) / float(len(self.mean_bank[i]))
+ var_acc = sum(self.var_bank[i]) / float(len(self.var_bank[i]))
+ std_acc = paddle.maximum(var_acc, paddle.zeros_like(var_acc) + eps) ** 0.5
+ hidden_states_uc = (((hidden_states - mean) / std) * std_acc) + mean_acc
+ hidden_states_c = hidden_states_uc.clone()
+ if do_classifier_free_guidance and style_fidelity > 0:
+ hidden_states_c[uc_mask] = hidden_states[uc_mask]
+ hidden_states = style_fidelity * hidden_states_c + (1.0 - style_fidelity) * hidden_states_uc
+
+ if MODE == "read":
+ self.mean_bank = []
+ self.var_bank = []
+
+ if self.upsamplers is not None:
+ for upsampler in self.upsamplers:
+ hidden_states = upsampler(hidden_states, upsample_size)
+
+ return hidden_states
+
+ def hacked_UpBlock2D_forward(self, hidden_states, res_hidden_states_tuple, temb=None, upsample_size=None):
+ eps = 1e-6
+ for i, resnet in enumerate(self.resnets):
+ # pop res hidden states
+ res_hidden_states = res_hidden_states_tuple[-1]
+ res_hidden_states_tuple = res_hidden_states_tuple[:-1]
+ hidden_states = paddle.concat([hidden_states, res_hidden_states], axis=1)
+ hidden_states = resnet(hidden_states, temb)
+
+ if MODE == "write":
+ if gn_auto_machine_weight >= self.gn_weight:
+ var = paddle.var(hidden_states, axis=(2, 3), unbiased=True, keepdim=True)
+ mean = paddle.mean(hidden_states, axis=(2, 3), keepdim=True)
+ self.mean_bank.append([mean])
+ self.var_bank.append([var])
+ if MODE == "read":
+ if len(self.mean_bank) > 0 and len(self.var_bank) > 0:
+ var = paddle.var(hidden_states, axis=(2, 3), unbiased=True, keepdim=True)
+ mean = paddle.mean(hidden_states, axis=(2, 3), keepdim=True)
+ std = paddle.maximum(var, paddle.zeros_like(var) + eps) ** 0.5
+ mean_acc = sum(self.mean_bank[i]) / float(len(self.mean_bank[i]))
+ var_acc = sum(self.var_bank[i]) / float(len(self.var_bank[i]))
+ std_acc = paddle.maximum(var_acc, paddle.zeros_like(var_acc) + eps) ** 0.5
+ hidden_states_uc = (((hidden_states - mean) / std) * std_acc) + mean_acc
+ hidden_states_c = hidden_states_uc.clone()
+ if do_classifier_free_guidance and style_fidelity > 0:
+ hidden_states_c[uc_mask] = hidden_states[uc_mask]
+ hidden_states = style_fidelity * hidden_states_c + (1.0 - style_fidelity) * hidden_states_uc
+
+ if MODE == "read":
+ self.mean_bank = []
+ self.var_bank = []
+
+ if self.upsamplers is not None:
+ for upsampler in self.upsamplers:
+ hidden_states = upsampler(hidden_states, upsample_size)
+
+ return hidden_states
+
+ if reference_attn:
+ attn_modules = [module for module in paddle_dfs(self.unet) if isinstance(module, BasicTransformerBlock)]
+ attn_modules = sorted(attn_modules, key=lambda x: -x.norm1._normalized_shape[0])
+
+ for i, module in enumerate(attn_modules):
+ module._original_inner_forward = module.forward
+ module.forward = hacked_basic_transformer_inner_forward.__get__(module, BasicTransformerBlock)
+ module.bank = []
+ module.attn_weight = float(i) / float(len(attn_modules))
+
+ if reference_adain:
+ gn_modules = [self.unet.mid_block]
+ self.unet.mid_block.gn_weight = 0
+
+ down_blocks = self.unet.down_blocks
+ for w, module in enumerate(down_blocks):
+ module.gn_weight = 1.0 - float(w) / float(len(down_blocks))
+ gn_modules.append(module)
+
+ up_blocks = self.unet.up_blocks
+ for w, module in enumerate(up_blocks):
+ module.gn_weight = float(w) / float(len(up_blocks))
+ gn_modules.append(module)
+
+ for i, module in enumerate(gn_modules):
+ if getattr(module, "original_forward", None) is None:
+ module.original_forward = module.forward
+ if i == 0:
+ # mid_block
+ module.forward = hacked_mid_forward.__get__(module, paddle.nn.Layer)
+ elif isinstance(module, CrossAttnDownBlock2D):
+ module.forward = hack_CrossAttnDownBlock2D_forward.__get__(module, CrossAttnDownBlock2D)
+ elif isinstance(module, DownBlock2D):
+ module.forward = hacked_DownBlock2D_forward.__get__(module, DownBlock2D)
+ elif isinstance(module, CrossAttnUpBlock2D):
+ module.forward = hacked_CrossAttnUpBlock2D_forward.__get__(module, CrossAttnUpBlock2D)
+ elif isinstance(module, UpBlock2D):
+ module.forward = hacked_UpBlock2D_forward.__get__(module, UpBlock2D)
+ module.mean_bank = []
+ module.var_bank = []
+ module.gn_weight *= 2
+
+ # 10. Prepare added time ids & embeddings
+ add_text_embeds = pooled_prompt_embeds
+ add_time_ids = self._get_add_time_ids(
+ original_size, crops_coords_top_left, target_size, dtype=prompt_embeds.dtype
+ )
+
+ if do_classifier_free_guidance:
+ prompt_embeds = paddle.concat([negative_prompt_embeds, prompt_embeds], axis=0)
+ add_text_embeds = paddle.concat([negative_pooled_prompt_embeds, add_text_embeds], axis=0)
+ add_time_ids = paddle.concat([add_time_ids, add_time_ids], axis=0)
+
+ prompt_embeds = prompt_embeds._to(device)
+ add_text_embeds = add_text_embeds._to(device)
+ add_time_ids = add_time_ids._to(device).tile((batch_size * num_images_per_prompt, 1))
+
+ # 11. Denoising loop
+ num_warmup_steps = max(len(timesteps) - num_inference_steps * self.scheduler.order, 0)
+
+ # 10.1 Apply denoising_end
+ if denoising_end is not None and isinstance(denoising_end, float) and denoising_end > 0 and denoising_end < 1:
+ discrete_timestep_cutoff = int(
+ round(
+ self.scheduler.config.num_train_timesteps
+ - (denoising_end * self.scheduler.config.num_train_timesteps)
+ )
+ )
+ num_inference_steps = len(list(filter(lambda ts: ts >= discrete_timestep_cutoff, timesteps)))
+ timesteps = timesteps[:num_inference_steps]
+
+ with self.progress_bar(total=num_inference_steps) as progress_bar:
+ for i, t in enumerate(timesteps):
+ # expand the latents if we are doing classifier free guidance
+ latent_model_input = paddle.concat([latents] * 2) if do_classifier_free_guidance else latents
+
+ latent_model_input = self.scheduler.scale_model_input(latent_model_input, t)
+
+ added_cond_kwargs = {"text_embeds": add_text_embeds, "time_ids": add_time_ids}
+
+ # ref only part
+ noise = randn_tensor(ref_image_latents.shape, generator=generator, dtype=ref_image_latents.dtype)
+ ref_xt = self.scheduler.add_noise(
+ ref_image_latents,
+ noise,
+ t.reshape((1,)),
+ )
+ ref_xt = self.scheduler.scale_model_input(ref_xt, t)
+
+ MODE = "write"
+
+ self.unet(
+ ref_xt,
+ t,
+ encoder_hidden_states=prompt_embeds,
+ cross_attention_kwargs=cross_attention_kwargs,
+ added_cond_kwargs=added_cond_kwargs,
+ return_dict=False,
+ )
+
+ # predict the noise residual
+ MODE = "read"
+ noise_pred = self.unet(
+ latent_model_input,
+ t,
+ encoder_hidden_states=prompt_embeds,
+ cross_attention_kwargs=cross_attention_kwargs,
+ added_cond_kwargs=added_cond_kwargs,
+ return_dict=False,
+ )[0]
+
+ # perform guidance
+ if do_classifier_free_guidance:
+ noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
+ noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond)
+
+ if do_classifier_free_guidance and guidance_rescale > 0.0:
+ # Based on 3.4. in https://arxiv.org/pdf/2305.08891.pdf
+ noise_pred = rescale_noise_cfg(noise_pred, noise_pred_text, guidance_rescale=guidance_rescale)
+
+ # compute the previous noisy sample x_t -> x_t-1
+ latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs, return_dict=False)[0]
+
+ # call the callback, if provided
+ if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0):
+ progress_bar.update()
+ if callback is not None and i % callback_steps == 0:
+ step_idx = i // getattr(self.scheduler, "order", 1)
+ callback(step_idx, t, latents)
+
+ if not output_type == "latent":
+ # make sure the VAE is in float32 mode, as it overflows in float16
+ needs_upcasting = self.vae.dtype == paddle.float16 and self.vae.config.force_upcast
+
+ if needs_upcasting:
+ self.upcast_vae()
+ latents = latents.cast(next(iter(self.vae.post_quant_conv.parameters())).dtype)
+
+ image = self.vae.decode(latents / self.vae.config.scaling_factor, return_dict=False)[0]
+
+ # cast back to fp16 if needed
+ if needs_upcasting:
+ self.vae.to(dtype=paddle.float16)
+ else:
+ image = latents
+ return StableDiffusionXLPipelineOutput(images=image)
+
+ # apply watermark if available
+ if self.watermark is not None:
+ image = self.watermark.apply_watermark(image)
+
+ image = self.image_processor.postprocess(image, output_type=output_type)
+
+ if not return_dict:
+ return (image,)
+
+ return StableDiffusionXLPipelineOutput(images=image)