Skip to content

Commit 193ecaa

Browse files
yiyixuxusayakpaul
authored andcommitted
fix controlnet module refactor (#9968)
* fix
1 parent 6905951 commit 193ecaa

22 files changed

+272
-58
lines changed

src/diffusers/__init__.py

+2
Original file line numberDiff line numberDiff line change
@@ -107,6 +107,7 @@
107107
"ModelMixin",
108108
"MotionAdapter",
109109
"MultiAdapter",
110+
"MultiControlNetModel",
110111
"PixArtTransformer2DModel",
111112
"PriorTransformer",
112113
"SD3ControlNetModel",
@@ -592,6 +593,7 @@
592593
ModelMixin,
593594
MotionAdapter,
594595
MultiAdapter,
596+
MultiControlNetModel,
595597
PixArtTransformer2DModel,
596598
PriorTransformer,
597599
SD3ControlNetModel,

src/diffusers/models/controlnet.py

+79-6
Original file line numberDiff line numberDiff line change
@@ -11,9 +11,10 @@
1111
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
1212
# See the License for the specific language governing permissions and
1313
# limitations under the License.
14+
from typing import Optional, Tuple, Union
15+
1416
from ..utils import deprecate
1517
from .controlnets.controlnet import ( # noqa
16-
BaseOutput,
1718
ControlNetConditioningEmbedding,
1819
ControlNetModel,
1920
ControlNetOutput,
@@ -24,19 +25,91 @@
2425
class ControlNetOutput(ControlNetOutput):
2526
def __init__(self, *args, **kwargs):
2627
deprecation_message = "Importing `ControlNetOutput` from `diffusers.models.controlnet` is deprecated and this will be removed in a future version. Please use `from diffusers.models.controlnets.controlnet import ControlNetOutput`, instead."
27-
deprecate("ControlNetOutput", "0.34", deprecation_message)
28+
deprecate("diffusers.models.controlnet.ControlNetOutput", "0.34", deprecation_message)
2829
super().__init__(*args, **kwargs)
2930

3031

3132
class ControlNetModel(ControlNetModel):
32-
def __init__(self, *args, **kwargs):
33+
def __init__(
34+
self,
35+
in_channels: int = 4,
36+
conditioning_channels: int = 3,
37+
flip_sin_to_cos: bool = True,
38+
freq_shift: int = 0,
39+
down_block_types: Tuple[str, ...] = (
40+
"CrossAttnDownBlock2D",
41+
"CrossAttnDownBlock2D",
42+
"CrossAttnDownBlock2D",
43+
"DownBlock2D",
44+
),
45+
mid_block_type: Optional[str] = "UNetMidBlock2DCrossAttn",
46+
only_cross_attention: Union[bool, Tuple[bool]] = False,
47+
block_out_channels: Tuple[int, ...] = (320, 640, 1280, 1280),
48+
layers_per_block: int = 2,
49+
downsample_padding: int = 1,
50+
mid_block_scale_factor: float = 1,
51+
act_fn: str = "silu",
52+
norm_num_groups: Optional[int] = 32,
53+
norm_eps: float = 1e-5,
54+
cross_attention_dim: int = 1280,
55+
transformer_layers_per_block: Union[int, Tuple[int, ...]] = 1,
56+
encoder_hid_dim: Optional[int] = None,
57+
encoder_hid_dim_type: Optional[str] = None,
58+
attention_head_dim: Union[int, Tuple[int, ...]] = 8,
59+
num_attention_heads: Optional[Union[int, Tuple[int, ...]]] = None,
60+
use_linear_projection: bool = False,
61+
class_embed_type: Optional[str] = None,
62+
addition_embed_type: Optional[str] = None,
63+
addition_time_embed_dim: Optional[int] = None,
64+
num_class_embeds: Optional[int] = None,
65+
upcast_attention: bool = False,
66+
resnet_time_scale_shift: str = "default",
67+
projection_class_embeddings_input_dim: Optional[int] = None,
68+
controlnet_conditioning_channel_order: str = "rgb",
69+
conditioning_embedding_out_channels: Optional[Tuple[int, ...]] = (16, 32, 96, 256),
70+
global_pool_conditions: bool = False,
71+
addition_embed_type_num_heads: int = 64,
72+
):
3373
deprecation_message = "Importing `ControlNetModel` from `diffusers.models.controlnet` is deprecated and this will be removed in a future version. Please use `from diffusers.models.controlnets.controlnet import ControlNetModel`, instead."
34-
deprecate("ControlNetModel", "0.34", deprecation_message)
35-
super().__init__(*args, **kwargs)
74+
deprecate("diffusers.models.controlnet.ControlNetModel", "0.34", deprecation_message)
75+
super().__init__(
76+
in_channels=in_channels,
77+
conditioning_channels=conditioning_channels,
78+
flip_sin_to_cos=flip_sin_to_cos,
79+
freq_shift=freq_shift,
80+
down_block_types=down_block_types,
81+
mid_block_type=mid_block_type,
82+
only_cross_attention=only_cross_attention,
83+
block_out_channels=block_out_channels,
84+
layers_per_block=layers_per_block,
85+
downsample_padding=downsample_padding,
86+
mid_block_scale_factor=mid_block_scale_factor,
87+
act_fn=act_fn,
88+
norm_num_groups=norm_num_groups,
89+
norm_eps=norm_eps,
90+
cross_attention_dim=cross_attention_dim,
91+
transformer_layers_per_block=transformer_layers_per_block,
92+
encoder_hid_dim=encoder_hid_dim,
93+
encoder_hid_dim_type=encoder_hid_dim_type,
94+
attention_head_dim=attention_head_dim,
95+
num_attention_heads=num_attention_heads,
96+
use_linear_projection=use_linear_projection,
97+
class_embed_type=class_embed_type,
98+
addition_embed_type=addition_embed_type,
99+
addition_time_embed_dim=addition_time_embed_dim,
100+
num_class_embeds=num_class_embeds,
101+
upcast_attention=upcast_attention,
102+
resnet_time_scale_shift=resnet_time_scale_shift,
103+
projection_class_embeddings_input_dim=projection_class_embeddings_input_dim,
104+
controlnet_conditioning_channel_order=controlnet_conditioning_channel_order,
105+
conditioning_embedding_out_channels=conditioning_embedding_out_channels,
106+
global_pool_conditions=global_pool_conditions,
107+
addition_embed_type_num_heads=addition_embed_type_num_heads,
108+
)
36109

37110

38111
class ControlNetConditioningEmbedding(ControlNetConditioningEmbedding):
39112
def __init__(self, *args, **kwargs):
40113
deprecation_message = "Importing `ControlNetConditioningEmbedding` from `diffusers.models.controlnet` is deprecated and this will be removed in a future version. Please use `from diffusers.models.controlnets.controlnet import ControlNetConditioningEmbedding`, instead."
41-
deprecate("ControlNetConditioningEmbedding", "0.34", deprecation_message)
114+
deprecate("diffusers.models.controlnet.ControlNetConditioningEmbedding", "0.34", deprecation_message)
42115
super().__init__(*args, **kwargs)

src/diffusers/models/controlnet_flux.py

+34-5
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,8 @@
1313
# limitations under the License.
1414

1515

16+
from typing import List
17+
1618
from ..utils import deprecate, logging
1719
from .controlnets.controlnet_flux import FluxControlNetModel, FluxControlNetOutput, FluxMultiControlNetModel
1820

@@ -23,19 +25,46 @@
2325
class FluxControlNetOutput(FluxControlNetOutput):
2426
def __init__(self, *args, **kwargs):
2527
deprecation_message = "Importing `FluxControlNetOutput` from `diffusers.models.controlnet_flux` is deprecated and this will be removed in a future version. Please use `from diffusers.models.controlnets.controlnet_flux import FluxControlNetOutput`, instead."
26-
deprecate("FluxControlNetOutput", "0.34", deprecation_message)
28+
deprecate("diffusers.models.controlnet_flux.FluxControlNetOutput", "0.34", deprecation_message)
2729
super().__init__(*args, **kwargs)
2830

2931

3032
class FluxControlNetModel(FluxControlNetModel):
31-
def __init__(self, *args, **kwargs):
33+
def __init__(
34+
self,
35+
patch_size: int = 1,
36+
in_channels: int = 64,
37+
num_layers: int = 19,
38+
num_single_layers: int = 38,
39+
attention_head_dim: int = 128,
40+
num_attention_heads: int = 24,
41+
joint_attention_dim: int = 4096,
42+
pooled_projection_dim: int = 768,
43+
guidance_embeds: bool = False,
44+
axes_dims_rope: List[int] = [16, 56, 56],
45+
num_mode: int = None,
46+
conditioning_embedding_channels: int = None,
47+
):
3248
deprecation_message = "Importing `FluxControlNetModel` from `diffusers.models.controlnet_flux` is deprecated and this will be removed in a future version. Please use `from diffusers.models.controlnets.controlnet_flux import FluxControlNetModel`, instead."
33-
deprecate("FluxControlNetModel", "0.34", deprecation_message)
34-
super().__init__(*args, **kwargs)
49+
deprecate("diffusers.models.controlnet_flux.FluxControlNetModel", "0.34", deprecation_message)
50+
super().__init__(
51+
patch_size=patch_size,
52+
in_channels=in_channels,
53+
num_layers=num_layers,
54+
num_single_layers=num_single_layers,
55+
attention_head_dim=attention_head_dim,
56+
num_attention_heads=num_attention_heads,
57+
joint_attention_dim=joint_attention_dim,
58+
pooled_projection_dim=pooled_projection_dim,
59+
guidance_embeds=guidance_embeds,
60+
axes_dims_rope=axes_dims_rope,
61+
num_mode=num_mode,
62+
conditioning_embedding_channels=conditioning_embedding_channels,
63+
)
3564

3665

3766
class FluxMultiControlNetModel(FluxMultiControlNetModel):
3867
def __init__(self, *args, **kwargs):
3968
deprecation_message = "Importing `FluxMultiControlNetModel` from `diffusers.models.controlnet_flux` is deprecated and this will be removed in a future version. Please use `from diffusers.models.controlnets.controlnet_flux import FluxMultiControlNetModel`, instead."
40-
deprecate("FluxMultiControlNetModel", "0.34", deprecation_message)
69+
deprecate("diffusers.models.controlnet_flux.FluxMultiControlNetModel", "0.34", deprecation_message)
4170
super().__init__(*args, **kwargs)

src/diffusers/models/controlnet_sd3.py

+32-5
Original file line numberDiff line numberDiff line change
@@ -23,19 +23,46 @@
2323
class SD3ControlNetOutput(SD3ControlNetOutput):
2424
def __init__(self, *args, **kwargs):
2525
deprecation_message = "Importing `SD3ControlNetOutput` from `diffusers.models.controlnet_sd3` is deprecated and this will be removed in a future version. Please use `from diffusers.models.controlnets.controlnet_sd3 import SD3ControlNetOutput`, instead."
26-
deprecate("SD3ControlNetOutput", "0.34", deprecation_message)
26+
deprecate("diffusers.models.controlnet_sd3.SD3ControlNetOutput", "0.34", deprecation_message)
2727
super().__init__(*args, **kwargs)
2828

2929

3030
class SD3ControlNetModel(SD3ControlNetModel):
31-
def __init__(self, *args, **kwargs):
31+
def __init__(
32+
self,
33+
sample_size: int = 128,
34+
patch_size: int = 2,
35+
in_channels: int = 16,
36+
num_layers: int = 18,
37+
attention_head_dim: int = 64,
38+
num_attention_heads: int = 18,
39+
joint_attention_dim: int = 4096,
40+
caption_projection_dim: int = 1152,
41+
pooled_projection_dim: int = 2048,
42+
out_channels: int = 16,
43+
pos_embed_max_size: int = 96,
44+
extra_conditioning_channels: int = 0,
45+
):
3246
deprecation_message = "Importing `SD3ControlNetModel` from `diffusers.models.controlnet_sd3` is deprecated and this will be removed in a future version. Please use `from diffusers.models.controlnets.controlnet_sd3 import SD3ControlNetModel`, instead."
33-
deprecate("SD3ControlNetModel", "0.34", deprecation_message)
34-
super().__init__(*args, **kwargs)
47+
deprecate("diffusers.models.controlnet_sd3.SD3ControlNetModel", "0.34", deprecation_message)
48+
super().__init__(
49+
sample_size=sample_size,
50+
patch_size=patch_size,
51+
in_channels=in_channels,
52+
num_layers=num_layers,
53+
attention_head_dim=attention_head_dim,
54+
num_attention_heads=num_attention_heads,
55+
joint_attention_dim=joint_attention_dim,
56+
caption_projection_dim=caption_projection_dim,
57+
pooled_projection_dim=pooled_projection_dim,
58+
out_channels=out_channels,
59+
pos_embed_max_size=pos_embed_max_size,
60+
extra_conditioning_channels=extra_conditioning_channels,
61+
)
3562

3663

3764
class SD3MultiControlNetModel(SD3MultiControlNetModel):
3865
def __init__(self, *args, **kwargs):
3966
deprecation_message = "Importing `SD3MultiControlNetModel` from `diffusers.models.controlnet_sd3` is deprecated and this will be removed in a future version. Please use `from diffusers.models.controlnets.controlnet_sd3 import SD3MultiControlNetModel`, instead."
40-
deprecate("SD3MultiControlNetModel", "0.34", deprecation_message)
67+
deprecate("diffusers.models.controlnet_sd3.SD3MultiControlNetModel", "0.34", deprecation_message)
4168
super().__init__(*args, **kwargs)

src/diffusers/models/controlnet_sparsectrl.py

+75-5
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,8 @@
1313
# limitations under the License.
1414

1515

16+
from typing import Optional, Tuple, Union
17+
1618
from ..utils import deprecate, logging
1719
from .controlnets.controlnet_sparsectrl import ( # noqa
1820
SparseControlNetConditioningEmbedding,
@@ -28,19 +30,87 @@
2830
class SparseControlNetOutput(SparseControlNetOutput):
2931
def __init__(self, *args, **kwargs):
3032
deprecation_message = "Importing `SparseControlNetOutput` from `diffusers.models.controlnet_sparsectrl` is deprecated and this will be removed in a future version. Please use `from diffusers.models.controlnets.controlnet_sparsectrl import SparseControlNetOutput`, instead."
31-
deprecate("SparseControlNetOutput", "0.34", deprecation_message)
33+
deprecate("diffusers.models.controlnet_sparsectrl.SparseControlNetOutput", "0.34", deprecation_message)
3234
super().__init__(*args, **kwargs)
3335

3436

3537
class SparseControlNetConditioningEmbedding(SparseControlNetConditioningEmbedding):
3638
def __init__(self, *args, **kwargs):
3739
deprecation_message = "Importing `SparseControlNetConditioningEmbedding` from `diffusers.models.controlnet_sparsectrl` is deprecated and this will be removed in a future version. Please use `from diffusers.models.controlnets.controlnet_sparsectrl import SparseControlNetConditioningEmbedding`, instead."
38-
deprecate("SparseControlNetConditioningEmbedding", "0.34", deprecation_message)
40+
deprecate(
41+
"diffusers.models.controlnet_sparsectrl.SparseControlNetConditioningEmbedding", "0.34", deprecation_message
42+
)
3943
super().__init__(*args, **kwargs)
4044

4145

4246
class SparseControlNetModel(SparseControlNetModel):
43-
def __init__(self, *args, **kwargs):
47+
def __init__(
48+
self,
49+
in_channels: int = 4,
50+
conditioning_channels: int = 4,
51+
flip_sin_to_cos: bool = True,
52+
freq_shift: int = 0,
53+
down_block_types: Tuple[str, ...] = (
54+
"CrossAttnDownBlockMotion",
55+
"CrossAttnDownBlockMotion",
56+
"CrossAttnDownBlockMotion",
57+
"DownBlockMotion",
58+
),
59+
only_cross_attention: Union[bool, Tuple[bool]] = False,
60+
block_out_channels: Tuple[int, ...] = (320, 640, 1280, 1280),
61+
layers_per_block: int = 2,
62+
downsample_padding: int = 1,
63+
mid_block_scale_factor: float = 1,
64+
act_fn: str = "silu",
65+
norm_num_groups: Optional[int] = 32,
66+
norm_eps: float = 1e-5,
67+
cross_attention_dim: int = 768,
68+
transformer_layers_per_block: Union[int, Tuple[int, ...]] = 1,
69+
transformer_layers_per_mid_block: Optional[Union[int, Tuple[int]]] = None,
70+
temporal_transformer_layers_per_block: Union[int, Tuple[int, ...]] = 1,
71+
attention_head_dim: Union[int, Tuple[int, ...]] = 8,
72+
num_attention_heads: Optional[Union[int, Tuple[int, ...]]] = None,
73+
use_linear_projection: bool = False,
74+
upcast_attention: bool = False,
75+
resnet_time_scale_shift: str = "default",
76+
conditioning_embedding_out_channels: Optional[Tuple[int, ...]] = (16, 32, 96, 256),
77+
global_pool_conditions: bool = False,
78+
controlnet_conditioning_channel_order: str = "rgb",
79+
motion_max_seq_length: int = 32,
80+
motion_num_attention_heads: int = 8,
81+
concat_conditioning_mask: bool = True,
82+
use_simplified_condition_embedding: bool = True,
83+
):
4484
deprecation_message = "Importing `SparseControlNetModel` from `diffusers.models.controlnet_sparsectrl` is deprecated and this will be removed in a future version. Please use `from diffusers.models.controlnets.controlnet_sparsectrl import SparseControlNetModel`, instead."
45-
deprecate("SparseControlNetModel", "0.34", deprecation_message)
46-
super().__init__(*args, **kwargs)
85+
deprecate("diffusers.models.controlnet_sparsectrl.SparseControlNetModel", "0.34", deprecation_message)
86+
super().__init__(
87+
in_channels=in_channels,
88+
conditioning_channels=conditioning_channels,
89+
flip_sin_to_cos=flip_sin_to_cos,
90+
freq_shift=freq_shift,
91+
down_block_types=down_block_types,
92+
only_cross_attention=only_cross_attention,
93+
block_out_channels=block_out_channels,
94+
layers_per_block=layers_per_block,
95+
downsample_padding=downsample_padding,
96+
mid_block_scale_factor=mid_block_scale_factor,
97+
act_fn=act_fn,
98+
norm_num_groups=norm_num_groups,
99+
norm_eps=norm_eps,
100+
cross_attention_dim=cross_attention_dim,
101+
transformer_layers_per_block=transformer_layers_per_block,
102+
transformer_layers_per_mid_block=transformer_layers_per_mid_block,
103+
temporal_transformer_layers_per_block=temporal_transformer_layers_per_block,
104+
attention_head_dim=attention_head_dim,
105+
num_attention_heads=num_attention_heads,
106+
use_linear_projection=use_linear_projection,
107+
upcast_attention=upcast_attention,
108+
resnet_time_scale_shift=resnet_time_scale_shift,
109+
conditioning_embedding_out_channels=conditioning_embedding_out_channels,
110+
global_pool_conditions=global_pool_conditions,
111+
controlnet_conditioning_channel_order=controlnet_conditioning_channel_order,
112+
motion_max_seq_length=motion_max_seq_length,
113+
motion_num_attention_heads=motion_num_attention_heads,
114+
concat_conditioning_mask=concat_conditioning_mask,
115+
use_simplified_condition_embedding=use_simplified_condition_embedding,
116+
)

src/diffusers/models/controlnets/controlnet_flux.py

+4-4
Original file line numberDiff line numberDiff line change
@@ -22,8 +22,8 @@
2222
from ...loaders import PeftAdapterMixin
2323
from ...models.attention_processor import AttentionProcessor
2424
from ...models.modeling_utils import ModelMixin
25-
from ...utils import USE_PEFT_BACKEND, is_torch_version, logging, scale_lora_layers, unscale_lora_layers
26-
from ..controlnet import BaseOutput, ControlNetConditioningEmbedding, zero_module
25+
from ...utils import USE_PEFT_BACKEND, BaseOutput, is_torch_version, logging, scale_lora_layers, unscale_lora_layers
26+
from ..controlnets.controlnet import ControlNetConditioningEmbedding, zero_module
2727
from ..embeddings import CombinedTimestepGuidanceTextProjEmbeddings, CombinedTimestepTextProjEmbeddings, FluxPosEmbed
2828
from ..modeling_outputs import Transformer2DModelOutput
2929
from ..transformers.transformer_flux import FluxSingleTransformerBlock, FluxTransformerBlock
@@ -192,13 +192,13 @@ def from_transformer(
192192
num_attention_heads: int = 24,
193193
load_weights_from_transformer=True,
194194
):
195-
config = transformer.config
195+
config = dict(transformer.config)
196196
config["num_layers"] = num_layers
197197
config["num_single_layers"] = num_single_layers
198198
config["attention_head_dim"] = attention_head_dim
199199
config["num_attention_heads"] = num_attention_heads
200200

201-
controlnet = cls(**config)
201+
controlnet = cls.from_config(config)
202202

203203
if load_weights_from_transformer:
204204
controlnet.pos_embed.load_state_dict(transformer.pos_embed.state_dict())

src/diffusers/models/controlnets/controlnet_hunyuan.py

+2-2
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,7 @@
1818
from torch import nn
1919

2020
from ...configuration_utils import ConfigMixin, register_to_config
21-
from ...utils import logging
21+
from ...utils import BaseOutput, logging
2222
from ..attention_processor import AttentionProcessor
2323
from ..embeddings import (
2424
HunyuanCombinedTimestepTextSizeStyleEmbedding,
@@ -27,7 +27,7 @@
2727
)
2828
from ..modeling_utils import ModelMixin
2929
from ..transformers.hunyuan_transformer_2d import HunyuanDiTBlock
30-
from .controlnet import BaseOutput, Tuple, zero_module
30+
from .controlnet import Tuple, zero_module
3131

3232

3333
logger = logging.get_logger(__name__) # pylint: disable=invalid-name

0 commit comments

Comments
 (0)