From 6f4ada9b5cdf3e805440d1a09137b63036cbbca5 Mon Sep 17 00:00:00 2001 From: Conan <conan.nan.ke@qq.com> Date: Fri, 14 Jun 2024 20:32:38 +0800 Subject: [PATCH 1/3] fix shape mismatch when num_images_per_prompt > 1 and text_encoder_3=None --- .../pipelines/stable_diffusion_3/pipeline_stable_diffusion_3.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/diffusers/pipelines/stable_diffusion_3/pipeline_stable_diffusion_3.py b/src/diffusers/pipelines/stable_diffusion_3/pipeline_stable_diffusion_3.py index 8e951db68cc8..25e15e2e5573 100644 --- a/src/diffusers/pipelines/stable_diffusion_3/pipeline_stable_diffusion_3.py +++ b/src/diffusers/pipelines/stable_diffusion_3/pipeline_stable_diffusion_3.py @@ -216,7 +216,7 @@ def _get_t5_prompt_embeds( if self.text_encoder_3 is None: return torch.zeros( - (batch_size, self.tokenizer_max_length, self.transformer.config.joint_attention_dim), + (batch_size * num_images_per_prompt, self.tokenizer_max_length, self.transformer.config.joint_attention_dim), device=device, dtype=dtype, ) From 3c1f545ad3090e99b50d8b518e2dd87c0e7aa581 Mon Sep 17 00:00:00 2001 From: yiyixuxu <yixu310@gmail,com> Date: Tue, 18 Jun 2024 21:55:22 +0000 Subject: [PATCH 2/3] style --- .../stable_diffusion_3/pipeline_stable_diffusion_3.py | 6 +++++- 1 file changed, 5 insertions(+), 1 deletion(-) diff --git a/src/diffusers/pipelines/stable_diffusion_3/pipeline_stable_diffusion_3.py b/src/diffusers/pipelines/stable_diffusion_3/pipeline_stable_diffusion_3.py index 034254d7add7..b8fd0b907684 100644 --- a/src/diffusers/pipelines/stable_diffusion_3/pipeline_stable_diffusion_3.py +++ b/src/diffusers/pipelines/stable_diffusion_3/pipeline_stable_diffusion_3.py @@ -217,7 +217,11 @@ def _get_t5_prompt_embeds( if self.text_encoder_3 is None: return torch.zeros( - (batch_size * num_images_per_prompt, self.tokenizer_max_length, self.transformer.config.joint_attention_dim), + ( + batch_size * num_images_per_prompt, + self.tokenizer_max_length, + self.transformer.config.joint_attention_dim, + ), device=device, dtype=dtype, ) From 7238b71b91afe5bde87216de0f919a90936d54a5 Mon Sep 17 00:00:00 2001 From: yiyixuxu <yixu310@gmail,com> Date: Tue, 18 Jun 2024 22:10:53 +0000 Subject: [PATCH 3/3] fix copies --- .../pipeline_stable_diffusion_3_img2img.py | 6 +++++- 1 file changed, 5 insertions(+), 1 deletion(-) diff --git a/src/diffusers/pipelines/stable_diffusion_3/pipeline_stable_diffusion_3_img2img.py b/src/diffusers/pipelines/stable_diffusion_3/pipeline_stable_diffusion_3_img2img.py index 44210ae8453f..fda363fc2978 100644 --- a/src/diffusers/pipelines/stable_diffusion_3/pipeline_stable_diffusion_3_img2img.py +++ b/src/diffusers/pipelines/stable_diffusion_3/pipeline_stable_diffusion_3_img2img.py @@ -232,7 +232,11 @@ def _get_t5_prompt_embeds( if self.text_encoder_3 is None: return torch.zeros( - (batch_size, self.tokenizer_max_length, self.transformer.config.joint_attention_dim), + ( + batch_size * num_images_per_prompt, + self.tokenizer_max_length, + self.transformer.config.joint_attention_dim, + ), device=device, dtype=dtype, )