Skip to content
New issue

Have a question about this project? # for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “#”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? # to your account

Hunyuan Video does not support batch size > 1 #10453

Closed
Nerogar opened this issue Jan 4, 2025 · 2 comments · Fixed by #10454
Closed

Hunyuan Video does not support batch size > 1 #10453

Nerogar opened this issue Jan 4, 2025 · 2 comments · Fixed by #10454
Assignees
Labels
bug Something isn't working

Comments

@Nerogar
Copy link
Contributor

Nerogar commented Jan 4, 2025

Describe the bug

The HunyuanVideoPipeline (and I believe the model itself) does not support execution with a batch size > 1. There are some shape mismatches in the attention calculation. Trying to set the batch size to 2 will result in an error like this:

Reproduction

This example is directly taken from the model card of https://huggingface.co/hunyuanvideo-community/HunyuanVideo. The only change is the added line num_videos_per_prompt=2,

import torch
from diffusers import HunyuanVideoPipeline, HunyuanVideoTransformer3DModel
from diffusers.utils import export_to_video

model_id = "hunyuanvideo-community/HunyuanVideo"
transformer = HunyuanVideoTransformer3DModel.from_pretrained(
    model_id, subfolder="transformer", torch_dtype=torch.bfloat16
)
pipe = HunyuanVideoPipeline.from_pretrained(model_id, transformer=transformer, torch_dtype=torch.float16)

# Enable memory savings
pipe.vae.enable_tiling()
pipe.enable_model_cpu_offload()

output = pipe(
    prompt="A cat walks on the grass, realistic",
    height=320,
    width=512,
    num_frames=61,
    num_inference_steps=30,
    num_videos_per_prompt=2, # <--- This is the only line I changed
).frames[0]
export_to_video(output, "output.mp4", fps=15)

Logs

File "H:\stable-diffusion\one-trainer\venv\lib\site-packages\torch\utils\_contextlib.py", line 116, in decorate_context
    return func(*args, **kwargs)
  File "H:\stable-diffusion\one-trainer\venv\src\diffusers\src\diffusers\pipelines\hunyuan_video\pipeline_hunyuan_video.py", line 647, in __call__
    noise_pred = self.transformer(
  File "H:\stable-diffusion\one-trainer\venv\lib\site-packages\torch\nn\modules\module.py", line 1736, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
  File "H:\stable-diffusion\one-trainer\venv\lib\site-packages\torch\nn\modules\module.py", line 1747, in _call_impl
    return forward_call(*args, **kwargs)
  File "H:\stable-diffusion\one-trainer\venv\lib\site-packages\accelerate\hooks.py", line 170, in new_forward
    output = module._old_forward(*args, **kwargs)
  File "H:\stable-diffusion\one-trainer\venv\src\diffusers\src\diffusers\models\transformers\transformer_hunyuan_video.py", line 763, in forward
    hidden_states, encoder_hidden_states = block(
  File "H:\stable-diffusion\one-trainer\venv\lib\site-packages\torch\nn\modules\module.py", line 1736, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
  File "H:\stable-diffusion\one-trainer\venv\lib\site-packages\torch\nn\modules\module.py", line 1747, in _call_impl
    return forward_call(*args, **kwargs)
  File "H:\stable-diffusion\one-trainer\venv\src\diffusers\src\diffusers\models\transformers\transformer_hunyuan_video.py", line 478, in forward
    attn_output, context_attn_output = self.attn(
  File "H:\stable-diffusion\one-trainer\venv\lib\site-packages\torch\nn\modules\module.py", line 1736, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
  File "H:\stable-diffusion\one-trainer\venv\lib\site-packages\torch\nn\modules\module.py", line 1747, in _call_impl
    return forward_call(*args, **kwargs)
  File "H:\stable-diffusion\one-trainer\venv\src\diffusers\src\diffusers\models\attention_processor.py", line 588, in forward
    return self.processor(
  File "H:\stable-diffusion\one-trainer\venv\src\diffusers\src\diffusers\models\transformers\transformer_hunyuan_video.py", line 117, in __call__
    hidden_states = F.scaled_dot_product_attention(
RuntimeError: The expanded size of the tensor (24) must match the existing size (2) at non-singleton dimension 1.  Target sizes: [2, 24, 10496, 10496].  Tensor sizes: [2, 10496, 10496]


### System Info

- 🤗 Diffusers version: 0.33.0.dev0
- Platform: Windows-10-10.0.22631-SP0
- Running on Google Colab?: No
- Python version: 3.10.8
- PyTorch version (GPU?): 2.5.1+cu124 (True)
- Flax version (CPU?/GPU?/TPU?): not installed (NA)
- Jax version: not installed
- JaxLib version: not installed
- Huggingface_hub version: 0.26.2
- Transformers version: 4.47.0
- Accelerate version: 1.0.1
- PEFT version: not installed
- Bitsandbytes version: 0.44.1
- Safetensors version: 0.4.5
- xFormers version: 0.0.28.post3
- Accelerator: NVIDIA RTX A5000, 24564 MiB
- Using GPU in script?: NVIDIA RTX A5000
- Using distributed or parallel set-up in script?: no


### Who can help?

_No response_
@Nerogar Nerogar added the bug Something isn't working label Jan 4, 2025
@a-r-r-o-w a-r-r-o-w self-assigned this Jan 4, 2025
@a-r-r-o-w
Copy link
Member

Sorry, this was untested when the PR was merged. Will take a look

@Nerogar
Copy link
Contributor Author

Nerogar commented Jan 4, 2025

Changing the attention mask shape here like this seems to fix the issue.

before

        # 5. Attention
        hidden_states = F.scaled_dot_product_attention(
            query, key, value, attn_mask=attention_mask, dropout_p=0.0, is_causal=False
        )

after

        # 5. Attention
        hidden_states = F.scaled_dot_product_attention(
            query, key, value, attn_mask=attention_mask.unsqueeze(1), dropout_p=0.0, is_causal=False
        )

Not sure if this is the best approach, or if the attention mask shape should already be changed before calling the attention processor.

# for free to join this conversation on GitHub. Already have an account? # to comment
Labels
bug Something isn't working
Projects
None yet
Development

Successfully merging a pull request may close this issue.

2 participants