Skip to content

[rfc][compile] compile method for DiffusionPipeline #11705

New issue

Have a question about this project? # for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “#”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? # to your account

Open
wants to merge 1 commit into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
49 changes: 49 additions & 0 deletions src/diffusers/models/modeling_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -266,6 +266,7 @@ class ModelMixin(torch.nn.Module, PushToHubMixin):
_keep_in_fp32_modules = None
_skip_layerwise_casting_patterns = None
_supports_group_offloading = True
_repeated_blocks = []

def __init__(self):
super().__init__()
Expand Down Expand Up @@ -1402,6 +1403,54 @@ def float(self, *args):
else:
return super().float(*args)

def compile_repeated_blocks(self, *args, **kwargs):
"""
Compiles *only* the frequently repeated sub-modules of a model (e.g. the
Transformer layers) instead of compiling the entire model. This
technique—often called **regional compilation** (see the PyTorch recipe
https://docs.pytorch.org/tutorials/recipes/regional_compilation.html)
can reduce end-to-end compile time substantially, while preserving the
runtime speed-ups you would expect from a full `torch.compile`.

The set of sub-modules to compile is discovered in one of two ways:

1. **`_repeated_blocks`** – Preferred. Define this attribute on your
subclass as a list/tuple of class names (strings). Every module whose
class name matches will be compiled.

2. **`_no_split_modules`** – Fallback. If the preferred attribute is
missing or empty, we fall back to the legacy Diffusers attribute
`_no_split_modules`.

Once discovered, each matching sub-module is compiled by calling
``submodule.compile(*args, **kwargs)``. Any positional or keyword
arguments you supply to :meth:`compile_repeated_blocks` are forwarded
verbatim to `torch.compile`.
"""
repeated_blocks = getattr(self, "_repeated_blocks", None)

if not repeated_blocks:
logger.warning("_repeated_blocks attribute is empty. Using _no_split_modules to find compile regions.")

repeated_blocks = getattr(self, "_no_split_modules", None)

if not repeated_blocks:
raise ValueError(
"Both _repeated_blocks and _no_split_modules attribute are empty. "
"Set _repeated_blocks for the model to benefit from faster compilation. "
)

has_compiled_region = False
for submod in self.modules():
if submod.__class__.__name__ in repeated_blocks:
has_compiled_region = True
submod.compile(*args, **kwargs)

if not has_compiled_region:
raise ValueError(
f"Regional compilation failed because {repeated_blocks} classes are not found in the model. "
)

@classmethod
def _load_pretrained_model(
cls,
Expand Down
1 change: 1 addition & 0 deletions src/diffusers/models/transformers/transformer_flux.py
Original file line number Diff line number Diff line change
Expand Up @@ -227,6 +227,7 @@ class FluxTransformer2DModel(
_supports_gradient_checkpointing = True
_no_split_modules = ["FluxTransformerBlock", "FluxSingleTransformerBlock"]
_skip_layerwise_casting_patterns = ["pos_embed", "norm"]
_repeated_blocks = _no_split_modules

@register_to_config
def __init__(
Expand Down
1 change: 1 addition & 0 deletions src/diffusers/models/transformers/transformer_wan.py
Original file line number Diff line number Diff line change
Expand Up @@ -345,6 +345,7 @@ class WanTransformer3DModel(ModelMixin, ConfigMixin, PeftAdapterMixin, FromOrigi
_no_split_modules = ["WanTransformerBlock"]
_keep_in_fp32_modules = ["time_embedder", "scale_shift_table", "norm1", "norm2", "norm3"]
_keys_to_ignore_on_load_unexpected = ["norm_added_q"]
_repeated_blocks = _no_split_modules

@register_to_config
def __init__(
Expand Down