Skip to content
New issue

Have a question about this project? # for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “#”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? # to your account

[v1][torch.compile] support managing cudagraph buffer #10203

Merged
merged 6 commits into from
Nov 11, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion tests/compile/piecewise/piecewise_compilation_config.json
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
{
"use_cudagraph": true,
"non_cudagraph_ops": ["silly.attention"]
"non_cudagraph_ops": ["silly.attention"],
"cudagraph_copy_inputs": true
}
12 changes: 6 additions & 6 deletions tests/compile/piecewise/test_simple.py
Original file line number Diff line number Diff line change
Expand Up @@ -80,7 +80,7 @@ def test_simple_piecewise_compile():
config = os.path.join(directory, "piecewise_compilation_config.json")
os.environ["VLLM_TORCH_COMPILE_CONFIG"] = config

input_buffer = torch.randn(100).cuda()
inputs = torch.randn(100).cuda()

with compilation_counter.expect(
num_graphs_seen=1, # one graph for the model
Expand All @@ -92,15 +92,15 @@ def test_simple_piecewise_compile():
):

with set_compile_context([1, 2]):
model(input_buffer)
model(inputs)

model(input_buffer[:2])
model(input_buffer[:1])
model(torch.randn(2).cuda())
model(torch.randn(1).cuda())

input_buffer[:2].zero_()
input = torch.zeros(2).cuda()
global global_counter
global_counter = 0
output = model(input_buffer[:2])
output = model(input)
assert global_counter == 2
assert torch.allclose(output.cpu(), torch.tensor([3., 1.]))

Expand Down
46 changes: 45 additions & 1 deletion vllm/compilation/backends.py
Original file line number Diff line number Diff line change
Expand Up @@ -389,6 +389,8 @@ class VllmBackend:
returned_callable: Callable
# Inductor passes to run on the graph pre-defunctionalization
post_grad_passes: Sequence[Callable]
sym_tensor_indices: List[int]
input_buffers: List[torch.Tensor]

def __init__(self, post_grad_passes: Sequence[Callable] = ()):
global global_graph_pool
Expand All @@ -401,6 +403,9 @@ def __init__(self, post_grad_passes: Sequence[Callable] = ()):
self.graph_pool = global_graph_pool
self.post_grad_passes = post_grad_passes

self.sym_tensor_indices = []
self.input_buffers = []

# `torch.compile` is JIT compiled, so we don't need to
# do anything here

Expand Down Expand Up @@ -461,7 +466,46 @@ def __call__(self, graph: fx.GraphModule, example_inputs) -> Callable:

self._called = True

return self.split_gm
if not self.compilation_configs.use_cudagraph or \
not self.compilation_configs.cudagraph_copy_inputs:
return self.split_gm

# if we need to copy input buffers for cudagraph
from torch._guards import detect_fake_mode
fake_mode = detect_fake_mode()
fake_args = [
fake_mode.from_tensor(t) if isinstance(t, torch.Tensor) else t
for t in example_inputs
]

# index of tensors that have symbolic shapes (batch size)
self.sym_tensor_indices = [
i for i, x in enumerate(fake_args)
if isinstance(x, torch._subclasses.fake_tensor.FakeTensor)
]

# compiler managed cudagraph input buffers
# we assume the first run with symbolic shapes
# has the maximum size among all the tensors
self.input_buffers = [
example_inputs[x].clone() for x in self.sym_tensor_indices
]

def copy_and_call(*args):
list_args = list(args)
for i, index in enumerate(self.sym_tensor_indices):
runtime_tensor = list_args[index]
runtime_shape = runtime_tensor.shape[0]
static_tensor = self.input_buffers[i][:runtime_shape]

# copy the tensor to the static buffer
static_tensor.copy_(runtime_tensor)

# replace the tensor in the list_args to the static buffer
list_args[index] = static_tensor
return self.split_gm(*list_args)

return copy_and_call


@dataclasses.dataclass
Expand Down
6 changes: 6 additions & 0 deletions vllm/compilation/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,11 @@ class CompilationConfig(BaseModel):
It means the first several runs will be treated as warmup runs.
Only after that, the execution will be recorded, and the recorded
cudagraph will be used for subsequent runs.
- cudagraph_copy_inputs: whether to copy input tensors for
cudagraph. If the caller can guarantee that the same input buffers
are always used, it can set this to False. Otherwise, it should
set this to True, and the compiler will copy the input to an
internally managed buffer. Default is False.
- Inductor compilation:
- use_inductor: whether to use inductor compilation.
- False: inductor compilation is not used. graph runs in eager.
Expand Down Expand Up @@ -78,6 +83,7 @@ class CompilationConfig(BaseModel):
non_cudagraph_ops: List[str] = Field(default_factory=list)
cudagraph_num_of_warmups: int = 0
cudagraph_capture_sizes: Optional[List[int]] = None
cudagraph_copy_inputs: bool = False

dump_graph_stages: List[str] = Field(default_factory=list)
dump_graph_dir: Path = Field(default=Path("."))
Expand Down