Skip to content

Commit 8690e8b

Browse files
authored
add PAG support for SD architecture (#8725)
* add pag to sd pipelines
1 parent 7db8c3e commit 8690e8b

File tree

9 files changed

+1424
-1
lines changed

9 files changed

+1424
-1
lines changed

docs/source/en/api/pipelines/pag.md

+5
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,11 @@ The abstract from the paper is:
2020

2121
*Recent studies have demonstrated that diffusion models are capable of generating high-quality samples, but their quality heavily depends on sampling guidance techniques, such as classifier guidance (CG) and classifier-free guidance (CFG). These techniques are often not applicable in unconditional generation or in various downstream tasks such as image restoration. In this paper, we propose a novel sampling guidance, called Perturbed-Attention Guidance (PAG), which improves diffusion sample quality across both unconditional and conditional settings, achieving this without requiring additional training or the integration of external modules. PAG is designed to progressively enhance the structure of samples throughout the denoising process. It involves generating intermediate samples with degraded structure by substituting selected self-attention maps in diffusion U-Net with an identity matrix, by considering the self-attention mechanisms' ability to capture structural information, and guiding the denoising process away from these degraded samples. In both ADM and Stable Diffusion, PAG surprisingly improves sample quality in conditional and even unconditional scenarios. Moreover, PAG significantly improves the baseline performance in various downstream tasks where existing guidances such as CG or CFG cannot be fully utilized, including ControlNet with empty prompts and image restoration such as inpainting and deblurring.*
2222

23+
## StableDiffusionPAGPipeline
24+
[[autodoc]] StableDiffusionPAGPipeline
25+
- all
26+
- __call__
27+
2328
## StableDiffusionXLPAGPipeline
2429
[[autodoc]] StableDiffusionXLPAGPipeline
2530
- all

src/diffusers/__init__.py

+2
Original file line numberDiff line numberDiff line change
@@ -304,6 +304,7 @@
304304
"StableDiffusionLatentUpscalePipeline",
305305
"StableDiffusionLDM3DPipeline",
306306
"StableDiffusionModelEditingPipeline",
307+
"StableDiffusionPAGPipeline",
307308
"StableDiffusionPanoramaPipeline",
308309
"StableDiffusionParadigmsPipeline",
309310
"StableDiffusionPipeline",
@@ -702,6 +703,7 @@
702703
StableDiffusionLatentUpscalePipeline,
703704
StableDiffusionLDM3DPipeline,
704705
StableDiffusionModelEditingPipeline,
706+
StableDiffusionPAGPipeline,
705707
StableDiffusionPanoramaPipeline,
706708
StableDiffusionParadigmsPipeline,
707709
StableDiffusionPipeline,

src/diffusers/pipelines/__init__.py

+2
Original file line numberDiff line numberDiff line change
@@ -141,6 +141,7 @@
141141
)
142142
_import_structure["pag"].extend(
143143
[
144+
"StableDiffusionPAGPipeline",
144145
"StableDiffusionXLPAGPipeline",
145146
"StableDiffusionXLPAGInpaintPipeline",
146147
"StableDiffusionXLControlNetPAGPipeline",
@@ -491,6 +492,7 @@
491492
)
492493
from .musicldm import MusicLDMPipeline
493494
from .pag import (
495+
StableDiffusionPAGPipeline,
494496
StableDiffusionXLControlNetPAGPipeline,
495497
StableDiffusionXLPAGImg2ImgPipeline,
496498
StableDiffusionXLPAGInpaintPipeline,

src/diffusers/pipelines/auto_pipeline.py

+2
Original file line numberDiff line numberDiff line change
@@ -47,6 +47,7 @@
4747
from .kandinsky3 import Kandinsky3Img2ImgPipeline, Kandinsky3Pipeline
4848
from .latent_consistency_models import LatentConsistencyModelImg2ImgPipeline, LatentConsistencyModelPipeline
4949
from .pag import (
50+
StableDiffusionPAGPipeline,
5051
StableDiffusionXLControlNetPAGPipeline,
5152
StableDiffusionXLPAGImg2ImgPipeline,
5253
StableDiffusionXLPAGInpaintPipeline,
@@ -88,6 +89,7 @@
8889
("lcm", LatentConsistencyModelPipeline),
8990
("pixart-alpha", PixArtAlphaPipeline),
9091
("pixart-sigma", PixArtSigmaPipeline),
92+
("stable-diffusion-pag", StableDiffusionPAGPipeline),
9193
("stable-diffusion-xl-pag", StableDiffusionXLPAGPipeline),
9294
("stable-diffusion-xl-controlnet-pag", StableDiffusionXLControlNetPAGPipeline),
9395
]

src/diffusers/pipelines/pag/__init__.py

+2
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,7 @@
2323
_dummy_objects.update(get_objects_from_module(dummy_torch_and_transformers_objects))
2424
else:
2525
_import_structure["pipeline_pag_controlnet_sd_xl"] = ["StableDiffusionXLControlNetPAGPipeline"]
26+
_import_structure["pipeline_pag_sd"] = ["StableDiffusionPAGPipeline"]
2627
_import_structure["pipeline_pag_sd_xl"] = ["StableDiffusionXLPAGPipeline"]
2728
_import_structure["pipeline_pag_sd_xl_img2img"] = ["StableDiffusionXLPAGImg2ImgPipeline"]
2829
_import_structure["pipeline_pag_sd_xl_inpaint"] = ["StableDiffusionXLPAGInpaintPipeline"]
@@ -36,6 +37,7 @@
3637
from ...utils.dummy_torch_and_transformers_objects import *
3738
else:
3839
from .pipeline_pag_controlnet_sd_xl import StableDiffusionXLControlNetPAGPipeline
40+
from .pipeline_pag_sd import StableDiffusionPAGPipeline
3941
from .pipeline_pag_sd_xl import StableDiffusionXLPAGPipeline
4042
from .pipeline_pag_sd_xl_img2img import StableDiffusionXLPAGImg2ImgPipeline
4143
from .pipeline_pag_sd_xl_inpaint import StableDiffusionXLPAGInpaintPipeline

0 commit comments

Comments
 (0)