Skip to content

Add option to use torch._inductor.standalone_compile #17057

New issue

Have a question about this project? # for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “#”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? # to your account

Merged
merged 1 commit into from
May 9, 2025

Conversation

zou3519
Copy link
Collaborator

@zou3519 zou3519 commented Apr 23, 2025

This PR adds the option to use torch._inductor.standalone_compile to perform compilation instead of compile_fx. The goal of standalone_compile is to remove the hacks around vLLM's usage of compile_fx, we want to migrate to using it in PyTorch 2.8.

standalone_compile replaces how vLLM interacts with the torch.compile caches. Instead of vLLM trying to redirect them into its torch_compile_cache folder, vLLM can pass standalone_compile a filepath that is inside of the torch_compile_cache folder and standalone_compile will write the full precompiled artifact to it.

Right now this option is hidden behind an envvar (VLLM_TEST_STANDALONE_COMPILE). It is also not tested in vLLM CI (vLLM CI only tests against PyTorch 2.6). This option also needs more testing before we turn it on by default for PyTorch 2.8+. I am putting this PR out so that we can merge something that we can keep developing on top of.

Test Plan:

Copy link

👋 Hi! Thank you for contributing to the vLLM project.

💬 Join our developer Slack at https://slack.vllm.ai to discuss your PR in #pr-reviews, coordinate on features in #feat- channels, or join special interest groups in #sig- channels.

Just a reminder: PRs would not trigger full CI run by default. Instead, it would only run fastcheck CI which starts running only a small and essential subset of CI tests to quickly catch errors. You can run other CI tests on top of those by going to your fastcheck build on Buildkite UI (linked in the PR checks section) and unblock them. If you do not have permission to unblock, ping simon-mo or khluu to add you in our Buildkite org.

Once the PR is approved and ready to go, your PR reviewer(s) can run CI to test the changes comprehensively before merging.

To run CI, PR reviewers can either: Add ready label to the PR or enable auto-merge.

🚀

@zou3519 zou3519 force-pushed the standalone_compile branch 3 times, most recently from 031b8a1 to 24dc355 Compare April 23, 2025 18:32
@zou3519 zou3519 marked this pull request as ready for review April 24, 2025 00:42
@youkaichao youkaichao self-assigned this Apr 25, 2025
@zou3519 zou3519 force-pushed the standalone_compile branch from 24dc355 to 0bbb2e9 Compare April 30, 2025 03:52
Copy link

mergify bot commented May 1, 2025

This pull request has merge conflicts that must be resolved before it can be
merged. Please rebase the PR, @zou3519.

https://docs.github.com/en/pull-requests/collaborating-with-pull-requests/working-with-forks/syncing-a-fork

@mergify mergify bot added the needs-rebase label May 1, 2025
@zou3519 zou3519 force-pushed the standalone_compile branch from 0bbb2e9 to 5bb2380 Compare May 2, 2025 14:58
@mergify mergify bot removed the needs-rebase label May 2, 2025
Copy link
Collaborator

@houseroad houseroad left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Overall looks reasonable to me.

@zou3519 zou3519 force-pushed the standalone_compile branch from 5bb2380 to 80d85bd Compare May 5, 2025 12:45
@zou3519 zou3519 requested a review from houseroad May 6, 2025 01:56
Copy link

mergify bot commented May 6, 2025

This pull request has merge conflicts that must be resolved before it can be
merged. Please rebase the PR, @zou3519.

https://docs.github.com/en/pull-requests/collaborating-with-pull-requests/working-with-forks/syncing-a-fork

@mergify mergify bot added the needs-rebase label May 6, 2025
@houseroad
Copy link
Collaborator

Let's rebase?

@zou3519 zou3519 force-pushed the standalone_compile branch from 80d85bd to dcdbdc3 Compare May 6, 2025 19:25
@mergify mergify bot removed the needs-rebase label May 6, 2025
@zou3519 zou3519 force-pushed the standalone_compile branch 2 times, most recently from 8a9288d to a147b3f Compare May 7, 2025 18:45
Comment on lines +216 to +220
# TODO(rzou): the implication is that we're not
# reading the python bytecode correctly in vLLM?
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@zou3519 Could you explain this comment? Would like to understand the sketchiness

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is a pre-existing problem, the pre-existing InductorAdaptor also has this code in it.

If a function being compiled returns a single tensor, e.g. f(x) = x.sin(), and we compile this, then Inductor is always passed a graph that returns a (Tensor,) and Inductor returns a compiled artifact that returns a (Tensor,). Dynamo is responsible for unpacking this back into a single tensor via the bytecode it generates.

vLLM takes the bytecode that Dynamo generates and turns it into some Python code that wraps the compiled artifact. However, since we also need to manually do the unpacking here, I suspect that vLLM is not doing that process correctly.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Makes sense, thanks for the explanation.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

since we also need to manually do the unpacking here, I suspect that vLLM is not doing that process correctly

what does this mean? as you mentioned, we have special handling logic for the case when the original graph returns a single tensor, and I think vLLM is correct here.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

what does this mean? as you mentioned, we have special handling logic for the case when the original graph returns a single tensor, and I think vLLM is correct here.

In torch.compile, the handling logic for what happens when the original graph returns a single tensor is in the Dynamo-produced bytecode. In vLLM, the handling logic is in the InductorAdaptor. I would expect it to be in the Dynamo-produced bytecode.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@youkaichao mentioned to me that vLLM does use the Dynamo-produced bytecode directly so... this needs more investigation

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Let's track it in the following PR? (like creating some issue?)

@zou3519 zou3519 force-pushed the standalone_compile branch from a147b3f to 4cd5ae6 Compare May 7, 2025 19:08
vllm/config.py Outdated
@@ -3639,6 +3642,7 @@ class CompilationConfig(BaseModel):
compile_sizes: Optional[list[Union[int, str]]] = Field(default=None)
inductor_compile_config: dict = Field(default_factory=dict)
inductor_passes: dict[str, str] = Field(default_factory=dict)
use_standalone_compile: bool = False
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this is a user-facing config, and I don't think we should add it here. If you want to switch the behavior, you can have an env var like VLLM_TEST_XXX , that's less user-facing and shows it is not intended for external use cases.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for pointing that out, I'll update this.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Updated

@@ -123,8 +134,15 @@ def compile(self,

# no compiler cached the graph, or the cache is disabled,
# we need to compile it
if self.compilation_config.use_standalone_compile:
maybe_key = \
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think it should be the compiler's responsibility to summarize the graph and generate the key. The compiler manager just calls the compiler. It should not provide the graph index.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@youkaichao The "key" is the file that the compiled artifact gets saved to. I don't think it's the compiler's responsibility to specify the key. Consider gcc - by default it will write a file to a.out, otherwise it's the user's responsibility to pick a name for their binary.

From the Inductor side, the inductor hash key computation is not public API and shouldn't be.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@youkaichao The "key" is the file that the compiled artifact gets saved to. I don't think it's the compiler's responsibility to specify the key. Consider gcc - by default it will write a file to a.out, otherwise it's the user's responsibility to pick a name for their binary.

From the Inductor side, the inductor hash key computation is not public API and shouldn't be.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Btw, do we need to use the key for hash computation to validate the cache?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Btw, do we need to use the key for hash computation to validate the cache?

We do not use this key to validate the vLLM cache. The "key" here is just a file path to the compiled artifact.

The "key" that is returned by InductorAdapator has semantic meaning to torch.compile. But vLLM manages its own cache and compilation, so this key has no semantic meaning to vLLM.

@@ -127,9 +132,99 @@ def produce_guards_expression(self, *args, **kwargs):
return ""


class InductorStandaloneAdaptor(CompilerInterface):
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

we should add a unittest for this adaptor. Feel free to do it in the following PR.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

In the following PR I will turn on InductorStandaloneAdaptor for PyTorch >= 2.8, which will make it so that it gets tested in the "torch nightly" vLLM CI.

@zou3519 zou3519 force-pushed the standalone_compile branch from 7bafaae to 009857e Compare May 8, 2025 22:33
"""
name = "inductor_standalone"

def compute_hash(self, vllm_config: VllmConfig) -> str:
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do we need to consider the none pytorch source code in the hash?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What do you mean? The compute_hash function here is the same as the compute_hash function in InductorAdaptor. I can put this into a helper function for better code reuse.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Offline synced, it will be called by vLLM to compute the overall cache key.

@houseroad houseroad added the ready ONLY add when PR is ready to merge/full CI is needed label May 8, 2025
Copy link
Collaborator

@houseroad houseroad left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Looks good to me.

This PR adds the option to use torch._inductor.standalone_compile to
perform compilation instead of compile_fx. The goal of
standalone_compile is to remove the hacks around vLLM's usage of
compile_fx, we want to migrate to using it in PyTorch 2.8.

standalone_compile replaces how vLLM interacts with the torch.compile
caches. Instead of vLLM trying to redirect them into its
torch_compile_cache folder, vLLM can pass standalone_compile a filepath
that is inside of the torch_compile_cache folder and standalone_compile
will write the full precompiled artifact to it.

Right now this option is hidden behind an envvar. It is also not
tested in vLLM CI (vLLM CI only tests against PyTorch 2.6).
This option also needs more testing before we turn it on by default for
PyTorch 2.8+. I am putting this PR out so that we can
merge something that we can keep developing on top of.

Test Plan:
- Run https://gist.github.com/zou3519/aebb622714e80f4cd4c369472f2372cd
with or without VLLM_TEST_STANDALONE_COMPILE

Signed-off-by: rzou <zou3519@gmail.com>
@zou3519 zou3519 force-pushed the standalone_compile branch from 009857e to e651f13 Compare May 8, 2025 23:03
@houseroad houseroad merged commit ea2236b into vllm-project:main May 9, 2025
50 checks passed
princepride pushed a commit to princepride/vllm that referenced this pull request May 10, 2025
)

Signed-off-by: rzou <zou3519@gmail.com>
Signed-off-by: 汪志鹏 <wangzhipeng628@gmail.com>
# for free to join this conversation on GitHub. Already have an account? # to comment
Labels
ready ONLY add when PR is ready to merge/full CI is needed
Projects
None yet
Development

Successfully merging this pull request may close these issues.

4 participants