Skip to content

Commit b6fac9d

Browse files
a-r-r-o-wDN6
authored andcommitted
[core] FreeNoise (#8948)
* initial work draft for freenoise; needs massive cleanup * fix freeinit bug * add animatediff controlnet implementation * revert attention changes * add freenoise * remove old helper functions * add decode batch size param to all pipelines * make style * fix copied from comments * make fix-copies * make style * copy animatediff controlnet implementation from #8972 * add experimental support for num_frames not perfectly fitting context length, ocntext stride * make unet motion model lora work again based on #8995 * copy load video utils from #8972 * copied from AnimateDiff::prepare_latents * address the case where last batch of frames does not match length of indices in prepare latents * decode_batch_size->vae_batch_size; batch vae encode support in animatediff vid2vid * revert sparsectrl and sdxl freenoise changes * revert pia * add freenoise tests * make fix-copies * improve docstrings * add freenoise tests to animatediff controlnet * update tests * Update src/diffusers/models/unets/unet_motion_model.py * add freenoise to animatediff pag * address review comments * make style * update tests * make fix-copies * fix error message * remove copied from comment * fix imports in tests * update --------- Co-authored-by: Dhruv Nair <dhruv.nair@gmail.com>
1 parent f35bdb6 commit b6fac9d

11 files changed

+911
-50
lines changed

src/diffusers/models/attention.py

+325-1
Large diffs are not rendered by default.

src/diffusers/models/unets/unet_motion_model.py

+4-3
Original file line numberDiff line numberDiff line change
@@ -343,6 +343,7 @@ def custom_forward(*inputs):
343343

344344
else:
345345
hidden_states = resnet(hidden_states, temb)
346+
346347
hidden_states = motion_module(hidden_states, num_frames=num_frames)
347348

348349
output_states = output_states + (hidden_states,)
@@ -536,6 +537,7 @@ def custom_forward(*inputs):
536537
)[0]
537538
else:
538539
hidden_states = resnet(hidden_states, temb)
540+
539541
hidden_states = attn(
540542
hidden_states,
541543
encoder_hidden_states=encoder_hidden_states,
@@ -761,6 +763,7 @@ def custom_forward(*inputs):
761763
)[0]
762764
else:
763765
hidden_states = resnet(hidden_states, temb)
766+
764767
hidden_states = attn(
765768
hidden_states,
766769
encoder_hidden_states=encoder_hidden_states,
@@ -921,9 +924,9 @@ def custom_forward(*inputs):
921924
hidden_states = torch.utils.checkpoint.checkpoint(
922925
create_custom_forward(resnet), hidden_states, temb
923926
)
924-
925927
else:
926928
hidden_states = resnet(hidden_states, temb)
929+
927930
hidden_states = motion_module(hidden_states, num_frames=num_frames)
928931

929932
if self.upsamplers is not None:
@@ -1923,7 +1926,6 @@ def fn_recursive_attn_processor(name: str, module: torch.nn.Module, processor):
19231926
for name, module in self.named_children():
19241927
fn_recursive_attn_processor(name, module, processor)
19251928

1926-
# Copied from diffusers.models.unets.unet_3d_condition.UNet3DConditionModel.enable_forward_chunking
19271929
def enable_forward_chunking(self, chunk_size: Optional[int] = None, dim: int = 0) -> None:
19281930
"""
19291931
Sets the attention processor to use [feed forward
@@ -1953,7 +1955,6 @@ def fn_recursive_feed_forward(module: torch.nn.Module, chunk_size: int, dim: int
19531955
for module in self.children():
19541956
fn_recursive_feed_forward(module, chunk_size, dim)
19551957

1956-
# Copied from diffusers.models.unets.unet_3d_condition.UNet3DConditionModel.disable_forward_chunking
19571958
def disable_forward_chunking(self) -> None:
19581959
def fn_recursive_feed_forward(module: torch.nn.Module, chunk_size: int, dim: int):
19591960
if hasattr(module, "set_chunk_feed_forward"):

src/diffusers/pipelines/animatediff/pipeline_animatediff.py

+27-11
Original file line numberDiff line numberDiff line change
@@ -42,6 +42,7 @@
4242
from ...utils.torch_utils import randn_tensor
4343
from ...video_processor import VideoProcessor
4444
from ..free_init_utils import FreeInitMixin
45+
from ..free_noise_utils import AnimateDiffFreeNoiseMixin
4546
from ..pipeline_utils import DiffusionPipeline, StableDiffusionMixin
4647
from .pipeline_output import AnimateDiffPipelineOutput
4748

@@ -72,6 +73,7 @@ class AnimateDiffPipeline(
7273
IPAdapterMixin,
7374
StableDiffusionLoraLoaderMixin,
7475
FreeInitMixin,
76+
AnimateDiffFreeNoiseMixin,
7577
):
7678
r"""
7779
Pipeline for text-to-video generation.
@@ -394,15 +396,20 @@ def prepare_ip_adapter_image_embeds(
394396

395397
return ip_adapter_image_embeds
396398

397-
# Copied from diffusers.pipelines.text_to_video_synthesis/pipeline_text_to_video_synth.TextToVideoSDPipeline.decode_latents
398-
def decode_latents(self, latents):
399+
def decode_latents(self, latents, vae_batch_size: int = 16):
399400
latents = 1 / self.vae.config.scaling_factor * latents
400401

401402
batch_size, channels, num_frames, height, width = latents.shape
402403
latents = latents.permute(0, 2, 1, 3, 4).reshape(batch_size * num_frames, channels, height, width)
403404

404-
image = self.vae.decode(latents).sample
405-
video = image[None, :].reshape((batch_size, num_frames, -1) + image.shape[2:]).permute(0, 2, 1, 3, 4)
405+
video = []
406+
for i in range(0, latents.shape[0], vae_batch_size):
407+
batch_latents = latents[i : i + vae_batch_size]
408+
batch_latents = self.vae.decode(batch_latents).sample
409+
video.append(batch_latents)
410+
411+
video = torch.cat(video)
412+
video = video[None, :].reshape((batch_size, num_frames, -1) + video.shape[2:]).permute(0, 2, 1, 3, 4)
406413
# we always cast to float32 as this does not cause significant overhead and is compatible with bfloat16
407414
video = video.float()
408415
return video
@@ -495,22 +502,28 @@ def check_inputs(
495502
f"`ip_adapter_image_embeds` has to be a list of 3D or 4D tensors but is {ip_adapter_image_embeds[0].ndim}D"
496503
)
497504

498-
# Copied from diffusers.pipelines.text_to_video_synthesis.pipeline_text_to_video_synth.TextToVideoSDPipeline.prepare_latents
499505
def prepare_latents(
500506
self, batch_size, num_channels_latents, num_frames, height, width, dtype, device, generator, latents=None
501507
):
508+
# If FreeNoise is enabled, generate latents as described in Equation (7) of [FreeNoise](https://arxiv.org/abs/2310.15169)
509+
if self.free_noise_enabled:
510+
latents = self._prepare_latents_free_noise(
511+
batch_size, num_channels_latents, num_frames, height, width, dtype, device, generator, latents
512+
)
513+
514+
if isinstance(generator, list) and len(generator) != batch_size:
515+
raise ValueError(
516+
f"You have passed a list of generators of length {len(generator)}, but requested an effective batch"
517+
f" size of {batch_size}. Make sure the batch size matches the length of the generators."
518+
)
519+
502520
shape = (
503521
batch_size,
504522
num_channels_latents,
505523
num_frames,
506524
height // self.vae_scale_factor,
507525
width // self.vae_scale_factor,
508526
)
509-
if isinstance(generator, list) and len(generator) != batch_size:
510-
raise ValueError(
511-
f"You have passed a list of generators of length {len(generator)}, but requested an effective batch"
512-
f" size of {batch_size}. Make sure the batch size matches the length of the generators."
513-
)
514527

515528
if latents is None:
516529
latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype)
@@ -569,6 +582,7 @@ def __call__(
569582
clip_skip: Optional[int] = None,
570583
callback_on_step_end: Optional[Callable[[int, int, Dict], None]] = None,
571584
callback_on_step_end_tensor_inputs: List[str] = ["latents"],
585+
vae_batch_size: int = 16,
572586
**kwargs,
573587
):
574588
r"""
@@ -637,6 +651,8 @@ def __call__(
637651
The list of tensor inputs for the `callback_on_step_end` function. The tensors specified in the list
638652
will be passed as `callback_kwargs` argument. You will only be able to include variables listed in the
639653
`._callback_tensor_inputs` attribute of your pipeline class.
654+
vae_batch_size (`int`, defaults to `16`):
655+
The number of frames to decode at a time when calling `decode_latents` method.
640656
641657
Examples:
642658
@@ -808,7 +824,7 @@ def __call__(
808824
if output_type == "latent":
809825
video = latents
810826
else:
811-
video_tensor = self.decode_latents(latents)
827+
video_tensor = self.decode_latents(latents, vae_batch_size)
812828
video = self.video_processor.postprocess_video(video=video_tensor, output_type=output_type)
813829

814830
# 10. Offload all models

src/diffusers/pipelines/animatediff/pipeline_animatediff_controlnet.py

+21-11
Original file line numberDiff line numberDiff line change
@@ -30,6 +30,7 @@
3030
from ...video_processor import VideoProcessor
3131
from ..controlnet.multicontrolnet import MultiControlNetModel
3232
from ..free_init_utils import FreeInitMixin
33+
from ..free_noise_utils import AnimateDiffFreeNoiseMixin
3334
from ..pipeline_utils import DiffusionPipeline, StableDiffusionMixin
3435
from .pipeline_output import AnimateDiffPipelineOutput
3536

@@ -109,6 +110,7 @@ class AnimateDiffControlNetPipeline(
109110
IPAdapterMixin,
110111
StableDiffusionLoraLoaderMixin,
111112
FreeInitMixin,
113+
AnimateDiffFreeNoiseMixin,
112114
):
113115
r"""
114116
Pipeline for text-to-video generation with ControlNet guidance.
@@ -432,15 +434,16 @@ def prepare_ip_adapter_image_embeds(
432434

433435
return ip_adapter_image_embeds
434436

435-
def decode_latents(self, latents, decode_batch_size: int = 16):
437+
# Copied from diffusers.pipelines.animatediff.pipeline_animatediff.AnimateDiffPipeline.decode_latents
438+
def decode_latents(self, latents, vae_batch_size: int = 16):
436439
latents = 1 / self.vae.config.scaling_factor * latents
437440

438441
batch_size, channels, num_frames, height, width = latents.shape
439442
latents = latents.permute(0, 2, 1, 3, 4).reshape(batch_size * num_frames, channels, height, width)
440443

441444
video = []
442-
for i in range(0, latents.shape[0], decode_batch_size):
443-
batch_latents = latents[i : i + decode_batch_size]
445+
for i in range(0, latents.shape[0], vae_batch_size):
446+
batch_latents = latents[i : i + vae_batch_size]
444447
batch_latents = self.vae.decode(batch_latents).sample
445448
video.append(batch_latents)
446449

@@ -608,22 +611,29 @@ def check_inputs(
608611
if end > 1.0:
609612
raise ValueError(f"control guidance end: {end} can't be larger than 1.0.")
610613

611-
# Copied from diffusers.pipelines.text_to_video_synthesis.pipeline_text_to_video_synth.TextToVideoSDPipeline.prepare_latents
614+
# Copied from diffusers.pipelines.animatediff.pipeline_animatediff.AnimateDiffPipeline.prepare_latents
612615
def prepare_latents(
613616
self, batch_size, num_channels_latents, num_frames, height, width, dtype, device, generator, latents=None
614617
):
618+
# If FreeNoise is enabled, generate latents as described in Equation (7) of [FreeNoise](https://arxiv.org/abs/2310.15169)
619+
if self.free_noise_enabled:
620+
latents = self._prepare_latents_free_noise(
621+
batch_size, num_channels_latents, num_frames, height, width, dtype, device, generator, latents
622+
)
623+
624+
if isinstance(generator, list) and len(generator) != batch_size:
625+
raise ValueError(
626+
f"You have passed a list of generators of length {len(generator)}, but requested an effective batch"
627+
f" size of {batch_size}. Make sure the batch size matches the length of the generators."
628+
)
629+
615630
shape = (
616631
batch_size,
617632
num_channels_latents,
618633
num_frames,
619634
height // self.vae_scale_factor,
620635
width // self.vae_scale_factor,
621636
)
622-
if isinstance(generator, list) and len(generator) != batch_size:
623-
raise ValueError(
624-
f"You have passed a list of generators of length {len(generator)}, but requested an effective batch"
625-
f" size of {batch_size}. Make sure the batch size matches the length of the generators."
626-
)
627637

628638
if latents is None:
629639
latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype)
@@ -718,7 +728,7 @@ def __call__(
718728
clip_skip: Optional[int] = None,
719729
callback_on_step_end: Optional[Callable[[int, int, Dict], None]] = None,
720730
callback_on_step_end_tensor_inputs: List[str] = ["latents"],
721-
decode_batch_size: int = 16,
731+
vae_batch_size: int = 16,
722732
):
723733
r"""
724734
The call function to the pipeline for generation.
@@ -1054,7 +1064,7 @@ def __call__(
10541064
if output_type == "latent":
10551065
video = latents
10561066
else:
1057-
video_tensor = self.decode_latents(latents, decode_batch_size)
1067+
video_tensor = self.decode_latents(latents, vae_batch_size)
10581068
video = self.video_processor.postprocess_video(video=video_tensor, output_type=output_type)
10591069

10601070
# 10. Offload all models

src/diffusers/pipelines/animatediff/pipeline_animatediff_video2video.py

+28-10
Original file line numberDiff line numberDiff line change
@@ -35,6 +35,7 @@
3535
from ...utils.torch_utils import randn_tensor
3636
from ...video_processor import VideoProcessor
3737
from ..free_init_utils import FreeInitMixin
38+
from ..free_noise_utils import AnimateDiffFreeNoiseMixin
3839
from ..pipeline_utils import DiffusionPipeline, StableDiffusionMixin
3940
from .pipeline_output import AnimateDiffPipelineOutput
4041

@@ -176,6 +177,7 @@ class AnimateDiffVideoToVideoPipeline(
176177
IPAdapterMixin,
177178
StableDiffusionLoraLoaderMixin,
178179
FreeInitMixin,
180+
AnimateDiffFreeNoiseMixin,
179181
):
180182
r"""
181183
Pipeline for video-to-video generation.
@@ -498,15 +500,29 @@ def prepare_ip_adapter_image_embeds(
498500

499501
return ip_adapter_image_embeds
500502

501-
# Copied from diffusers.pipelines.text_to_video_synthesis/pipeline_text_to_video_synth.TextToVideoSDPipeline.decode_latents
502-
def decode_latents(self, latents):
503+
def encode_video(self, video, generator, vae_batch_size: int = 16) -> torch.Tensor:
504+
latents = []
505+
for i in range(0, len(video), vae_batch_size):
506+
batch_video = video[i : i + vae_batch_size]
507+
batch_video = retrieve_latents(self.vae.encode(batch_video), generator=generator)
508+
latents.append(batch_video)
509+
return torch.cat(latents)
510+
511+
# Copied from diffusers.pipelines.animatediff.pipeline_animatediff.AnimateDiffPipeline.decode_latents
512+
def decode_latents(self, latents, vae_batch_size: int = 16):
503513
latents = 1 / self.vae.config.scaling_factor * latents
504514

505515
batch_size, channels, num_frames, height, width = latents.shape
506516
latents = latents.permute(0, 2, 1, 3, 4).reshape(batch_size * num_frames, channels, height, width)
507517

508-
image = self.vae.decode(latents).sample
509-
video = image[None, :].reshape((batch_size, num_frames, -1) + image.shape[2:]).permute(0, 2, 1, 3, 4)
518+
video = []
519+
for i in range(0, latents.shape[0], vae_batch_size):
520+
batch_latents = latents[i : i + vae_batch_size]
521+
batch_latents = self.vae.decode(batch_latents).sample
522+
video.append(batch_latents)
523+
524+
video = torch.cat(video)
525+
video = video[None, :].reshape((batch_size, num_frames, -1) + video.shape[2:]).permute(0, 2, 1, 3, 4)
510526
# we always cast to float32 as this does not cause significant overhead and is compatible with bfloat16
511527
video = video.float()
512528
return video
@@ -622,6 +638,7 @@ def prepare_latents(
622638
device,
623639
generator,
624640
latents=None,
641+
vae_batch_size: int = 16,
625642
):
626643
if latents is None:
627644
num_frames = video.shape[1]
@@ -656,13 +673,10 @@ def prepare_latents(
656673
)
657674

658675
init_latents = [
659-
retrieve_latents(self.vae.encode(video[i]), generator=generator[i]).unsqueeze(0)
660-
for i in range(batch_size)
676+
self.encode_video(video[i], generator[i], vae_batch_size).unsqueeze(0) for i in range(batch_size)
661677
]
662678
else:
663-
init_latents = [
664-
retrieve_latents(self.vae.encode(vid), generator=generator).unsqueeze(0) for vid in video
665-
]
679+
init_latents = [self.encode_video(vid, generator, vae_batch_size).unsqueeze(0) for vid in video]
666680

667681
init_latents = torch.cat(init_latents, dim=0)
668682

@@ -747,6 +761,7 @@ def __call__(
747761
clip_skip: Optional[int] = None,
748762
callback_on_step_end: Optional[Callable[[int, int, Dict], None]] = None,
749763
callback_on_step_end_tensor_inputs: List[str] = ["latents"],
764+
vae_batch_size: int = 16,
750765
):
751766
r"""
752767
The call function to the pipeline for generation.
@@ -822,6 +837,8 @@ def __call__(
822837
The list of tensor inputs for the `callback_on_step_end` function. The tensors specified in the list
823838
will be passed as `callback_kwargs` argument. You will only be able to include variables listed in the
824839
`._callback_tensor_inputs` attribute of your pipeline class.
840+
vae_batch_size (`int`, defaults to `16`):
841+
The number of frames to decode at a time when calling `decode_latents` method.
825842
826843
Examples:
827844
@@ -923,6 +940,7 @@ def __call__(
923940
device=device,
924941
generator=generator,
925942
latents=latents,
943+
vae_batch_size=vae_batch_size,
926944
)
927945

928946
# 6. Prepare extra step kwargs. TODO: Logic should ideally just be moved out of the pipeline
@@ -990,7 +1008,7 @@ def __call__(
9901008
if output_type == "latent":
9911009
video = latents
9921010
else:
993-
video_tensor = self.decode_latents(latents)
1011+
video_tensor = self.decode_latents(latents, vae_batch_size)
9941012
video = self.video_processor.postprocess_video(video=video_tensor, output_type=output_type)
9951013

9961014
# 10. Offload all models

0 commit comments

Comments
 (0)