Skip to content

Commit 298ab6e

Browse files
authored
Added Support of Xlabs controlnet to FluxControlNetInpaintPipeline (#9770)
* added xlabs support
1 parent 73b59f5 commit 298ab6e

File tree

1 file changed

+34
-27
lines changed

1 file changed

+34
-27
lines changed

src/diffusers/pipelines/flux/pipeline_flux_controlnet_inpainting.py

+34-27
Original file line numberDiff line numberDiff line change
@@ -932,19 +932,22 @@ def __call__(
932932
)
933933
height, width = control_image.shape[-2:]
934934

935-
# vae encode
936-
control_image = self.vae.encode(control_image).latent_dist.sample()
937-
control_image = (control_image - self.vae.config.shift_factor) * self.vae.config.scaling_factor
938-
939-
# pack
940-
height_control_image, width_control_image = control_image.shape[2:]
941-
control_image = self._pack_latents(
942-
control_image,
943-
batch_size * num_images_per_prompt,
944-
num_channels_latents,
945-
height_control_image,
946-
width_control_image,
947-
)
935+
# xlab controlnet has a input_hint_block and instantx controlnet does not
936+
controlnet_blocks_repeat = False if self.controlnet.input_hint_block is None else True
937+
if self.controlnet.input_hint_block is None:
938+
# vae encode
939+
control_image = self.vae.encode(control_image).latent_dist.sample()
940+
control_image = (control_image - self.vae.config.shift_factor) * self.vae.config.scaling_factor
941+
942+
# pack
943+
height_control_image, width_control_image = control_image.shape[2:]
944+
control_image = self._pack_latents(
945+
control_image,
946+
batch_size * num_images_per_prompt,
947+
num_channels_latents,
948+
height_control_image,
949+
width_control_image,
950+
)
948951

949952
# set control mode
950953
if control_mode is not None:
@@ -954,7 +957,9 @@ def __call__(
954957
elif isinstance(self.controlnet, FluxMultiControlNetModel):
955958
control_images = []
956959

957-
for control_image_ in control_image:
960+
# xlab controlnet has a input_hint_block and instantx controlnet does not
961+
controlnet_blocks_repeat = False if self.controlnet.nets[0].input_hint_block is None else True
962+
for i, control_image_ in enumerate(control_image):
958963
control_image_ = self.prepare_image(
959964
image=control_image_,
960965
width=width,
@@ -966,19 +971,20 @@ def __call__(
966971
)
967972
height, width = control_image_.shape[-2:]
968973

969-
# vae encode
970-
control_image_ = self.vae.encode(control_image_).latent_dist.sample()
971-
control_image_ = (control_image_ - self.vae.config.shift_factor) * self.vae.config.scaling_factor
972-
973-
# pack
974-
height_control_image, width_control_image = control_image_.shape[2:]
975-
control_image_ = self._pack_latents(
976-
control_image_,
977-
batch_size * num_images_per_prompt,
978-
num_channels_latents,
979-
height_control_image,
980-
width_control_image,
981-
)
974+
if self.controlnet.nets[0].input_hint_block is None:
975+
# vae encode
976+
control_image_ = self.vae.encode(control_image_).latent_dist.sample()
977+
control_image_ = (control_image_ - self.vae.config.shift_factor) * self.vae.config.scaling_factor
978+
979+
# pack
980+
height_control_image, width_control_image = control_image_.shape[2:]
981+
control_image_ = self._pack_latents(
982+
control_image_,
983+
batch_size * num_images_per_prompt,
984+
num_channels_latents,
985+
height_control_image,
986+
width_control_image,
987+
)
982988

983989
control_images.append(control_image_)
984990

@@ -1129,6 +1135,7 @@ def __call__(
11291135
img_ids=latent_image_ids,
11301136
joint_attention_kwargs=self.joint_attention_kwargs,
11311137
return_dict=False,
1138+
controlnet_blocks_repeat=controlnet_blocks_repeat,
11321139
)[0]
11331140

11341141
# compute the previous noisy sample x_t -> x_t-1

0 commit comments

Comments
 (0)