Skip to content
New issue

Have a question about this project? # for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “#”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? # to your account

[core] Move community AnimateDiff ControlNet to core #8972

Merged
merged 21 commits into from
Jul 30, 2024

Conversation

a-r-r-o-w
Copy link
Member

@a-r-r-o-w a-r-r-o-w commented Jul 25, 2024

What does this PR do?

Moves the community implementation of AnimateDiff ControlNet to core. Part of supporting long vid2vid generations here.

Code
import torch
from diffusers import AnimateDiffControlNetPipeline, AutoencoderKL, ControlNetModel, MotionAdapter, LCMScheduler
from diffusers.utils import export_to_gif, load_video

# Additionally, you will need a preprocess videos before they can be used with the ControlNet
# HF maintains just the right package for it: `pip install controlnet_aux`
from controlnet_aux.processor import ZoeDetector

# Download controlnets from https://huggingface.co/lllyasviel/ControlNet-v1-1 to use .from_single_file
# Download Diffusers-format controlnets, such as https://huggingface.co/lllyasviel/sd-controlnet-depth, to use .from_pretrained()
controlnet = ControlNetModel.from_single_file("control_v11f1p_sd15_depth.pth", torch_dtype=torch.float16)

# We use AnimateLCM for this example but one can use the original motion adapters as well (for example, https://huggingface.co/guoyww/animatediff-motion-adapter-v1-5-3)
motion_adapter = MotionAdapter.from_pretrained("wangfuyun/AnimateLCM")

vae = AutoencoderKL.from_pretrained("stabilityai/sd-vae-ft-mse", torch_dtype=torch.float16)
pipe: AnimateDiffControlNetPipeline = AnimateDiffControlNetPipeline.from_pretrained(
    "SG161222/Realistic_Vision_V5.1_noVAE",
    motion_adapter=motion_adapter,
    controlnet=controlnet,
    vae=vae,
).to(device="cuda", dtype=torch.float16)
pipe.scheduler = LCMScheduler.from_config(pipe.scheduler.config, beta_schedule="linear")

pipe.load_lora_weights("wangfuyun/AnimateLCM", weight_name="AnimateLCM_sd15_t2v_lora.safetensors", adapter_name="lcm-lora")
pipe.set_adapters(["lcm-lora"], [0.8])

depth_detector = ZoeDetector.from_pretrained("lllyasviel/Annotators").to("cuda")
video = load_video("https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/animatediff-vid2vid-input-1.gif")
conditioning_frames = []

with pipe.progress_bar(total=len(video)) as progress_bar:
    for frame in video:
        conditioning_frames.append(depth_detector(frame))
        progress_bar.update()

prompt = "a panda, playing a guitar, sitting in a pink boat, in the ocean, mountains in background, realistic, high quality"
negative_prompt = "bad quality, worst quality"

video = pipe(
    prompt=prompt,
    negative_prompt=negative_prompt,
    num_frames=len(video),
    num_inference_steps=10,
    guidance_scale=2.0,
    conditioning_frames=conditioning_frames,
    generator=torch.Generator().manual_seed(42),
).frames[0]

export_to_gif(video, "animatediff_controlnet.gif", fps=8)
Source Video Output Video
raccoon playing a guitar
racoon playing a guitar
a panda, playing a guitar, sitting in a pink boat, in the ocean, mountains in background, realistic, high quality
a panda, playing a guitar, sitting in a pink boat, in the ocean, mountains in background, realistic, high quality

Documentation PR requires merging: https://huggingface.co/datasets/huggingface/documentation-images/discussions/351

Who can review?

Anyone in the community is free to review the PR once the tests have passed. Feel free to tag
members/contributors who may be interested in your PR.

@DN6 @yiyixuxu

@HuggingFaceDocBuilderDev

The docs for this PR live here. All of your documentation changes will be reflected on that endpoint. The docs are available until 30 days after the last update.

@yiyixuxu yiyixuxu requested a review from DN6 July 25, 2024 17:46
@@ -47,3 +50,77 @@ def load_image(
image = image.convert("RGB")

return image


def load_video(
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Added this here to make it a little more easier to deal with videos instead of using imageio (which maybe could be a separate backend here). Let me know if this would need a separate PR and is out of scope to add here for a cleaner commit history

if was_tempfile_created:
os.remove(video_path)

elif isinstance(video, list) and all(isinstance(frame, PIL.Image.Image) for frame in video):
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If passing a list of PIL images, we would just return the same list back? Why support passing in a list of images then?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I followed the implementation of load_image. And the same list might not be returned since in the code that follows, convert_method ensures a callback is called or we convert the images to RGB (possibly from RGBA/HSV/etc.)

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hmm seems like there's been a bit of back and forth about it
#6479
#6904

IMO it doesn't make much sense to pass a list of already loaded images into the load_video just to run preprocessing. If the user is at a point where they already have this list of images, it should be up to them to preprocess on their own.

I think the conversion to RGB by default should also be removed. A loading function should just return the objects. If additional processing has to be done, it can be done via the convert_method

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I see, okay yeah that makes sense. Will remove this change

a-r-r-o-w added a commit that referenced this pull request Jul 28, 2024
@a-r-r-o-w a-r-r-o-w mentioned this pull request Jul 28, 2024
Copy link
Collaborator

@DN6 DN6 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Left a comment on the load video function, but the rest looks good to me 👍🏽

@a-r-r-o-w
Copy link
Member Author

Left a comment on the load video function, but the rest looks good to me 👍🏽

thanks, will merge once CI is green!

@a-r-r-o-w a-r-r-o-w merged commit e5b94b4 into main Jul 30, 2024
18 checks passed
@a-r-r-o-w a-r-r-o-w deleted the animatediff/controlnet branch July 30, 2024 11:40
DN6 added a commit that referenced this pull request Aug 7, 2024
* initial work draft for freenoise; needs massive cleanup

* fix freeinit bug

* add animatediff controlnet implementation

* revert attention changes

* add freenoise

* remove old helper functions

* add decode batch size param to all pipelines

* make style

* fix copied from comments

* make fix-copies

* make style

* copy animatediff controlnet implementation from #8972

* add experimental support for num_frames not perfectly fitting context length, ocntext stride

* make unet motion model lora work again based on #8995

* copy load video utils from #8972

* copied from AnimateDiff::prepare_latents

* address the case where last batch of frames does not match length of indices in prepare latents

* decode_batch_size->vae_batch_size; batch vae encode support in animatediff vid2vid

* revert sparsectrl and sdxl freenoise changes

* revert pia

* add freenoise tests

* make fix-copies

* improve docstrings

* add freenoise tests to animatediff controlnet

* update tests

* Update src/diffusers/models/unets/unet_motion_model.py

* add freenoise to animatediff pag

* address review comments

* make style

* update tests

* make fix-copies

* fix error message

* remove copied from comment

* fix imports in tests

* update

---------

Co-authored-by: Dhruv Nair <dhruv.nair@gmail.com>
sayakpaul pushed a commit that referenced this pull request Dec 23, 2024
* add animatediff controlnet to core

* make style; remove unused method

* fix copied from comment

* add tests

* changes to make tests work

* add utility function to load videos

* update docs

* update pipeline example

* make style

* update docs with example

* address review comments

* add latest freeinit test from #8969

* LoraLoaderMixin -> StableDiffusionLoraLoaderMixin

* fix docs

* Update src/diffusers/utils/loading_utils.py

Co-authored-by: Dhruv Nair <dhruv.nair@gmail.com>

* fix: variable out of scope

---------

Co-authored-by: Dhruv Nair <dhruv.nair@gmail.com>
sayakpaul pushed a commit that referenced this pull request Dec 23, 2024
* initial work draft for freenoise; needs massive cleanup

* fix freeinit bug

* add animatediff controlnet implementation

* revert attention changes

* add freenoise

* remove old helper functions

* add decode batch size param to all pipelines

* make style

* fix copied from comments

* make fix-copies

* make style

* copy animatediff controlnet implementation from #8972

* add experimental support for num_frames not perfectly fitting context length, ocntext stride

* make unet motion model lora work again based on #8995

* copy load video utils from #8972

* copied from AnimateDiff::prepare_latents

* address the case where last batch of frames does not match length of indices in prepare latents

* decode_batch_size->vae_batch_size; batch vae encode support in animatediff vid2vid

* revert sparsectrl and sdxl freenoise changes

* revert pia

* add freenoise tests

* make fix-copies

* improve docstrings

* add freenoise tests to animatediff controlnet

* update tests

* Update src/diffusers/models/unets/unet_motion_model.py

* add freenoise to animatediff pag

* address review comments

* make style

* update tests

* make fix-copies

* fix error message

* remove copied from comment

* fix imports in tests

* update

---------

Co-authored-by: Dhruv Nair <dhruv.nair@gmail.com>
# for free to join this conversation on GitHub. Already have an account? # to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants