@@ -932,19 +932,22 @@ def __call__(
932
932
)
933
933
height , width = control_image .shape [- 2 :]
934
934
935
- # vae encode
936
- control_image = self .vae .encode (control_image ).latent_dist .sample ()
937
- control_image = (control_image - self .vae .config .shift_factor ) * self .vae .config .scaling_factor
938
-
939
- # pack
940
- height_control_image , width_control_image = control_image .shape [2 :]
941
- control_image = self ._pack_latents (
942
- control_image ,
943
- batch_size * num_images_per_prompt ,
944
- num_channels_latents ,
945
- height_control_image ,
946
- width_control_image ,
947
- )
935
+ # xlab controlnet has a input_hint_block and instantx controlnet does not
936
+ controlnet_blocks_repeat = False if self .controlnet .input_hint_block is None else True
937
+ if self .controlnet .input_hint_block is None :
938
+ # vae encode
939
+ control_image = self .vae .encode (control_image ).latent_dist .sample ()
940
+ control_image = (control_image - self .vae .config .shift_factor ) * self .vae .config .scaling_factor
941
+
942
+ # pack
943
+ height_control_image , width_control_image = control_image .shape [2 :]
944
+ control_image = self ._pack_latents (
945
+ control_image ,
946
+ batch_size * num_images_per_prompt ,
947
+ num_channels_latents ,
948
+ height_control_image ,
949
+ width_control_image ,
950
+ )
948
951
949
952
# set control mode
950
953
if control_mode is not None :
@@ -954,7 +957,9 @@ def __call__(
954
957
elif isinstance (self .controlnet , FluxMultiControlNetModel ):
955
958
control_images = []
956
959
957
- for control_image_ in control_image :
960
+ # xlab controlnet has a input_hint_block and instantx controlnet does not
961
+ controlnet_blocks_repeat = False if self .controlnet .nets [0 ].input_hint_block is None else True
962
+ for i , control_image_ in enumerate (control_image ):
958
963
control_image_ = self .prepare_image (
959
964
image = control_image_ ,
960
965
width = width ,
@@ -966,19 +971,20 @@ def __call__(
966
971
)
967
972
height , width = control_image_ .shape [- 2 :]
968
973
969
- # vae encode
970
- control_image_ = self .vae .encode (control_image_ ).latent_dist .sample ()
971
- control_image_ = (control_image_ - self .vae .config .shift_factor ) * self .vae .config .scaling_factor
972
-
973
- # pack
974
- height_control_image , width_control_image = control_image_ .shape [2 :]
975
- control_image_ = self ._pack_latents (
976
- control_image_ ,
977
- batch_size * num_images_per_prompt ,
978
- num_channels_latents ,
979
- height_control_image ,
980
- width_control_image ,
981
- )
974
+ if self .controlnet .nets [0 ].input_hint_block is None :
975
+ # vae encode
976
+ control_image_ = self .vae .encode (control_image_ ).latent_dist .sample ()
977
+ control_image_ = (control_image_ - self .vae .config .shift_factor ) * self .vae .config .scaling_factor
978
+
979
+ # pack
980
+ height_control_image , width_control_image = control_image_ .shape [2 :]
981
+ control_image_ = self ._pack_latents (
982
+ control_image_ ,
983
+ batch_size * num_images_per_prompt ,
984
+ num_channels_latents ,
985
+ height_control_image ,
986
+ width_control_image ,
987
+ )
982
988
983
989
control_images .append (control_image_ )
984
990
@@ -1129,6 +1135,7 @@ def __call__(
1129
1135
img_ids = latent_image_ids ,
1130
1136
joint_attention_kwargs = self .joint_attention_kwargs ,
1131
1137
return_dict = False ,
1138
+ controlnet_blocks_repeat = controlnet_blocks_repeat ,
1132
1139
)[0 ]
1133
1140
1134
1141
# compute the previous noisy sample x_t -> x_t-1
0 commit comments