Skip to content
New issue

Have a question about this project? # for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “#”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? # to your account

Add PAG support to StableDiffusionControlNetPAGInpaintPipeline #8875

Merged
merged 21 commits into from
Oct 1, 2024

Conversation

juancopi81
Copy link
Contributor

What does this PR do?

Adds PAG (Perturbed-Attention Guidance) support for SD 1.5 with controlnet and inpainting models (StableDiffusionControlNetPAGInpaintPipeline )
Continuation of #8710. It was not mentioned on #8710 but I think this pipeline is pretty cool also 😄

Before submitting

Who can review?

@a-r-r-o-w @yiyixuxu
Anyone in the community is free to review the PR once the tests have passed.

Code example

import cv2
from diffusers import AutoPipelineForInpainting, ControlNetModel, DDIMScheduler
from diffusers.utils import load_image
import numpy as np
from PIL import Image
import torch

init_image = load_image("https://huggingface.co/datasets/diffusers/test-arrays/resolve/main/stable_diffusion_inpaint/boy.png")
init_image = init_image.resize((512, 512))

mask_image = load_image("https://huggingface.co/datasets/diffusers/test-arrays/resolve/main/stable_diffusion_inpaint/boy_mask.png")
mask_image = mask_image.resize((512, 512))

def make_canny_condition(image):
    image = np.array(image)
    image = cv2.Canny(image, 100, 200)
    image = image[:, :, None]
    image = np.concatenate([image, image, image], axis=2)
    image = Image.fromarray(image)
    return image

control_image = make_canny_condition(init_image)

controlnet = ControlNetModel.from_pretrained("lllyasviel/control_v11p_sd15_inpaint", torch_dtype=torch.float16)

pipe = AutoPipelineForInpainting.from_pretrained(
    "jayparmr/icbinp_v8_inpaint_v2", controlnet=controlnet, torch_dtype=torch.float16, enable_pag=True, requires_safety_checker=False, safety_checker=None,
)

pipe.scheduler = DDIMScheduler.from_config(pipe.scheduler.config)
pipe.enable_model_cpu_offload()

pag_scales =  [0.0, 3.0]
guidance_scales = [0.0, 2.0]

grid = []
for pag_scale in pag_scales:
    for guidance_scale in guidance_scales:
        generator = torch.Generator(device="cpu").manual_seed(0)
        # generate image
        image = pipe(
            "a man with ray-ban sunglasses",
            negative_prompt="low quality, bad quality",
            num_inference_steps=20,
            generator=generator,
            eta=1.0,
            image=init_image,
            mask_image=mask_image,
            control_image=control_image,
            pag_scale=pag_scale,
            strength=0.99,
            guidance_scale=guidance_scale,
            controlnet_conditioning_scale=0.2
        ).images
        grid.append(image[0])

from diffusers.utils import make_image_grid
make_image_grid(grid, rows=len(pag_scales), cols=len(guidance_scales))

pag_controlnet_inpaint

@HuggingFaceDocBuilderDev

The docs for this PR live here. All of your documentation changes will be reflected on that endpoint. The docs are available until 30 days after the last update.

Copy link
Member

@a-r-r-o-w a-r-r-o-w left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for working on this! The PR looks absolutely perfect with all required PAG changes. I think you need to run make style for the CI to go green. @yiyixuxu WDYT?

control_image = control_images if isinstance(control_image, list) else control_images[0]
controlnet_prompt_embeds = prompt_embeds

# 7.2 Create tensor stating which controlnets to keep
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The numbering here onwards seems incorrect

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Oh you're right! Changing it...

mask_image,
height,
width,
callback_steps,
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think you can remove this parameter completely instead of passing None

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Oh sure! I just removed it

@juancopi81
Copy link
Contributor Author

Hi @a-r-r-o-w ,

Thank you very much for your feedback 🚀 ! I just updated my PR trying to address all the issues.

Thanks again.

@juancopi81
Copy link
Contributor Author

Sorry I had a small mistake that made some tests failed. It should be solved now!

Copy link
Collaborator

@yiyixuxu yiyixuxu left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

thanks! super cool PR! I left one comment,
also, can we make sure it work with from_pipe()?

@@ -944,7 +946,10 @@ def from_pretrained(cls, pretrained_model_or_path, **kwargs):
if "enable_pag" in kwargs:
enable_pag = kwargs.pop("enable_pag")
if enable_pag:
orig_class_name = config["_class_name"].replace("Pipeline", "PAGPipeline")
if "controlnet" in kwargs:
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

can you provide an example where this code addition is needed? I think it should already be addressed in line 944 no?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hi @yiyixuxu,

Thank you very much for your comments. Sorry if I am missing something here, but my idea is as follows:
There are 4 cases that need to be handled:

  • No ControlNet and no Pag -> Keep Pipeline

  • With ControlNet and no Pag -> Replace Pipeline with ControlNetPipeline

  • No ControlNet and with Pag -> Replace Pipeline with PAGPipeline

  • With ControlNet and with Pag -> Replace Pipeline with ControlNetPAGPipeline, which will override the first case.

That's why I added the nested conditions, but maybe I am missing something 🤔

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

i see! thank for explaining it, you were absolute ly right!

can we do this instead (to be consistent with AutoPipelineForText2Image and AutoPipelineForImage2Image)

if "controlnet" in kwargs:
            orig_class_name = config["_class_name"].replace("Pipeline", "ControlNetPipeline")
        if "enable_pag" in kwargs:
            enable_pag = kwargs.pop("enable_pag")
            if enable_pag:
                orig_class_name = orig_class_name.replace("Pipeline", "PAGPipeline")

orig_class_name = orig_class_name.replace("Pipeline", "PAGPipeline")

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Oh sure, done! This way is better!

@juancopi81
Copy link
Contributor Author

thanks! super cool PR! I left one comment, also, can we make sure it work with from_pipe()?

Hi @yiyixuxu,

Thanks for your support and feedback. I updated my PR with a new test. I run it with:

tests/pipelines/test_pipelines_auto.py::AutoPipelineFastTest::test_from_pipe_pag_controlnet_inpaint

and it seems to be working fine. How do you like it? Would be enough for testing that it works with from_pipe()?

Please let me know 😃

Copy link
Collaborator

@yiyixuxu yiyixuxu left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

very nice! I left some comments for the test, since all the from_pipe tests are based on SDXL, if you want to skip it, it is ok too! Just let us know!

Let's merge this soon

@@ -283,6 +283,26 @@ def test_from_pipe_pag_inpaint(self):
pipe = AutoPipelineForInpainting.from_pipe(pipe_pag, enable_pag=False)
assert pipe.__class__.__name__ == "StableDiffusionXLInpaintPipeline"

def test_from_pipe_pag_controlnet_inpaint(self):
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nice! we can combine this test with test_from_pipe_pag_inpaint and make sure it covers all the scenarios similarly as in https://github.com/huggingface/diffusers/blob/main/tests/pipelines/test_pipelines_auto.py#L141

  1. test_from_pipe_pag_inpaint test for the inpainting XL pipeline, we can update to test SD pipeline instead because we now have a controlnet + inpaint + PAG for SD but not for SDXL
  2. make sure :
    • we test use from_pipe on inpainting pipeline, inpainting controlnet pipeline and pag inpainting controlnet
    • for each, make sure we test the combination of enable_pag and controlnet`

but if you want to skip the test, it's ok too! we can update the test when we have the sdxl pag inpainting pipeline as well

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks @yiyixuxu! I'm short on time right now, so I removed it. Would this work for now? I can come back later to finish it properly, but I wanted to get the PR merged sooner. Please let me know what you think! 😄

@juancopi81
Copy link
Contributor Author

Thanks a lot @yiyixuxu for your feedback. I addressed your comment in src/diffusers/pipelines/auto_pipeline.py, and remove for now the test I added so the PR can be merged. I can come back later to finish it properly or if someone want to take it, please let me know!

@yiyixuxu yiyixuxu added the PAG label Sep 4, 2024
@yiyixuxu yiyixuxu mentioned this pull request Sep 7, 2024
6 tasks
Copy link
Contributor

This issue has been automatically marked as stale because it has not had recent activity. If you think this still needs to be addressed please comment on this thread.

Please note that issues that do not follow the contributing guidelines are likely to be ignored.

@github-actions github-actions bot added the stale Issues that haven't received updates label Sep 29, 2024
@yiyixuxu yiyixuxu removed the stale Issues that haven't received updates label Sep 30, 2024
Copy link
Collaborator

@yiyixuxu yiyixuxu left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

thanks so much for the PR
and very sorry I forgot about it and only merge it now

@yiyixuxu yiyixuxu merged commit 33fafe3 into huggingface:main Oct 1, 2024
15 checks passed
@juancopi81 juancopi81 deleted the pag_controlnet_inpaint_sd15 branch October 1, 2024 13:42
@juancopi81
Copy link
Contributor Author

No worries @yiyixuxu! Thank you very much for the support. I am very happy that now is merged 🚀

leisuzz pushed a commit to leisuzz/diffusers that referenced this pull request Oct 11, 2024
…ngface#8875)

* Add pag to controlnet inpainting pipeline


---------

Co-authored-by: YiYi Xu <yixu310@gmail.com>
sayakpaul pushed a commit that referenced this pull request Dec 23, 2024
* Add pag to controlnet inpainting pipeline


---------

Co-authored-by: YiYi Xu <yixu310@gmail.com>
# for free to join this conversation on GitHub. Already have an account? # to comment
Labels
Projects
None yet
Development

Successfully merging this pull request may close these issues.

4 participants