Skip to content

Commit a36e070

Browse files
authored
[torch.compile] fix functionalization (#8480)
1 parent 8a0cf1d commit a36e070

File tree

3 files changed

+167
-5
lines changed

3 files changed

+167
-5
lines changed

tests/compile/test_full_graph.py

+9-4
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,12 @@ def test_full_graph(model):
1616
"The future of AI is",
1717
]
1818
sampling_params = SamplingParams(temperature=0)
19-
llm = LLM(model="meta-llama/Meta-Llama-3-8B",
20-
enforce_eager=True,
21-
load_format="dummy")
22-
llm.generate(prompts, sampling_params)
19+
llm = LLM(model=model, enforce_eager=True)
20+
21+
outputs = llm.generate(prompts, sampling_params)
22+
23+
# Print the outputs.
24+
for output in outputs:
25+
prompt = output.prompt
26+
generated_text = output.outputs[0].text
27+
print(f"Prompt: {prompt!r}, Generated text: {generated_text!r}")

vllm/compilation/backends.py

+156
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,156 @@
1+
import operator
2+
3+
import torch
4+
import torch.fx as fx
5+
6+
7+
def fix_functionalization(graph: fx.Graph):
8+
"""
9+
Rewrite the graph module to replace the pattern involving
10+
torch._higher_order_ops.auto_functionalize.auto_functionalized
11+
with a direct call to the inplace custom op.
12+
13+
# TODO: check if PyTorch nightly has fixed this issue
14+
"""
15+
16+
# debug code, if we want to see the graph before the transformation
17+
# with open("before.py", "w") as f:
18+
# print(graph.python_code(root_module="self", verbose=True).src, file=f)
19+
20+
nodes_to_remove = []
21+
22+
for node in graph.nodes:
23+
# Identify the auto_functionalized node
24+
if node.op == 'call_function' and node.target == torch._higher_order_ops.auto_functionalize.auto_functionalized: # noqa
25+
if node.args[0] == torch.ops._C.rotary_embedding.default:
26+
# manual replace for rotary_embedding
27+
28+
# Now, collect the arguments
29+
kwargs = node.kwargs
30+
31+
query = kwargs['query']
32+
mm_node = query.args[0].args[0]
33+
34+
# Create a new call to torch.ops._C.rotary_embedding.default
35+
with graph.inserting_before(node):
36+
# just insert the call to the custom op
37+
# NOTE: don't run dead code elimination,
38+
# otherwise this op will be removed
39+
graph.call_function(torch.ops._C.rotary_embedding.default,
40+
kwargs=kwargs)
41+
42+
# Remove the auto_functionalized node
43+
# Since the node may have outputs, we need to handle its users
44+
# Replace uses of the outputs (getitem nodes) with mm_node
45+
for user in list(node.users):
46+
if user.op == 'call_function' and user.target == operator.getitem: # noqa
47+
# Remove the getitem node
48+
for getitem_user in list(user.users):
49+
if (getitem_user.op == 'call_function'
50+
and getitem_user.target
51+
== torch.ops.aten.slice_scatter.default):
52+
# Replace the uses of slice_scatter node
53+
# with mm_node
54+
getitem_user.replace_all_uses_with(mm_node)
55+
nodes_to_remove.append(getitem_user)
56+
nodes_to_remove.append(user)
57+
nodes_to_remove.append(node)
58+
59+
elif node.args[0] == torch.ops._C.fused_add_rms_norm.default:
60+
# manual replace for fused_add_rms_norm
61+
# this is the most effective optimization for llama
62+
# failing to do this will result in many unnecessary copies
63+
64+
kwargs = node.kwargs
65+
66+
input = kwargs['input']
67+
residual = kwargs['residual']
68+
69+
# Create a new call to torch.ops._C.rotary_embedding.default
70+
with graph.inserting_before(node):
71+
# just insert the call to the custom op
72+
# NOTE: don't run dead code elimination,
73+
# otherwise this op will be removed
74+
graph.call_function(
75+
torch.ops._C.fused_add_rms_norm.default, kwargs=kwargs)
76+
77+
for user in list(node.users):
78+
if user.op == 'call_function' and user.target == operator.getitem: # noqa
79+
# Remove the getitem node
80+
if user.args[1] == 1:
81+
replace_node = input
82+
elif user.args[1] == 2:
83+
replace_node = residual
84+
user.replace_all_uses_with(replace_node)
85+
nodes_to_remove.append(user)
86+
nodes_to_remove.append(node)
87+
88+
elif node.args[0] == torch.ops._C.rms_norm.default:
89+
# manual replace for rms_norm
90+
91+
kwargs = node.kwargs
92+
93+
input = kwargs['input']
94+
out = kwargs['out']
95+
weight = kwargs['weight']
96+
epsilon = kwargs['epsilon']
97+
# Create a new call to torch.ops._C.rotary_embedding.default
98+
# cannot use kwargs, because we have an `out`, see https://github.com/pytorch/pytorch/blob/a00faf440888ffb724bad413f329a49e2b6388e7/torch/_inductor/lowering.py#L351 # noqa
99+
with graph.inserting_before(node):
100+
# just insert the call to the custom op
101+
# NOTE: don't run dead code elimination,
102+
# otherwise this op will be removed
103+
graph.call_function(
104+
torch.ops._C.rms_norm.default,
105+
args=(out, input, weight, epsilon),
106+
)
107+
108+
replace_node = out
109+
110+
for user in list(node.users):
111+
if user.op == 'call_function' and user.target == operator.getitem: # noqa
112+
user.replace_all_uses_with(replace_node)
113+
nodes_to_remove.append(user)
114+
nodes_to_remove.append(node)
115+
116+
elif node.args[0] == torch.ops._C.silu_and_mul.default:
117+
# manual replace for silu_and_mul
118+
119+
kwargs = node.kwargs
120+
121+
input = kwargs['input']
122+
out = kwargs['out']
123+
124+
# Create a new call to torch.ops._C.rotary_embedding.default
125+
# cannot use kwargs, because we have an `out`, see https://github.com/pytorch/pytorch/blob/a00faf440888ffb724bad413f329a49e2b6388e7/torch/_inductor/lowering.py#L351 # noqa
126+
with graph.inserting_before(node):
127+
# just insert the call to the custom op
128+
# NOTE: don't run dead code elimination,
129+
# otherwise this op will be removed
130+
graph.call_function(
131+
torch.ops._C.silu_and_mul.default,
132+
args=(out, input),
133+
)
134+
replace_node = out
135+
136+
for user in list(node.users):
137+
if user.op == 'call_function' and user.target == operator.getitem: # noqa
138+
user.replace_all_uses_with(replace_node)
139+
nodes_to_remove.append(user)
140+
nodes_to_remove.append(node)
141+
142+
# Remove the nodes all at once
143+
for node in nodes_to_remove:
144+
graph.erase_node(node)
145+
146+
# debug code, if we want to see the graph after the transformation
147+
# with open("after.py", "w") as f:
148+
# print(graph.python_code(root_module="self", verbose=True).src, file=f)
149+
150+
151+
def vllm_backend(graph, example_inputs):
152+
from torch._inductor import config
153+
current_config = config.shallow_copy_dict()
154+
from torch._inductor.compile_fx import compile_fx
155+
current_config['post_grad_custom_post_pass'] = fix_functionalization
156+
return compile_fx(graph, example_inputs, config_patches=current_config)

vllm/worker/model_runner.py

+2-1
Original file line numberDiff line numberDiff line change
@@ -1064,8 +1064,9 @@ def load_model(self) -> None:
10641064
"This may lead to less accurate results!")
10651065

10661066
if envs.VLLM_TEST_DYNAMO_GRAPH_CAPTURE and supports_dynamo():
1067+
from vllm.compilation.backends import vllm_backend
10671068
from vllm.plugins import get_torch_compile_backend
1068-
backend = get_torch_compile_backend() or "eager"
1069+
backend = get_torch_compile_backend() or vllm_backend
10691070
self.model = torch.compile(
10701071
self.model,
10711072
fullgraph=envs.VLLM_TEST_DYNAMO_FULLGRAPH_CAPTURE,

0 commit comments

Comments
 (0)