From 27df73ea4a06a6c45313d4915a6867266005ce1f Mon Sep 17 00:00:00 2001 From: sayakpaul <spsayakpaul@gmail.com> Date: Thu, 31 Oct 2024 14:45:48 +0530 Subject: [PATCH] reduce explicit device transfers and typecasting in flux. --- src/diffusers/pipelines/flux/pipeline_flux.py | 6 +++--- src/diffusers/pipelines/flux/pipeline_flux_controlnet.py | 4 ++-- .../flux/pipeline_flux_controlnet_image_to_image.py | 6 +++--- .../pipelines/flux/pipeline_flux_controlnet_inpainting.py | 6 +++--- src/diffusers/pipelines/flux/pipeline_flux_img2img.py | 6 +++--- src/diffusers/pipelines/flux/pipeline_flux_inpaint.py | 6 +++--- 6 files changed, 17 insertions(+), 17 deletions(-) diff --git a/src/diffusers/pipelines/flux/pipeline_flux.py b/src/diffusers/pipelines/flux/pipeline_flux.py index 040d935f1b88..ab4e0fc4d255 100644 --- a/src/diffusers/pipelines/flux/pipeline_flux.py +++ b/src/diffusers/pipelines/flux/pipeline_flux.py @@ -371,7 +371,7 @@ def encode_prompt( unscale_lora_layers(self.text_encoder_2, lora_scale) dtype = self.text_encoder.dtype if self.text_encoder is not None else self.transformer.dtype - text_ids = torch.zeros(prompt_embeds.shape[1], 3).to(device=device, dtype=dtype) + text_ids = torch.zeros(prompt_embeds.shape[1], 3, dtype=dtype, device=device) return prompt_embeds, pooled_prompt_embeds, text_ids @@ -427,7 +427,7 @@ def check_inputs( @staticmethod def _prepare_latent_image_ids(batch_size, height, width, device, dtype): - latent_image_ids = torch.zeros(height, width, 3) + latent_image_ids = torch.zeros(height, width, 3, device=device, dtype=dtype) latent_image_ids[..., 1] = latent_image_ids[..., 1] + torch.arange(height)[:, None] latent_image_ids[..., 2] = latent_image_ids[..., 2] + torch.arange(width)[None, :] @@ -437,7 +437,7 @@ def _prepare_latent_image_ids(batch_size, height, width, device, dtype): latent_image_id_height * latent_image_id_width, latent_image_id_channels ) - return latent_image_ids.to(device=device, dtype=dtype) + return latent_image_ids @staticmethod def _pack_latents(latents, batch_size, num_channels_latents, height, width): diff --git a/src/diffusers/pipelines/flux/pipeline_flux_controlnet.py b/src/diffusers/pipelines/flux/pipeline_flux_controlnet.py index 9f33e26013d5..ceb095da0501 100644 --- a/src/diffusers/pipelines/flux/pipeline_flux_controlnet.py +++ b/src/diffusers/pipelines/flux/pipeline_flux_controlnet.py @@ -452,7 +452,7 @@ def check_inputs( @staticmethod # Copied from diffusers.pipelines.flux.pipeline_flux.FluxPipeline._prepare_latent_image_ids def _prepare_latent_image_ids(batch_size, height, width, device, dtype): - latent_image_ids = torch.zeros(height, width, 3) + latent_image_ids = torch.zeros(height, width, 3, device=device, dtype=dtype) latent_image_ids[..., 1] = latent_image_ids[..., 1] + torch.arange(height)[:, None] latent_image_ids[..., 2] = latent_image_ids[..., 2] + torch.arange(width)[None, :] @@ -462,7 +462,7 @@ def _prepare_latent_image_ids(batch_size, height, width, device, dtype): latent_image_id_height * latent_image_id_width, latent_image_id_channels ) - return latent_image_ids.to(device=device, dtype=dtype) + return latent_image_ids @staticmethod # Copied from diffusers.pipelines.flux.pipeline_flux.FluxPipeline._pack_latents diff --git a/src/diffusers/pipelines/flux/pipeline_flux_controlnet_image_to_image.py b/src/diffusers/pipelines/flux/pipeline_flux_controlnet_image_to_image.py index 810c970ab715..be0d8f85cd85 100644 --- a/src/diffusers/pipelines/flux/pipeline_flux_controlnet_image_to_image.py +++ b/src/diffusers/pipelines/flux/pipeline_flux_controlnet_image_to_image.py @@ -407,7 +407,7 @@ def encode_prompt( unscale_lora_layers(self.text_encoder_2, lora_scale) dtype = self.text_encoder.dtype if self.text_encoder is not None else self.transformer.dtype - text_ids = torch.zeros(prompt_embeds.shape[1], 3).to(device=device, dtype=dtype) + text_ids = torch.zeros(prompt_embeds.shape[1], 3, dtype=dtype, device=device) return prompt_embeds, pooled_prompt_embeds, text_ids @@ -495,7 +495,7 @@ def check_inputs( @staticmethod # Copied from diffusers.pipelines.flux.pipeline_flux.FluxPipeline._prepare_latent_image_ids def _prepare_latent_image_ids(batch_size, height, width, device, dtype): - latent_image_ids = torch.zeros(height, width, 3) + latent_image_ids = torch.zeros(height, width, 3, device=device, dtype=dtype) latent_image_ids[..., 1] = latent_image_ids[..., 1] + torch.arange(height)[:, None] latent_image_ids[..., 2] = latent_image_ids[..., 2] + torch.arange(width)[None, :] @@ -505,7 +505,7 @@ def _prepare_latent_image_ids(batch_size, height, width, device, dtype): latent_image_id_height * latent_image_id_width, latent_image_id_channels ) - return latent_image_ids.to(device=device, dtype=dtype) + return latent_image_ids @staticmethod # Copied from diffusers.pipelines.flux.pipeline_flux.FluxPipeline._pack_latents diff --git a/src/diffusers/pipelines/flux/pipeline_flux_controlnet_inpainting.py b/src/diffusers/pipelines/flux/pipeline_flux_controlnet_inpainting.py index 1f5f83561f1c..2a1fbce5e81e 100644 --- a/src/diffusers/pipelines/flux/pipeline_flux_controlnet_inpainting.py +++ b/src/diffusers/pipelines/flux/pipeline_flux_controlnet_inpainting.py @@ -417,7 +417,7 @@ def encode_prompt( unscale_lora_layers(self.text_encoder_2, lora_scale) dtype = self.text_encoder.dtype if self.text_encoder is not None else self.transformer.dtype - text_ids = torch.zeros(prompt_embeds.shape[1], 3).to(device=device, dtype=dtype) + text_ids = torch.zeros(prompt_embeds.shape[1], 3, dtype=dtype, device=device) return prompt_embeds, pooled_prompt_embeds, text_ids @@ -522,7 +522,7 @@ def check_inputs( @staticmethod # Copied from diffusers.pipelines.flux.pipeline_flux.FluxPipeline._prepare_latent_image_ids def _prepare_latent_image_ids(batch_size, height, width, device, dtype): - latent_image_ids = torch.zeros(height, width, 3) + latent_image_ids = torch.zeros(height, width, 3, device=device, dtype=dtype) latent_image_ids[..., 1] = latent_image_ids[..., 1] + torch.arange(height)[:, None] latent_image_ids[..., 2] = latent_image_ids[..., 2] + torch.arange(width)[None, :] @@ -532,7 +532,7 @@ def _prepare_latent_image_ids(batch_size, height, width, device, dtype): latent_image_id_height * latent_image_id_width, latent_image_id_channels ) - return latent_image_ids.to(device=device, dtype=dtype) + return latent_image_ids @staticmethod # Copied from diffusers.pipelines.flux.pipeline_flux.FluxPipeline._pack_latents diff --git a/src/diffusers/pipelines/flux/pipeline_flux_img2img.py b/src/diffusers/pipelines/flux/pipeline_flux_img2img.py index 47f9f268ee9d..aa1a3e7fc3a4 100644 --- a/src/diffusers/pipelines/flux/pipeline_flux_img2img.py +++ b/src/diffusers/pipelines/flux/pipeline_flux_img2img.py @@ -391,7 +391,7 @@ def encode_prompt( unscale_lora_layers(self.text_encoder_2, lora_scale) dtype = self.text_encoder.dtype if self.text_encoder is not None else self.transformer.dtype - text_ids = torch.zeros(prompt_embeds.shape[1], 3).to(device=device, dtype=dtype) + text_ids = torch.zeros(prompt_embeds.shape[1], 3, dtype=dtype, device=device) return prompt_embeds, pooled_prompt_embeds, text_ids @@ -479,7 +479,7 @@ def check_inputs( @staticmethod # Copied from diffusers.pipelines.flux.pipeline_flux.FluxPipeline._prepare_latent_image_ids def _prepare_latent_image_ids(batch_size, height, width, device, dtype): - latent_image_ids = torch.zeros(height, width, 3) + latent_image_ids = torch.zeros(height, width, 3, device=device, dtype=dtype) latent_image_ids[..., 1] = latent_image_ids[..., 1] + torch.arange(height)[:, None] latent_image_ids[..., 2] = latent_image_ids[..., 2] + torch.arange(width)[None, :] @@ -489,7 +489,7 @@ def _prepare_latent_image_ids(batch_size, height, width, device, dtype): latent_image_id_height * latent_image_id_width, latent_image_id_channels ) - return latent_image_ids.to(device=device, dtype=dtype) + return latent_image_ids @staticmethod # Copied from diffusers.pipelines.flux.pipeline_flux.FluxPipeline._pack_latents diff --git a/src/diffusers/pipelines/flux/pipeline_flux_inpaint.py b/src/diffusers/pipelines/flux/pipeline_flux_inpaint.py index 766f9864839e..97824258b28f 100644 --- a/src/diffusers/pipelines/flux/pipeline_flux_inpaint.py +++ b/src/diffusers/pipelines/flux/pipeline_flux_inpaint.py @@ -395,7 +395,7 @@ def encode_prompt( unscale_lora_layers(self.text_encoder_2, lora_scale) dtype = self.text_encoder.dtype if self.text_encoder is not None else self.transformer.dtype - text_ids = torch.zeros(prompt_embeds.shape[1], 3).to(device=device, dtype=dtype) + text_ids = torch.zeros(prompt_embeds.shape[1], 3, dtype=dtype, device=device) return prompt_embeds, pooled_prompt_embeds, text_ids @@ -500,7 +500,7 @@ def check_inputs( @staticmethod # Copied from diffusers.pipelines.flux.pipeline_flux.FluxPipeline._prepare_latent_image_ids def _prepare_latent_image_ids(batch_size, height, width, device, dtype): - latent_image_ids = torch.zeros(height, width, 3) + latent_image_ids = torch.zeros(height, width, 3, device=device, dtype=dtype) latent_image_ids[..., 1] = latent_image_ids[..., 1] + torch.arange(height)[:, None] latent_image_ids[..., 2] = latent_image_ids[..., 2] + torch.arange(width)[None, :] @@ -510,7 +510,7 @@ def _prepare_latent_image_ids(batch_size, height, width, device, dtype): latent_image_id_height * latent_image_id_width, latent_image_id_channels ) - return latent_image_ids.to(device=device, dtype=dtype) + return latent_image_ids @staticmethod # Copied from diffusers.pipelines.flux.pipeline_flux.FluxPipeline._pack_latents