Skip to content
New issue

Have a question about this project? # for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “#”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? # to your account

refactor rotary embedding 3: so it is not on cpu #9307

Merged
merged 2 commits into from
Aug 29, 2024
Merged

Conversation

yiyixuxu
Copy link
Collaborator

@yiyixuxu yiyixuxu commented Aug 28, 2024

fix #9299

@HuggingFaceDocBuilderDev

The docs for this PR live here. All of your documentation changes will be reflected on that endpoint. The docs are available until 30 days after the last update.

@yiyixuxu yiyixuxu changed the title refactor rotary embedding so it is not on cpu refactor rotary embedding 3: so it is not on cpu Aug 29, 2024
@yiyixuxu yiyixuxu requested a review from sayakpaul August 29, 2024 19:08
Copy link
Member

@sayakpaul sayakpaul left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks!

@yiyixuxu possible to trigger a torch.compile() with PyTorch nightly to verify if this helps with the CUDAGraph issue? Code is in #9299 (comment).

Ccing @cpuhrsch maybe you would like to review it?

@yiyixuxu
Copy link
Collaborator Author

@sayakpaul yeah tested it is fine

@sayakpaul sayakpaul merged commit 61d96c3 into main Aug 29, 2024
18 checks passed
@sayakpaul sayakpaul deleted the fix-torch-rope branch August 29, 2024 19:37
@yiyixuxu
Copy link
Collaborator Author

@sayakpaul
Is this a reasonable script? I want to compare the performance against 0.30.1-patch before we introduce the rotary embedding refractor

import torch
import torch.utils.benchmark as benchmark
import gc

import time

torch.set_float32_matmul_precision("high")
torch._inductor.conv_1x1_as_mm = True
torch._inductor.coordinate_descent_tuning = True
torch._inductor.epilogue_fusion = False
torch._inductor.coordinate_descent_check_all_directions = True

import diffusers
from platform import python_version
from diffusers import DiffusionPipeline

print(diffusers.__version__)
print(torch.__version__)
print(python_version())

def benchmark_fn(f, *args, **kwargs):
    t0 = benchmark.Timer(
        stmt="f(*args, **kwargs)",
        globals={"args": args, "kwargs": kwargs, "f": f},
        num_threads=torch.get_num_threads(),
    )
    return f"{(t0.blocked_autorange().mean):.3f}"

def bytes_to_giga_bytes(bytes):
    return f"{(bytes / 1024 / 1024 / 1024):.3f}"

def flush():
    """Wipes off memory."""
    gc.collect()
    torch.cuda.empty_cache()
    torch.cuda.reset_max_memory_allocated()
    torch.cuda.reset_peak_memory_stats()


pipe = DiffusionPipeline.from_pretrained("black-forest-labs/FLUX.1-dev", text_encoder=None, text_encoder_2=None, torch_dtype=torch.bfloat16).to("cuda")
pipe.transformer.to(memory_format=torch.channels_last)
pipe.vae.to(memory_format=torch.channels_last)

pipe.transformer = torch.compile(pipe.transformer, mode="max-autotune", fullgraph=True)
pipe.vae.decode = torch.compile(pipe.vae.decode, mode="max-autotune", fullgraph=True)

prompt_embeds = torch.load("flux_prompt_embeds.pt")
pooled_prompt_embeds = torch.load("flux_pooled_prompt_embeds.pt")

def run_inference(pipe):
    _ = pipe(
        prompt_embeds=prompt_embeds,
        pooled_prompt_embeds=pooled_prompt_embeds,
        num_inference_steps=5,
        guidance_scale=3.5,
        max_sequence_length=512,
        generator=torch.manual_seed(42),
        height=1024,
        width=1024,
    )

flush()

time = benchmark_fn(run_inference)
memory = bytes_to_giga_bytes(torch.cuda.max_memory_allocated())  # in GBs.
print(f" Execution time: {time} sec")
print(f" Memory: {memory} gib")

theta = theta * ntk_factor
freqs = 1.0 / (theta ** (torch.arange(0, dim, 2, dtype=freqs_dtype)[: (dim // 2)] / dim)) / linear_factor # [D/2]
t = torch.from_numpy(pos).to(freqs.device) # type: ignore # [S]
freqs = torch.outer(t, freqs) # type: ignore # [S, D/2]
freqs = freqs.to(pos.device)

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'd expect this to cause a sync as well since by default arange allocates on the CPU. One way to mitigate could be to
a) use pin_memory() on freqs ahead of time and set non_blocking=True
b) do arange on the GPU right away (i.e. torch.arange([...], device=pos.device)).

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ohhh let's do torch.arange([...], device=pos.device)

@sayakpaul
Copy link
Member

@yiyixuxu that looks reasonable but I'd call run_inference() maybe 2/3 times for warmups.

@@ -545,11 +545,14 @@ def get_1d_rotary_pos_embed(
assert dim % 2 == 0

if isinstance(pos, int):
pos = np.arange(pos)
pos = torch.arange(pos)

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This should also be passed a device argument to allocate it on the GPU. If this isn't on the GPU, then neither will the following Tensors.

sayakpaul pushed a commit that referenced this pull request Dec 23, 2024
change get_1d_rotary to accept pos as torch tensors
# for free to join this conversation on GitHub. Already have an account? # to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

CUDAGRAPHs for Flux position embeddings
4 participants