diff --git a/src/diffusers/models/modeling_utils.py b/src/diffusers/models/modeling_utils.py index 1e9e28471d89..604394f08408 100644 --- a/src/diffusers/models/modeling_utils.py +++ b/src/diffusers/models/modeling_utils.py @@ -266,6 +266,7 @@ class ModelMixin(torch.nn.Module, PushToHubMixin): _keep_in_fp32_modules = None _skip_layerwise_casting_patterns = None _supports_group_offloading = True + _repeated_blocks = [] def __init__(self): super().__init__() @@ -1402,6 +1403,54 @@ def float(self, *args): else: return super().float(*args) + def compile_repeated_blocks(self, *args, **kwargs): + """ + Compiles *only* the frequently repeated sub-modules of a model (e.g. the + Transformer layers) instead of compiling the entire model. This + technique—often called **regional compilation** (see the PyTorch recipe + https://docs.pytorch.org/tutorials/recipes/regional_compilation.html) + can reduce end-to-end compile time substantially, while preserving the + runtime speed-ups you would expect from a full `torch.compile`. + + The set of sub-modules to compile is discovered in one of two ways: + + 1. **`_repeated_blocks`** – Preferred. Define this attribute on your + subclass as a list/tuple of class names (strings). Every module whose + class name matches will be compiled. + + 2. **`_no_split_modules`** – Fallback. If the preferred attribute is + missing or empty, we fall back to the legacy Diffusers attribute + `_no_split_modules`. + + Once discovered, each matching sub-module is compiled by calling + ``submodule.compile(*args, **kwargs)``. Any positional or keyword + arguments you supply to :meth:`compile_repeated_blocks` are forwarded + verbatim to `torch.compile`. + """ + repeated_blocks = getattr(self, "_repeated_blocks", None) + + if not repeated_blocks: + logger.warning("_repeated_blocks attribute is empty. Using _no_split_modules to find compile regions.") + + repeated_blocks = getattr(self, "_no_split_modules", None) + + if not repeated_blocks: + raise ValueError( + "Both _repeated_blocks and _no_split_modules attribute are empty. " + "Set _repeated_blocks for the model to benefit from faster compilation. " + ) + + has_compiled_region = False + for submod in self.modules(): + if submod.__class__.__name__ in repeated_blocks: + has_compiled_region = True + submod.compile(*args, **kwargs) + + if not has_compiled_region: + raise ValueError( + f"Regional compilation failed because {repeated_blocks} classes are not found in the model. " + ) + @classmethod def _load_pretrained_model( cls, diff --git a/src/diffusers/models/transformers/transformer_flux.py b/src/diffusers/models/transformers/transformer_flux.py index 541576b13b78..1b6c582d5d06 100644 --- a/src/diffusers/models/transformers/transformer_flux.py +++ b/src/diffusers/models/transformers/transformer_flux.py @@ -227,6 +227,7 @@ class FluxTransformer2DModel( _supports_gradient_checkpointing = True _no_split_modules = ["FluxTransformerBlock", "FluxSingleTransformerBlock"] _skip_layerwise_casting_patterns = ["pos_embed", "norm"] + _repeated_blocks = _no_split_modules @register_to_config def __init__( diff --git a/src/diffusers/models/transformers/transformer_wan.py b/src/diffusers/models/transformers/transformer_wan.py index baa0ede4184e..e89860de1627 100644 --- a/src/diffusers/models/transformers/transformer_wan.py +++ b/src/diffusers/models/transformers/transformer_wan.py @@ -345,6 +345,7 @@ class WanTransformer3DModel(ModelMixin, ConfigMixin, PeftAdapterMixin, FromOrigi _no_split_modules = ["WanTransformerBlock"] _keep_in_fp32_modules = ["time_embedder", "scale_shift_table", "norm1", "norm2", "norm3"] _keys_to_ignore_on_load_unexpected = ["norm_added_q"] + _repeated_blocks = _no_split_modules @register_to_config def __init__(