Skip to content

Commit 5dd7c9f

Browse files
youkaichaoWoosukKwon
authored andcommitted
[v1][torch.compile] support managing cudagraph buffer (vllm-project#10203)
Signed-off-by: youkaichao <youkaichao@gmail.com> Co-authored-by: Woosuk Kwon <woosuk.kwon@berkeley.edu>
1 parent eb0bc81 commit 5dd7c9f

File tree

4 files changed

+59
-8
lines changed

4 files changed

+59
-8
lines changed
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
{
22
"use_cudagraph": true,
3-
"non_cudagraph_ops": ["silly.attention"]
3+
"non_cudagraph_ops": ["silly.attention"],
4+
"cudagraph_copy_inputs": true
45
}

tests/compile/piecewise/test_simple.py

+6-6
Original file line numberDiff line numberDiff line change
@@ -80,7 +80,7 @@ def test_simple_piecewise_compile():
8080
config = os.path.join(directory, "piecewise_compilation_config.json")
8181
os.environ["VLLM_TORCH_COMPILE_CONFIG"] = config
8282

83-
input_buffer = torch.randn(100).cuda()
83+
inputs = torch.randn(100).cuda()
8484

8585
with compilation_counter.expect(
8686
num_graphs_seen=1, # one graph for the model
@@ -92,15 +92,15 @@ def test_simple_piecewise_compile():
9292
):
9393

9494
with set_compile_context([1, 2]):
95-
model(input_buffer)
95+
model(inputs)
9696

97-
model(input_buffer[:2])
98-
model(input_buffer[:1])
97+
model(torch.randn(2).cuda())
98+
model(torch.randn(1).cuda())
9999

100-
input_buffer[:2].zero_()
100+
input = torch.zeros(2).cuda()
101101
global global_counter
102102
global_counter = 0
103-
output = model(input_buffer[:2])
103+
output = model(input)
104104
assert global_counter == 2
105105
assert torch.allclose(output.cpu(), torch.tensor([3., 1.]))
106106

vllm/compilation/backends.py

+45-1
Original file line numberDiff line numberDiff line change
@@ -389,6 +389,8 @@ class VllmBackend:
389389
returned_callable: Callable
390390
# Inductor passes to run on the graph pre-defunctionalization
391391
post_grad_passes: Sequence[Callable]
392+
sym_tensor_indices: List[int]
393+
input_buffers: List[torch.Tensor]
392394

393395
def __init__(self, post_grad_passes: Sequence[Callable] = ()):
394396
global global_graph_pool
@@ -401,6 +403,9 @@ def __init__(self, post_grad_passes: Sequence[Callable] = ()):
401403
self.graph_pool = global_graph_pool
402404
self.post_grad_passes = post_grad_passes
403405

406+
self.sym_tensor_indices = []
407+
self.input_buffers = []
408+
404409
# `torch.compile` is JIT compiled, so we don't need to
405410
# do anything here
406411

@@ -461,7 +466,46 @@ def __call__(self, graph: fx.GraphModule, example_inputs) -> Callable:
461466

462467
self._called = True
463468

464-
return self.split_gm
469+
if not self.compilation_configs.use_cudagraph or \
470+
not self.compilation_configs.cudagraph_copy_inputs:
471+
return self.split_gm
472+
473+
# if we need to copy input buffers for cudagraph
474+
from torch._guards import detect_fake_mode
475+
fake_mode = detect_fake_mode()
476+
fake_args = [
477+
fake_mode.from_tensor(t) if isinstance(t, torch.Tensor) else t
478+
for t in example_inputs
479+
]
480+
481+
# index of tensors that have symbolic shapes (batch size)
482+
self.sym_tensor_indices = [
483+
i for i, x in enumerate(fake_args)
484+
if isinstance(x, torch._subclasses.fake_tensor.FakeTensor)
485+
]
486+
487+
# compiler managed cudagraph input buffers
488+
# we assume the first run with symbolic shapes
489+
# has the maximum size among all the tensors
490+
self.input_buffers = [
491+
example_inputs[x].clone() for x in self.sym_tensor_indices
492+
]
493+
494+
def copy_and_call(*args):
495+
list_args = list(args)
496+
for i, index in enumerate(self.sym_tensor_indices):
497+
runtime_tensor = list_args[index]
498+
runtime_shape = runtime_tensor.shape[0]
499+
static_tensor = self.input_buffers[i][:runtime_shape]
500+
501+
# copy the tensor to the static buffer
502+
static_tensor.copy_(runtime_tensor)
503+
504+
# replace the tensor in the list_args to the static buffer
505+
list_args[index] = static_tensor
506+
return self.split_gm(*list_args)
507+
508+
return copy_and_call
465509

466510

467511
@dataclasses.dataclass

vllm/compilation/config.py

+6
Original file line numberDiff line numberDiff line change
@@ -32,6 +32,11 @@ class CompilationConfig(BaseModel):
3232
It means the first several runs will be treated as warmup runs.
3333
Only after that, the execution will be recorded, and the recorded
3434
cudagraph will be used for subsequent runs.
35+
- cudagraph_copy_inputs: whether to copy input tensors for
36+
cudagraph. If the caller can guarantee that the same input buffers
37+
are always used, it can set this to False. Otherwise, it should
38+
set this to True, and the compiler will copy the input to an
39+
internally managed buffer. Default is False.
3540
- Inductor compilation:
3641
- use_inductor: whether to use inductor compilation.
3742
- False: inductor compilation is not used. graph runs in eager.
@@ -78,6 +83,7 @@ class CompilationConfig(BaseModel):
7883
non_cudagraph_ops: List[str] = Field(default_factory=list)
7984
cudagraph_num_of_warmups: int = 0
8085
cudagraph_capture_sizes: Optional[List[int]] = None
86+
cudagraph_copy_inputs: bool = False
8187

8288
dump_graph_stages: List[str] = Field(default_factory=list)
8389
dump_graph_dir: Path = Field(default=Path("."))

0 commit comments

Comments
 (0)