@@ -389,6 +389,8 @@ class VllmBackend:
389
389
returned_callable : Callable
390
390
# Inductor passes to run on the graph pre-defunctionalization
391
391
post_grad_passes : Sequence [Callable ]
392
+ sym_tensor_indices : List [int ]
393
+ input_buffers : List [torch .Tensor ]
392
394
393
395
def __init__ (self , post_grad_passes : Sequence [Callable ] = ()):
394
396
global global_graph_pool
@@ -401,6 +403,9 @@ def __init__(self, post_grad_passes: Sequence[Callable] = ()):
401
403
self .graph_pool = global_graph_pool
402
404
self .post_grad_passes = post_grad_passes
403
405
406
+ self .sym_tensor_indices = []
407
+ self .input_buffers = []
408
+
404
409
# `torch.compile` is JIT compiled, so we don't need to
405
410
# do anything here
406
411
@@ -461,7 +466,46 @@ def __call__(self, graph: fx.GraphModule, example_inputs) -> Callable:
461
466
462
467
self ._called = True
463
468
464
- return self .split_gm
469
+ if not self .compilation_configs .use_cudagraph or \
470
+ not self .compilation_configs .cudagraph_copy_inputs :
471
+ return self .split_gm
472
+
473
+ # if we need to copy input buffers for cudagraph
474
+ from torch ._guards import detect_fake_mode
475
+ fake_mode = detect_fake_mode ()
476
+ fake_args = [
477
+ fake_mode .from_tensor (t ) if isinstance (t , torch .Tensor ) else t
478
+ for t in example_inputs
479
+ ]
480
+
481
+ # index of tensors that have symbolic shapes (batch size)
482
+ self .sym_tensor_indices = [
483
+ i for i , x in enumerate (fake_args )
484
+ if isinstance (x , torch ._subclasses .fake_tensor .FakeTensor )
485
+ ]
486
+
487
+ # compiler managed cudagraph input buffers
488
+ # we assume the first run with symbolic shapes
489
+ # has the maximum size among all the tensors
490
+ self .input_buffers = [
491
+ example_inputs [x ].clone () for x in self .sym_tensor_indices
492
+ ]
493
+
494
+ def copy_and_call (* args ):
495
+ list_args = list (args )
496
+ for i , index in enumerate (self .sym_tensor_indices ):
497
+ runtime_tensor = list_args [index ]
498
+ runtime_shape = runtime_tensor .shape [0 ]
499
+ static_tensor = self .input_buffers [i ][:runtime_shape ]
500
+
501
+ # copy the tensor to the static buffer
502
+ static_tensor .copy_ (runtime_tensor )
503
+
504
+ # replace the tensor in the list_args to the static buffer
505
+ list_args [index ] = static_tensor
506
+ return self .split_gm (* list_args )
507
+
508
+ return copy_and_call
465
509
466
510
467
511
@dataclasses .dataclass
0 commit comments