Skip to content
New issue

Have a question about this project? # for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “#”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? # to your account

[torch.compile] support moe models #9632

Merged
merged 7 commits into from
Oct 28, 2024

Conversation

youkaichao
Copy link
Member

moe models will read config file to determine the triton config to run.

reading files during forward is a disaster for torch.compile .

this pr wraps the config reading part inside a custom op, so that it can pass torch.compile (although torch.compile will not be able to optimize it).

Signed-off-by: youkaichao <youkaichao@gmail.com>
Signed-off-by: youkaichao <youkaichao@gmail.com>
Signed-off-by: youkaichao <youkaichao@gmail.com>
Signed-off-by: youkaichao <youkaichao@gmail.com>
Signed-off-by: youkaichao <youkaichao@gmail.com>
Copy link

👋 Hi! Thank you for contributing to the vLLM project.
Just a reminder: PRs would not trigger full CI run by default. Instead, it would only run fastcheck CI which starts running only a small and essential subset of CI tests to quickly catch errors. You can run other CI tests on top of those by going to your fastcheck build on Buildkite UI (linked in the PR checks section) and unblock them. If you do not have permission to unblock, ping simon-mo or khluu to add you in our Buildkite org.

Once the PR is approved and ready to go, your PR reviewer(s) can run CI to test the changes comprehensively before merging.

To run CI, PR reviewers can do one of these:

  • Add ready label to the PR
  • Enable auto-merge.

🚀

("nm-testing/Meta-Llama-3-8B-Instruct-W8A8-Dyn-Per-Token-2048-Samples",
["--quantization", "compressed-tensors"
], 1, 1, "FLASH_ATTN", "generate", True),
("google/gemma-2-2b-it", [], 1, 2, "FLASHINFER", "generate", True),
("ibm/PowerMoE-3b", [], 1, 2, "FLASH_ATTN", "generate", True),
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

it does not work with FLASHINFER due to some head size config.

gating_output=router_logits,
renormalize=renormalize)

forward_native = forward_cuda
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What is this for?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

when we use inductor, we will use forward_native . this function needs to be implemented.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

you can also write a pytorch native implementation, but I doubt if inductor can optimize it. that's why I use forward_cuda as forward_native

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

basically, forward_native is the function we compile when we use inductor.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I see. I was just surprised because I didn't know that a class method can be defined this way

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

So it's essentially the same as

def forward_native(self, *args, **kwargs):
    return self.forward_cuda(*args, **kwargs)

?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

yes.

@youkaichao youkaichao added the ready ONLY add when PR is ready to merge/full CI is needed label Oct 27, 2024
Copy link
Collaborator

@WoosukKwon WoosukKwon left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM!

Signed-off-by: youkaichao <youkaichao@gmail.com>
@youkaichao youkaichao merged commit 32176fe into vllm-project:main Oct 28, 2024
68 checks passed
@youkaichao youkaichao deleted the compile_moe branch October 28, 2024 04:58
cooleel pushed a commit to cooleel/vllm that referenced this pull request Oct 28, 2024
Signed-off-by: youkaichao <youkaichao@gmail.com>
Signed-off-by: Shanshan Wang <shanshan.wang@h2o.ai>
cooleel pushed a commit to cooleel/vllm that referenced this pull request Oct 28, 2024
Signed-off-by: youkaichao <youkaichao@gmail.com>
Signed-off-by: Shanshan Wang <shanshan.wang@h2o.ai>
FerdinandZhong pushed a commit to FerdinandZhong/vllm that referenced this pull request Oct 29, 2024
Signed-off-by: youkaichao <youkaichao@gmail.com>
Signed-off-by: qishuai <ferdinandzhong@gmail.com>
rasmith pushed a commit to rasmith/vllm that referenced this pull request Oct 30, 2024
Signed-off-by: youkaichao <youkaichao@gmail.com>
Signed-off-by: Randall Smith <Randall.Smith@amd.com>
NickLucche pushed a commit to NickLucche/vllm that referenced this pull request Oct 31, 2024
Signed-off-by: youkaichao <youkaichao@gmail.com>
Signed-off-by: NickLucche <nlucches@redhat.com>
NickLucche pushed a commit to NickLucche/vllm that referenced this pull request Oct 31, 2024
Signed-off-by: youkaichao <youkaichao@gmail.com>
Signed-off-by: NickLucche <nlucches@redhat.com>
lk-chen pushed a commit to lk-chen/vllm that referenced this pull request Nov 4, 2024
Signed-off-by: youkaichao <youkaichao@gmail.com>
lk-chen pushed a commit to lk-chen/vllm that referenced this pull request Nov 4, 2024
Signed-off-by: youkaichao <youkaichao@gmail.com>
Signed-off-by: Linkun Chen <github+anyscale@lkchen.net>
sumitd2 pushed a commit to sumitd2/vllm that referenced this pull request Nov 14, 2024
Signed-off-by: youkaichao <youkaichao@gmail.com>
Signed-off-by: Sumit Dubey <sumit.dubey2@ibm.com>
KuntaiDu pushed a commit to KuntaiDu/vllm that referenced this pull request Nov 20, 2024
Signed-off-by: youkaichao <youkaichao@gmail.com>
mfournioux pushed a commit to mfournioux/vllm that referenced this pull request Nov 20, 2024
Signed-off-by: youkaichao <youkaichao@gmail.com>
Signed-off-by: Maxime Fournioux <55544262+mfournioux@users.noreply.github.com>
tlrmchlsmth pushed a commit to neuralmagic/vllm that referenced this pull request Nov 23, 2024
Signed-off-by: youkaichao <youkaichao@gmail.com>
Signed-off-by: Tyler Michael Smith <tyler@neuralmagic.com>
sleepwalker2017 pushed a commit to sleepwalker2017/vllm that referenced this pull request Dec 13, 2024
Signed-off-by: youkaichao <youkaichao@gmail.com>
# for free to join this conversation on GitHub. Already have an account? # to comment
Labels
ready ONLY add when PR is ready to merge/full CI is needed
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants