Skip to content
New issue

Have a question about this project? # for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “#”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? # to your account

CUDAGRAPHs for Flux position embeddings #9299

Closed
sayakpaul opened this issue Aug 28, 2024 · 0 comments · Fixed by #9307
Closed

CUDAGRAPHs for Flux position embeddings #9299

sayakpaul opened this issue Aug 28, 2024 · 0 comments · Fixed by #9307
Assignees

Comments

@sayakpaul
Copy link
Member

sayakpaul commented Aug 28, 2024

@yiyixuxu

Is it possible to refactor the Flux positional embeddings so that we can fully make use of CUDAGRAPHs?

skipping cudagraphs due to skipping cudagraphs due to cpu device (device_put). Found from : 
   File "/home/sayak/diffusers/src/diffusers/models/transformers/transformer_flux.py", line 469, in forward
    image_rotary_emb = self.pos_embed(ids)
  File "/home/sayak/.pyenv/versions/diffusers/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1562, in _call_impl
    return forward_call(*args, **kwargs)
  File "/home/sayak/diffusers/src/diffusers/models/embeddings.py", line 630, in forward
    self.axes_dim[i], pos[:, i], repeat_interleave_real=True, use_real=True, freqs_dtype=freqs_dtype
Code
import torch

torch.set_float32_matmul_precision("high")
torch._inductor.conv_1x1_as_mm = True
torch._inductor.coordinate_descent_tuning = True
torch._inductor.epilogue_fusion = False
torch._inductor.coordinate_descent_check_all_directions = True

import diffusers
from platform import python_version
from diffusers import DiffusionPipeline

print(diffusers.__version__)
print(torch.__version__)
print(python_version())


pipe = DiffusionPipeline.from_pretrained("black-forest-labs/FLUX.1-dev", torch_dtype=torch.bfloat16).to("cuda")
pipe.transformer.to(memory_format=torch.channels_last)
pipe.vae.to(memory_format=torch.channels_last)

pipe.transformer = torch.compile(pipe.transformer, mode="max-autotune", fullgraph=True)
pipe.vae.decode = torch.compile(pipe.vae.decode, mode="max-autotune", fullgraph=True)

for _ in range(5):
    image = pipe(
        "Happy bear",
        num_inference_steps=5,
        guidance_scale=3.5,
        max_sequence_length=512,
        generator=torch.manual_seed(42),
        height=1024,
        width=1024,
    ).images[0]

If we can fully make sure CUDAGRAPHs torch.compile() would be faster.

# for free to join this conversation on GitHub. Already have an account? # to comment
Labels
None yet
Projects
None yet
Development

Successfully merging a pull request may close this issue.

2 participants