Skip to content
New issue

Have a question about this project? # for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “#”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? # to your account

[torch.compile] rework compile control with piecewise cudagraph #9715

Merged
merged 83 commits into from
Oct 30, 2024

Conversation

youkaichao
Copy link
Member

@youkaichao youkaichao commented Oct 26, 2024

rework the compilation control.

the user-facing flags are:

Usage CompilationLevel how vLLM uses Dynamo how vLLM uses Inductor use vLLM's custom ops (*) how to customize the compilation
export VLLM_TORCH_COMPILE_LEVEL=0 (default) NO_COMPILATION (0) N/A N/A N/A
export VLLM_TORCH_COMPILE_LEVEL=1 DYNAMO_AS_IS (1) use as-is N/A vllm.plugins.set_torch_compile_backend (default to "eager")
export VLLM_TORCH_COMPILE_LEVEL=2 DYNAMO_ONCE (2) use only once, make sure computation graph does not change N/A vllm.plugins.set_torch_compile_backend (default to "eager")
export VLLM_TORCH_COMPILE_LEVEL=3 PIECEWISE (3) same as 2 compilation behavior determined by the config write a json config file and specify it with VLLM_TORCH_COMPILE_CONFIG, or call vllm.plugins.set_compilation_config to set the config directly

* : users can also use VLLM_CUSTOM_OPS env var to have fine-grained control over custom ops.

For the detailed compilation config, please check the code doc.

Signed-off-by: youkaichao <youkaichao@gmail.com>
Signed-off-by: youkaichao <youkaichao@gmail.com>
Signed-off-by: youkaichao <youkaichao@gmail.com>
Signed-off-by: youkaichao <youkaichao@gmail.com>
Signed-off-by: youkaichao <youkaichao@gmail.com>
Signed-off-by: youkaichao <youkaichao@gmail.com>
Signed-off-by: youkaichao <youkaichao@gmail.com>
Signed-off-by: youkaichao <youkaichao@gmail.com>
Copy link

👋 Hi! Thank you for contributing to the vLLM project.
Just a reminder: PRs would not trigger full CI run by default. Instead, it would only run fastcheck CI which starts running only a small and essential subset of CI tests to quickly catch errors. You can run other CI tests on top of those by going to your fastcheck build on Buildkite UI (linked in the PR checks section) and unblock them. If you do not have permission to unblock, ping simon-mo or khluu to add you in our Buildkite org.

Once the PR is approved and ready to go, your PR reviewer(s) can run CI to test the changes comprehensively before merging.

To run CI, PR reviewers can do one of these:

  • Add ready label to the PR
  • Enable auto-merge.

🚀

@youkaichao
Copy link
Member Author

youkaichao commented Oct 26, 2024

example code:

import torch
from vllm.compilation.decorators import support_torch_compile
from vllm.compilation.compile_context import set_compile_context
from vllm.plugins import set_attention_ops
set_attention_ops(["silly.attention"])

import torch
from torch import nn

@torch.library.custom_op("silly::attention", mutates_args=["out"])
def silly_attention(q: torch.Tensor, k: torch.Tensor, v: torch.Tensor, out: torch.Tensor) -> None:
    print("silly")
    out.copy_(q)
    print(q)
    out[0] += 1

@silly_attention.register_fake
def _(q: torch.Tensor, k: torch.Tensor, v: torch.Tensor, out: torch.Tensor) -> None:
    return

@support_torch_compile
class SillyModel(nn.Module):
    def __init__(self) -> None:
        super().__init__()

    def forward(
        self,
        x: torch.Tensor
    ) -> torch.Tensor:
        x = x + 1
        x = x + 2
        out = torch.empty_like(x)
        torch.ops.silly.attention(x, x, x, out)
        x = out
        x = x - 2
        x = x - 1
        out = torch.empty_like(x)
        torch.ops.silly.attention(x, x, x, out)
        x = out
        x = x + 1
        return x

model = SillyModel()

input_buffer = torch.randn(100).cuda()

with set_compile_context([1, 2]):
    model(input_buffer)

    model(input_buffer[:2])
    model(input_buffer[:1])

input_buffer[:2].zero_()
output = model(input_buffer[:2])
print(output.__class__)
print(output[:2])

run with:

VLLM_LOGGING_LEVEL=DEBUG VLLM_TORCH_COMPILE_LEVEL=3 python test.py

requirements:

  • attention ops will be the boundary of piecewise graphs
  • if the output of attention ops is used in the subsequent graph, then it needs to be allocated in the previous graph, and passed to attention ops for mutation.

# FIXME: it seems pytorch changes the output to a tuple
it can be fixed by pytorch/pytorch#138980

Signed-off-by: youkaichao <youkaichao@gmail.com>
Signed-off-by: youkaichao <youkaichao@gmail.com>
Signed-off-by: youkaichao <youkaichao@gmail.com>
Signed-off-by: youkaichao <youkaichao@gmail.com>
Signed-off-by: youkaichao <youkaichao@gmail.com>
Signed-off-by: youkaichao <youkaichao@gmail.com>
Signed-off-by: youkaichao <youkaichao@gmail.com>
Signed-off-by: youkaichao <youkaichao@gmail.com>
Signed-off-by: youkaichao <youkaichao@gmail.com>
Signed-off-by: youkaichao <youkaichao@gmail.com>
Signed-off-by: youkaichao <youkaichao@gmail.com>
Signed-off-by: youkaichao <youkaichao@gmail.com>
Signed-off-by: youkaichao <youkaichao@gmail.com>
Signed-off-by: youkaichao <youkaichao@gmail.com>
Signed-off-by: youkaichao <youkaichao@gmail.com>
Signed-off-by: youkaichao <youkaichao@gmail.com>
Signed-off-by: youkaichao <youkaichao@gmail.com>
Signed-off-by: youkaichao <youkaichao@gmail.com>
Signed-off-by: youkaichao <youkaichao@gmail.com>
Signed-off-by: youkaichao <youkaichao@gmail.com>
Signed-off-by: youkaichao <youkaichao@gmail.com>
Signed-off-by: youkaichao <youkaichao@gmail.com>
Signed-off-by: youkaichao <youkaichao@gmail.com>
Signed-off-by: youkaichao <youkaichao@gmail.com>
Signed-off-by: youkaichao <youkaichao@gmail.com>
@youkaichao
Copy link
Member Author

I added some tests based on counters, following pytorch's test principle. @ProExpertProg

Signed-off-by: youkaichao <youkaichao@gmail.com>
Signed-off-by: youkaichao <youkaichao@gmail.com>
@youkaichao youkaichao added the ready ONLY add when PR is ready to merge/full CI is needed label Oct 29, 2024
Signed-off-by: youkaichao <youkaichao@gmail.com>
Signed-off-by: youkaichao <youkaichao@gmail.com>
@youkaichao
Copy link
Member Author

test failure is unrelated and also appear in main branch, merging

@youkaichao youkaichao merged commit ff5ed6e into vllm-project:main Oct 30, 2024
60 of 68 checks passed
@youkaichao youkaichao deleted the piece_wise branch October 30, 2024 06:03
rasmith pushed a commit to rasmith/vllm that referenced this pull request Oct 30, 2024
…-project#9715)

Signed-off-by: youkaichao <youkaichao@gmail.com>
Signed-off-by: Randall Smith <Randall.Smith@amd.com>
NickLucche pushed a commit to NickLucche/vllm that referenced this pull request Oct 31, 2024
…-project#9715)

Signed-off-by: youkaichao <youkaichao@gmail.com>
Signed-off-by: NickLucche <nlucches@redhat.com>
NickLucche pushed a commit to NickLucche/vllm that referenced this pull request Oct 31, 2024
…-project#9715)

Signed-off-by: youkaichao <youkaichao@gmail.com>
Signed-off-by: NickLucche <nlucches@redhat.com>
lk-chen pushed a commit to lk-chen/vllm that referenced this pull request Nov 4, 2024
lk-chen pushed a commit to lk-chen/vllm that referenced this pull request Nov 4, 2024
…-project#9715)

Signed-off-by: youkaichao <youkaichao@gmail.com>
Signed-off-by: Linkun Chen <github+anyscale@lkchen.net>
hissu-hyvarinen pushed a commit to ROCm/vllm that referenced this pull request Nov 6, 2024
JC1DA pushed a commit to JC1DA/vllm that referenced this pull request Nov 11, 2024
…-project#9715)

Signed-off-by: youkaichao <youkaichao@gmail.com>
Signed-off-by: Loc Huynh <jc1da.3011@gmail.com>
sumitd2 pushed a commit to sumitd2/vllm that referenced this pull request Nov 14, 2024
…-project#9715)

Signed-off-by: youkaichao <youkaichao@gmail.com>
Signed-off-by: Sumit Dubey <sumit.dubey2@ibm.com>
KuntaiDu pushed a commit to KuntaiDu/vllm that referenced this pull request Nov 20, 2024
mfournioux pushed a commit to mfournioux/vllm that referenced this pull request Nov 20, 2024
…-project#9715)

Signed-off-by: youkaichao <youkaichao@gmail.com>
Signed-off-by: Maxime Fournioux <55544262+mfournioux@users.noreply.github.com>
tlrmchlsmth pushed a commit to neuralmagic/vllm that referenced this pull request Nov 23, 2024
…-project#9715)

Signed-off-by: youkaichao <youkaichao@gmail.com>
Signed-off-by: Tyler Michael Smith <tyler@neuralmagic.com>
sleepwalker2017 pushed a commit to sleepwalker2017/vllm that referenced this pull request Dec 13, 2024
# for free to join this conversation on GitHub. Already have an account? # to comment
Labels
ci/build ready ONLY add when PR is ready to merge/full CI is needed
Projects
None yet
Development

Successfully merging this pull request may close these issues.

6 participants