Skip to content
New issue

Have a question about this project? # for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “#”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? # to your account

[refactor] enhance readability of flux related pipelines #9711

Merged
merged 4 commits into from
Oct 25, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -2198,8 +2198,8 @@ def get_sigmas(timesteps, n_dim=4, dtype=torch.float32):

latent_image_ids = FluxPipeline._prepare_latent_image_ids(
model_input.shape[0],
model_input.shape[2],
model_input.shape[3],
model_input.shape[2] // 2,
model_input.shape[3] // 2,
accelerator.device,
weight_dtype,
)
Expand Down Expand Up @@ -2253,8 +2253,8 @@ def get_sigmas(timesteps, n_dim=4, dtype=torch.float32):
)[0]
model_pred = FluxPipeline._unpack_latents(
model_pred,
height=int(model_input.shape[2] * vae_scale_factor / 2),
width=int(model_input.shape[3] * vae_scale_factor / 2),
height=model_input.shape[2] * vae_scale_factor,
width=model_input.shape[3] * vae_scale_factor,
vae_scale_factor=vae_scale_factor,
)

Expand Down
4 changes: 2 additions & 2 deletions examples/controlnet/train_controlnet_flux.py
Original file line number Diff line number Diff line change
Expand Up @@ -1256,8 +1256,8 @@ def get_sigmas(timesteps, n_dim=4, dtype=torch.float32):

latent_image_ids = FluxControlNetPipeline._prepare_latent_image_ids(
batch_size=pixel_latents_tmp.shape[0],
height=pixel_latents_tmp.shape[2],
width=pixel_latents_tmp.shape[3],
height=pixel_latents_tmp.shape[2] // 2,
width=pixel_latents_tmp.shape[3] // 2,
device=pixel_values.device,
dtype=pixel_values.dtype,
)
Expand Down
10 changes: 5 additions & 5 deletions examples/dreambooth/train_dreambooth_flux.py
Original file line number Diff line number Diff line change
Expand Up @@ -1540,12 +1540,12 @@ def get_sigmas(timesteps, n_dim=4, dtype=torch.float32):
model_input = (model_input - vae.config.shift_factor) * vae.config.scaling_factor
model_input = model_input.to(dtype=weight_dtype)

vae_scale_factor = 2 ** (len(vae.config.block_out_channels))
vae_scale_factor = 2 ** (len(vae.config.block_out_channels) - 1)

latent_image_ids = FluxPipeline._prepare_latent_image_ids(
model_input.shape[0],
model_input.shape[2],
model_input.shape[3],
model_input.shape[2] // 2,
model_input.shape[3] // 2,
accelerator.device,
weight_dtype,
)
Expand Down Expand Up @@ -1601,8 +1601,8 @@ def get_sigmas(timesteps, n_dim=4, dtype=torch.float32):
# upscaling height & width as discussed in https://github.com/huggingface/diffusers/pull/9257#discussion_r1731108042
model_pred = FluxPipeline._unpack_latents(
model_pred,
height=int(model_input.shape[2] * vae_scale_factor / 2),
width=int(model_input.shape[3] * vae_scale_factor / 2),
height=model_input.shape[2] * vae_scale_factor,
width=model_input.shape[3] * vae_scale_factor,
vae_scale_factor=vae_scale_factor,
)

Expand Down
10 changes: 5 additions & 5 deletions examples/dreambooth/train_dreambooth_lora_flux.py
Original file line number Diff line number Diff line change
Expand Up @@ -1645,12 +1645,12 @@ def get_sigmas(timesteps, n_dim=4, dtype=torch.float32):
model_input = (model_input - vae_config_shift_factor) * vae_config_scaling_factor
model_input = model_input.to(dtype=weight_dtype)

vae_scale_factor = 2 ** (len(vae_config_block_out_channels))
vae_scale_factor = 2 ** (len(vae_config_block_out_channels) - 1)

latent_image_ids = FluxPipeline._prepare_latent_image_ids(
model_input.shape[0],
model_input.shape[2],
model_input.shape[3],
model_input.shape[2] // 2,
model_input.shape[3] // 2,
accelerator.device,
weight_dtype,
)
Expand Down Expand Up @@ -1704,8 +1704,8 @@ def get_sigmas(timesteps, n_dim=4, dtype=torch.float32):
)[0]
model_pred = FluxPipeline._unpack_latents(
model_pred,
height=int(model_input.shape[2] * vae_scale_factor / 2),
width=int(model_input.shape[3] * vae_scale_factor / 2),
height=model_input.shape[2] * vae_scale_factor,
width=model_input.shape[3] * vae_scale_factor,
vae_scale_factor=vae_scale_factor,
)

Expand Down
26 changes: 14 additions & 12 deletions src/diffusers/pipelines/flux/pipeline_flux.py
Original file line number Diff line number Diff line change
Expand Up @@ -195,13 +195,13 @@ def __init__(
scheduler=scheduler,
)
self.vae_scale_factor = (
2 ** (len(self.vae.config.block_out_channels)) if hasattr(self, "vae") and self.vae is not None else 16
2 ** (len(self.vae.config.block_out_channels) - 1) if hasattr(self, "vae") and self.vae is not None else 8
)
self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor)
self.tokenizer_max_length = (
self.tokenizer.model_max_length if hasattr(self, "tokenizer") and self.tokenizer is not None else 77
)
self.default_sample_size = 64
self.default_sample_size = 128

def _get_t5_prompt_embeds(
self,
Expand Down Expand Up @@ -386,8 +386,10 @@ def check_inputs(
callback_on_step_end_tensor_inputs=None,
max_sequence_length=None,
):
if height % 8 != 0 or width % 8 != 0:
raise ValueError(f"`height` and `width` have to be divisible by 8 but are {height} and {width}.")
if height % self.vae_scale_factor != 0 or width % self.vae_scale_factor != 0:
raise ValueError(
f"`height` and `width` have to be divisible by {self.vae_scale_factor} but are {height} and {width}."
)

if callback_on_step_end_tensor_inputs is not None and not all(
k in self._callback_tensor_inputs for k in callback_on_step_end_tensor_inputs
Expand Down Expand Up @@ -425,9 +427,9 @@ def check_inputs(

@staticmethod
def _prepare_latent_image_ids(batch_size, height, width, device, dtype):
latent_image_ids = torch.zeros(height // 2, width // 2, 3)
latent_image_ids[..., 1] = latent_image_ids[..., 1] + torch.arange(height // 2)[:, None]
latent_image_ids[..., 2] = latent_image_ids[..., 2] + torch.arange(width // 2)[None, :]
latent_image_ids = torch.zeros(height, width, 3)
latent_image_ids[..., 1] = latent_image_ids[..., 1] + torch.arange(height)[:, None]
latent_image_ids[..., 2] = latent_image_ids[..., 2] + torch.arange(width)[None, :]

latent_image_id_height, latent_image_id_width, latent_image_id_channels = latent_image_ids.shape

Expand All @@ -452,10 +454,10 @@ def _unpack_latents(latents, height, width, vae_scale_factor):
height = height // vae_scale_factor
width = width // vae_scale_factor

latents = latents.view(batch_size, height, width, channels // 4, 2, 2)
latents = latents.view(batch_size, height // 2, width // 2, channels // 4, 2, 2)
latents = latents.permute(0, 3, 1, 4, 2, 5)

latents = latents.reshape(batch_size, channels // (2 * 2), height * 2, width * 2)
latents = latents.reshape(batch_size, channels // (2 * 2), height, width)

return latents

Expand Down Expand Up @@ -499,8 +501,8 @@ def prepare_latents(
generator,
latents=None,
):
height = 2 * (int(height) // self.vae_scale_factor)
width = 2 * (int(width) // self.vae_scale_factor)
height = int(height) // self.vae_scale_factor
width = int(width) // self.vae_scale_factor

shape = (batch_size, num_channels_latents, height, width)

Expand All @@ -517,7 +519,7 @@ def prepare_latents(
latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype)
latents = self._pack_latents(latents, batch_size, num_channels_latents, height, width)

latent_image_ids = self._prepare_latent_image_ids(batch_size, height, width, device, dtype)
latent_image_ids = self._prepare_latent_image_ids(batch_size, height // 2, width // 2, device, dtype)

return latents, latent_image_ids

Expand Down
26 changes: 14 additions & 12 deletions src/diffusers/pipelines/flux/pipeline_flux_controlnet.py
Original file line number Diff line number Diff line change
Expand Up @@ -216,13 +216,13 @@ def __init__(
controlnet=controlnet,
)
self.vae_scale_factor = (
2 ** (len(self.vae.config.block_out_channels)) if hasattr(self, "vae") and self.vae is not None else 16
2 ** (len(self.vae.config.block_out_channels) - 1) if hasattr(self, "vae") and self.vae is not None else 8
)
self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor)
self.tokenizer_max_length = (
self.tokenizer.model_max_length if hasattr(self, "tokenizer") and self.tokenizer is not None else 77
)
self.default_sample_size = 64
self.default_sample_size = 128

def _get_t5_prompt_embeds(
self,
Expand Down Expand Up @@ -410,8 +410,10 @@ def check_inputs(
callback_on_step_end_tensor_inputs=None,
max_sequence_length=None,
):
if height % 8 != 0 or width % 8 != 0:
raise ValueError(f"`height` and `width` have to be divisible by 8 but are {height} and {width}.")
if height % self.vae_scale_factor != 0 or width % self.vae_scale_factor != 0:
raise ValueError(
f"`height` and `width` have to be divisible by {self.vae_scale_factor} but are {height} and {width}."
)

if callback_on_step_end_tensor_inputs is not None and not all(
k in self._callback_tensor_inputs for k in callback_on_step_end_tensor_inputs
Expand Down Expand Up @@ -450,9 +452,9 @@ def check_inputs(
@staticmethod
# Copied from diffusers.pipelines.flux.pipeline_flux.FluxPipeline._prepare_latent_image_ids
def _prepare_latent_image_ids(batch_size, height, width, device, dtype):
latent_image_ids = torch.zeros(height // 2, width // 2, 3)
latent_image_ids[..., 1] = latent_image_ids[..., 1] + torch.arange(height // 2)[:, None]
latent_image_ids[..., 2] = latent_image_ids[..., 2] + torch.arange(width // 2)[None, :]
latent_image_ids = torch.zeros(height, width, 3)
latent_image_ids[..., 1] = latent_image_ids[..., 1] + torch.arange(height)[:, None]
latent_image_ids[..., 2] = latent_image_ids[..., 2] + torch.arange(width)[None, :]

latent_image_id_height, latent_image_id_width, latent_image_id_channels = latent_image_ids.shape

Expand All @@ -479,10 +481,10 @@ def _unpack_latents(latents, height, width, vae_scale_factor):
height = height // vae_scale_factor
width = width // vae_scale_factor

latents = latents.view(batch_size, height, width, channels // 4, 2, 2)
latents = latents.view(batch_size, height // 2, width // 2, channels // 4, 2, 2)
latents = latents.permute(0, 3, 1, 4, 2, 5)

latents = latents.reshape(batch_size, channels // (2 * 2), height * 2, width * 2)
latents = latents.reshape(batch_size, channels // (2 * 2), height, width)

return latents

Expand All @@ -498,8 +500,8 @@ def prepare_latents(
generator,
latents=None,
):
height = 2 * (int(height) // self.vae_scale_factor)
width = 2 * (int(width) // self.vae_scale_factor)
height = int(height) // self.vae_scale_factor
width = int(width) // self.vae_scale_factor

shape = (batch_size, num_channels_latents, height, width)

Expand All @@ -516,7 +518,7 @@ def prepare_latents(
latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype)
latents = self._pack_latents(latents, batch_size, num_channels_latents, height, width)

latent_image_ids = self._prepare_latent_image_ids(batch_size, height, width, device, dtype)
latent_image_ids = self._prepare_latent_image_ids(batch_size, height // 2, width // 2, device, dtype)

return latents, latent_image_ids

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -228,13 +228,13 @@ def __init__(
controlnet=controlnet,
)
self.vae_scale_factor = (
2 ** (len(self.vae.config.block_out_channels)) if hasattr(self, "vae") and self.vae is not None else 16
2 ** (len(self.vae.config.block_out_channels) - 1) if hasattr(self, "vae") and self.vae is not None else 8
)
self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor)
self.tokenizer_max_length = (
self.tokenizer.model_max_length if hasattr(self, "tokenizer") and self.tokenizer is not None else 77
)
self.default_sample_size = 64
self.default_sample_size = 128

# Copied from diffusers.pipelines.flux.pipeline_flux.FluxPipeline._get_t5_prompt_embeds
def _get_t5_prompt_embeds(
Expand Down Expand Up @@ -453,8 +453,10 @@ def check_inputs(
if strength < 0 or strength > 1:
raise ValueError(f"The value of strength should in [0.0, 1.0] but is {strength}")

if height % 8 != 0 or width % 8 != 0:
raise ValueError(f"`height` and `width` have to be divisible by 8 but are {height} and {width}.")
if height % self.vae_scale_factor != 0 or width % self.vae_scale_factor != 0:
raise ValueError(
f"`height` and `width` have to be divisible by {self.vae_scale_factor} but are {height} and {width}."
)

if callback_on_step_end_tensor_inputs is not None and not all(
k in self._callback_tensor_inputs for k in callback_on_step_end_tensor_inputs
Expand Down Expand Up @@ -493,9 +495,9 @@ def check_inputs(
@staticmethod
# Copied from diffusers.pipelines.flux.pipeline_flux.FluxPipeline._prepare_latent_image_ids
def _prepare_latent_image_ids(batch_size, height, width, device, dtype):
latent_image_ids = torch.zeros(height // 2, width // 2, 3)
latent_image_ids[..., 1] = latent_image_ids[..., 1] + torch.arange(height // 2)[:, None]
latent_image_ids[..., 2] = latent_image_ids[..., 2] + torch.arange(width // 2)[None, :]
latent_image_ids = torch.zeros(height, width, 3)
latent_image_ids[..., 1] = latent_image_ids[..., 1] + torch.arange(height)[:, None]
latent_image_ids[..., 2] = latent_image_ids[..., 2] + torch.arange(width)[None, :]

latent_image_id_height, latent_image_id_width, latent_image_id_channels = latent_image_ids.shape

Expand All @@ -522,10 +524,10 @@ def _unpack_latents(latents, height, width, vae_scale_factor):
height = height // vae_scale_factor
width = width // vae_scale_factor

latents = latents.view(batch_size, height, width, channels // 4, 2, 2)
latents = latents.view(batch_size, height // 2, width // 2, channels // 4, 2, 2)
latents = latents.permute(0, 3, 1, 4, 2, 5)

latents = latents.reshape(batch_size, channels // (2 * 2), height * 2, width * 2)
latents = latents.reshape(batch_size, channels // (2 * 2), height, width)

return latents

Expand All @@ -549,11 +551,11 @@ def prepare_latents(
f" size of {batch_size}. Make sure the batch size matches the length of the generators."
)

height = 2 * (int(height) // self.vae_scale_factor)
width = 2 * (int(width) // self.vae_scale_factor)
height = int(height) // self.vae_scale_factor
width = int(width) // self.vae_scale_factor

shape = (batch_size, num_channels_latents, height, width)
latent_image_ids = self._prepare_latent_image_ids(batch_size, height, width, device, dtype)
latent_image_ids = self._prepare_latent_image_ids(batch_size, height // 2, width // 2, device, dtype)

if latents is not None:
return latents.to(device=device, dtype=dtype), latent_image_ids
Expand Down Expand Up @@ -852,7 +854,7 @@ def __call__(
control_mode = control_mode.reshape([-1, 1])

sigmas = np.linspace(1.0, 1 / num_inference_steps, num_inference_steps)
image_seq_len = (int(height) // self.vae_scale_factor) * (int(width) // self.vae_scale_factor)
image_seq_len = (int(height) // self.vae_scale_factor // 2) * (int(width) // self.vae_scale_factor // 2)
mu = calculate_shift(
image_seq_len,
self.scheduler.config.base_image_seq_len,
Expand Down
Loading
Loading