Skip to content
New issue

Have a question about this project? # for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “#”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? # to your account

[SD3] Fix mis-matched shape when num_images_per_prompt > 1 using without T5 (text_encoder_3=None) #8558

Merged
merged 4 commits into from
Jun 18, 2024

Conversation

Dalanke
Copy link
Contributor

@Dalanke Dalanke commented Jun 14, 2024

What does this PR do?

This PR is trying to fix the bug when you specified any number greater than 1 in num_images_per_prompt when you call StableDiffusion3Pipeline . An expection occurs when you create the pipeline without T5 text encoder (set text_encoder_3=None)

Reproduction (follow the documentation here):

import torch
from diffusers import StableDiffusion3Pipeline

pipe = StableDiffusion3Pipeline.from_single_file(
    './stable-diffusion-3-medium/sd3_medium_incl_clips.safetensors',
    torch_dtype=torch.float16,
    text_encoder_3=None
    )
pipe = pipe.to("cuda")

image = pipe(
    "a picture of a cat holding a sign that says hello world",
    negative_prompt="",
    # could not specify number of images
    num_images_per_prompt=4,
    num_inference_steps=28,
    guidance_scale=7.0,
).images

for i, img in enumerate(image):
    with open(f'./output/test_{i}.jpg','w+') as f:
        img.save(f)

Bug output

Loading pipeline components...:  62%|███████████████████████████████████████▍                       | 5/8 [00:00<00:00, 16.23it/s]You set `add_prefix_space`. The tokenizer needs to be converted from the slow tokenizers
Loading pipeline components...: 100%|███████████████████████████████████████████████████████████████| 8/8 [00:01<00:00,  5.47it/s]
Traceback (most recent call last):
  File "/home/xxx/workspace/sd3/sd3_inference.py", line 16, in <module>
    image = pipe(
  File "/home/xxx/miniconda3/envs/diffusers/lib/python3.10/site-packages/torch/utils/_contextlib.py", line 115, in decorate_context
    return func(*args, **kwargs)
  File "/home/xxx/workspace/diffusers/src/diffusers/pipelines/stable_diffusion_3/pipeline_stable_diffusion_3.py", line 778, in __call__
    ) = self.encode_prompt(
  File "/home/xxx/workspace/diffusers/src/diffusers/pipelines/stable_diffusion_3/pipeline_stable_diffusion_3.py", line 413, in encode_prompt
    prompt_embeds = torch.cat([clip_prompt_embeds, t5_prompt_embed], dim=-2)
RuntimeError: Sizes of tensors must match except in dimension 1. Expected size 4 but got size 1 for tensor number 1 in the list.

Mitigation:
Bug due to the shape mis-match. For example, in num_images_per_prompt=4 settting, line 413:

prompt_embeds = torch.cat([clip_prompt_embeds, t5_prompt_embed], dim=-2)

will have different shape in torch.Size([4, 77, 4096]) and torch.Size([1, 77, 4096])

fix in the function _get_t5_prompt_embeds return when text_encoder_3=None

if self.text_encoder_3 is None:
        return torch.zeros(
        # change shape here
        # (batch_size, self.tokenizer_max_length, self.transformer.config.joint_attention_dim),
            (batch_size * num_images_per_prompt, self.tokenizer_max_length, self.transformer.config.joint_attention_dim),
            device=device,
            dtype=dtype,
        )

Before submitting

Who can review?

Anyone in the community is free to review the PR once the tests have passed. Feel free to tag
members/contributors who may be interested in your PR.

@HuggingFaceDocBuilderDev

The docs for this PR live here. All of your documentation changes will be reflected on that endpoint. The docs are available until 30 days after the last update.

@Dalanke
Copy link
Contributor Author

Dalanke commented Jun 18, 2024

Seems like some checks were not successful but I could not figure out the reason. One line code changed should not lead to code quality issue. Can anyone kindly look into it?

@yiyixuxu yiyixuxu merged commit 2921a20 into huggingface:main Jun 18, 2024
14 of 15 checks passed
@yiyixuxu
Copy link
Collaborator

thanks!
we have a make style and make fix-copies command you can run to pass the quality test https://huggingface.co/docs/diffusers/en/conceptual/contribution#how-to-open-a-pr

yiyixuxu added a commit that referenced this pull request Jun 20, 2024
…out T5 (text_encoder_3=None) (#8558)

* fix shape mismatch when num_images_per_prompt > 1 and text_encoder_3=None

* style

* fix copies

---------

Co-authored-by: YiYi Xu <yixu310@gmail.com>
Co-authored-by: yiyixuxu <yixu310@gmail,com>
sayakpaul pushed a commit that referenced this pull request Dec 23, 2024
…out T5 (text_encoder_3=None) (#8558)

* fix shape mismatch when num_images_per_prompt > 1 and text_encoder_3=None

* style

* fix copies

---------

Co-authored-by: YiYi Xu <yixu310@gmail.com>
Co-authored-by: yiyixuxu <yixu310@gmail,com>
# for free to join this conversation on GitHub. Already have an account? # to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants