From 27df73ea4a06a6c45313d4915a6867266005ce1f Mon Sep 17 00:00:00 2001
From: sayakpaul <spsayakpaul@gmail.com>
Date: Thu, 31 Oct 2024 14:45:48 +0530
Subject: [PATCH] reduce explicit device transfers and typecasting in flux.

---
 src/diffusers/pipelines/flux/pipeline_flux.py               | 6 +++---
 src/diffusers/pipelines/flux/pipeline_flux_controlnet.py    | 4 ++--
 .../flux/pipeline_flux_controlnet_image_to_image.py         | 6 +++---
 .../pipelines/flux/pipeline_flux_controlnet_inpainting.py   | 6 +++---
 src/diffusers/pipelines/flux/pipeline_flux_img2img.py       | 6 +++---
 src/diffusers/pipelines/flux/pipeline_flux_inpaint.py       | 6 +++---
 6 files changed, 17 insertions(+), 17 deletions(-)

diff --git a/src/diffusers/pipelines/flux/pipeline_flux.py b/src/diffusers/pipelines/flux/pipeline_flux.py
index 040d935f1b88..ab4e0fc4d255 100644
--- a/src/diffusers/pipelines/flux/pipeline_flux.py
+++ b/src/diffusers/pipelines/flux/pipeline_flux.py
@@ -371,7 +371,7 @@ def encode_prompt(
                 unscale_lora_layers(self.text_encoder_2, lora_scale)
 
         dtype = self.text_encoder.dtype if self.text_encoder is not None else self.transformer.dtype
-        text_ids = torch.zeros(prompt_embeds.shape[1], 3).to(device=device, dtype=dtype)
+        text_ids = torch.zeros(prompt_embeds.shape[1], 3, dtype=dtype, device=device)
 
         return prompt_embeds, pooled_prompt_embeds, text_ids
 
@@ -427,7 +427,7 @@ def check_inputs(
 
     @staticmethod
     def _prepare_latent_image_ids(batch_size, height, width, device, dtype):
-        latent_image_ids = torch.zeros(height, width, 3)
+        latent_image_ids = torch.zeros(height, width, 3, device=device, dtype=dtype)
         latent_image_ids[..., 1] = latent_image_ids[..., 1] + torch.arange(height)[:, None]
         latent_image_ids[..., 2] = latent_image_ids[..., 2] + torch.arange(width)[None, :]
 
@@ -437,7 +437,7 @@ def _prepare_latent_image_ids(batch_size, height, width, device, dtype):
             latent_image_id_height * latent_image_id_width, latent_image_id_channels
         )
 
-        return latent_image_ids.to(device=device, dtype=dtype)
+        return latent_image_ids
 
     @staticmethod
     def _pack_latents(latents, batch_size, num_channels_latents, height, width):
diff --git a/src/diffusers/pipelines/flux/pipeline_flux_controlnet.py b/src/diffusers/pipelines/flux/pipeline_flux_controlnet.py
index 9f33e26013d5..ceb095da0501 100644
--- a/src/diffusers/pipelines/flux/pipeline_flux_controlnet.py
+++ b/src/diffusers/pipelines/flux/pipeline_flux_controlnet.py
@@ -452,7 +452,7 @@ def check_inputs(
     @staticmethod
     # Copied from diffusers.pipelines.flux.pipeline_flux.FluxPipeline._prepare_latent_image_ids
     def _prepare_latent_image_ids(batch_size, height, width, device, dtype):
-        latent_image_ids = torch.zeros(height, width, 3)
+        latent_image_ids = torch.zeros(height, width, 3, device=device, dtype=dtype)
         latent_image_ids[..., 1] = latent_image_ids[..., 1] + torch.arange(height)[:, None]
         latent_image_ids[..., 2] = latent_image_ids[..., 2] + torch.arange(width)[None, :]
 
@@ -462,7 +462,7 @@ def _prepare_latent_image_ids(batch_size, height, width, device, dtype):
             latent_image_id_height * latent_image_id_width, latent_image_id_channels
         )
 
-        return latent_image_ids.to(device=device, dtype=dtype)
+        return latent_image_ids
 
     @staticmethod
     # Copied from diffusers.pipelines.flux.pipeline_flux.FluxPipeline._pack_latents
diff --git a/src/diffusers/pipelines/flux/pipeline_flux_controlnet_image_to_image.py b/src/diffusers/pipelines/flux/pipeline_flux_controlnet_image_to_image.py
index 810c970ab715..be0d8f85cd85 100644
--- a/src/diffusers/pipelines/flux/pipeline_flux_controlnet_image_to_image.py
+++ b/src/diffusers/pipelines/flux/pipeline_flux_controlnet_image_to_image.py
@@ -407,7 +407,7 @@ def encode_prompt(
                 unscale_lora_layers(self.text_encoder_2, lora_scale)
 
         dtype = self.text_encoder.dtype if self.text_encoder is not None else self.transformer.dtype
-        text_ids = torch.zeros(prompt_embeds.shape[1], 3).to(device=device, dtype=dtype)
+        text_ids = torch.zeros(prompt_embeds.shape[1], 3, dtype=dtype, device=device)
 
         return prompt_embeds, pooled_prompt_embeds, text_ids
 
@@ -495,7 +495,7 @@ def check_inputs(
     @staticmethod
     # Copied from diffusers.pipelines.flux.pipeline_flux.FluxPipeline._prepare_latent_image_ids
     def _prepare_latent_image_ids(batch_size, height, width, device, dtype):
-        latent_image_ids = torch.zeros(height, width, 3)
+        latent_image_ids = torch.zeros(height, width, 3, device=device, dtype=dtype)
         latent_image_ids[..., 1] = latent_image_ids[..., 1] + torch.arange(height)[:, None]
         latent_image_ids[..., 2] = latent_image_ids[..., 2] + torch.arange(width)[None, :]
 
@@ -505,7 +505,7 @@ def _prepare_latent_image_ids(batch_size, height, width, device, dtype):
             latent_image_id_height * latent_image_id_width, latent_image_id_channels
         )
 
-        return latent_image_ids.to(device=device, dtype=dtype)
+        return latent_image_ids
 
     @staticmethod
     # Copied from diffusers.pipelines.flux.pipeline_flux.FluxPipeline._pack_latents
diff --git a/src/diffusers/pipelines/flux/pipeline_flux_controlnet_inpainting.py b/src/diffusers/pipelines/flux/pipeline_flux_controlnet_inpainting.py
index 1f5f83561f1c..2a1fbce5e81e 100644
--- a/src/diffusers/pipelines/flux/pipeline_flux_controlnet_inpainting.py
+++ b/src/diffusers/pipelines/flux/pipeline_flux_controlnet_inpainting.py
@@ -417,7 +417,7 @@ def encode_prompt(
                 unscale_lora_layers(self.text_encoder_2, lora_scale)
 
         dtype = self.text_encoder.dtype if self.text_encoder is not None else self.transformer.dtype
-        text_ids = torch.zeros(prompt_embeds.shape[1], 3).to(device=device, dtype=dtype)
+        text_ids = torch.zeros(prompt_embeds.shape[1], 3, dtype=dtype, device=device)
 
         return prompt_embeds, pooled_prompt_embeds, text_ids
 
@@ -522,7 +522,7 @@ def check_inputs(
     @staticmethod
     # Copied from diffusers.pipelines.flux.pipeline_flux.FluxPipeline._prepare_latent_image_ids
     def _prepare_latent_image_ids(batch_size, height, width, device, dtype):
-        latent_image_ids = torch.zeros(height, width, 3)
+        latent_image_ids = torch.zeros(height, width, 3, device=device, dtype=dtype)
         latent_image_ids[..., 1] = latent_image_ids[..., 1] + torch.arange(height)[:, None]
         latent_image_ids[..., 2] = latent_image_ids[..., 2] + torch.arange(width)[None, :]
 
@@ -532,7 +532,7 @@ def _prepare_latent_image_ids(batch_size, height, width, device, dtype):
             latent_image_id_height * latent_image_id_width, latent_image_id_channels
         )
 
-        return latent_image_ids.to(device=device, dtype=dtype)
+        return latent_image_ids
 
     @staticmethod
     # Copied from diffusers.pipelines.flux.pipeline_flux.FluxPipeline._pack_latents
diff --git a/src/diffusers/pipelines/flux/pipeline_flux_img2img.py b/src/diffusers/pipelines/flux/pipeline_flux_img2img.py
index 47f9f268ee9d..aa1a3e7fc3a4 100644
--- a/src/diffusers/pipelines/flux/pipeline_flux_img2img.py
+++ b/src/diffusers/pipelines/flux/pipeline_flux_img2img.py
@@ -391,7 +391,7 @@ def encode_prompt(
                 unscale_lora_layers(self.text_encoder_2, lora_scale)
 
         dtype = self.text_encoder.dtype if self.text_encoder is not None else self.transformer.dtype
-        text_ids = torch.zeros(prompt_embeds.shape[1], 3).to(device=device, dtype=dtype)
+        text_ids = torch.zeros(prompt_embeds.shape[1], 3, dtype=dtype, device=device)
 
         return prompt_embeds, pooled_prompt_embeds, text_ids
 
@@ -479,7 +479,7 @@ def check_inputs(
     @staticmethod
     # Copied from diffusers.pipelines.flux.pipeline_flux.FluxPipeline._prepare_latent_image_ids
     def _prepare_latent_image_ids(batch_size, height, width, device, dtype):
-        latent_image_ids = torch.zeros(height, width, 3)
+        latent_image_ids = torch.zeros(height, width, 3, device=device, dtype=dtype)
         latent_image_ids[..., 1] = latent_image_ids[..., 1] + torch.arange(height)[:, None]
         latent_image_ids[..., 2] = latent_image_ids[..., 2] + torch.arange(width)[None, :]
 
@@ -489,7 +489,7 @@ def _prepare_latent_image_ids(batch_size, height, width, device, dtype):
             latent_image_id_height * latent_image_id_width, latent_image_id_channels
         )
 
-        return latent_image_ids.to(device=device, dtype=dtype)
+        return latent_image_ids
 
     @staticmethod
     # Copied from diffusers.pipelines.flux.pipeline_flux.FluxPipeline._pack_latents
diff --git a/src/diffusers/pipelines/flux/pipeline_flux_inpaint.py b/src/diffusers/pipelines/flux/pipeline_flux_inpaint.py
index 766f9864839e..97824258b28f 100644
--- a/src/diffusers/pipelines/flux/pipeline_flux_inpaint.py
+++ b/src/diffusers/pipelines/flux/pipeline_flux_inpaint.py
@@ -395,7 +395,7 @@ def encode_prompt(
                 unscale_lora_layers(self.text_encoder_2, lora_scale)
 
         dtype = self.text_encoder.dtype if self.text_encoder is not None else self.transformer.dtype
-        text_ids = torch.zeros(prompt_embeds.shape[1], 3).to(device=device, dtype=dtype)
+        text_ids = torch.zeros(prompt_embeds.shape[1], 3, dtype=dtype, device=device)
 
         return prompt_embeds, pooled_prompt_embeds, text_ids
 
@@ -500,7 +500,7 @@ def check_inputs(
     @staticmethod
     # Copied from diffusers.pipelines.flux.pipeline_flux.FluxPipeline._prepare_latent_image_ids
     def _prepare_latent_image_ids(batch_size, height, width, device, dtype):
-        latent_image_ids = torch.zeros(height, width, 3)
+        latent_image_ids = torch.zeros(height, width, 3, device=device, dtype=dtype)
         latent_image_ids[..., 1] = latent_image_ids[..., 1] + torch.arange(height)[:, None]
         latent_image_ids[..., 2] = latent_image_ids[..., 2] + torch.arange(width)[None, :]
 
@@ -510,7 +510,7 @@ def _prepare_latent_image_ids(batch_size, height, width, device, dtype):
             latent_image_id_height * latent_image_id_width, latent_image_id_channels
         )
 
-        return latent_image_ids.to(device=device, dtype=dtype)
+        return latent_image_ids
 
     @staticmethod
     # Copied from diffusers.pipelines.flux.pipeline_flux.FluxPipeline._pack_latents