diff --git a/tests/compile/backend.py b/tests/compile/backend.py index 9d5c68274374e..8fa10e5bd1b37 100644 --- a/tests/compile/backend.py +++ b/tests/compile/backend.py @@ -1,7 +1,9 @@ from copy import deepcopy -from typing import Callable +from typing import Callable, Union -import torch +from torch import fx + +from vllm.compilation.inductor_pass import InductorPass class TestBackend: @@ -11,19 +13,21 @@ class TestBackend: It also saves the graph before and after the custom passes for inspection. """ - def __init__(self, *args: Callable[[torch.fx.Graph], None]): - self.custom_passes = args + def __init__(self, *passes: Union[InductorPass, Callable[[fx.Graph], + None]]): + self.custom_passes = list(passes) from torch._inductor import config self.current_config = config.shallow_copy_dict() + self.current_config['force_disable_caches'] = True self.current_config['post_grad_custom_post_pass'] = self.post_pass - def __call__(self, graph: torch.fx.GraphModule, example_inputs): + def __call__(self, graph: fx.GraphModule, example_inputs): from torch._inductor.compile_fx import compile_fx return compile_fx(graph, example_inputs, config_patches=self.current_config) - def post_pass(self, graph: torch.fx.Graph): + def post_pass(self, graph: fx.Graph): self.graph_pre_pass = deepcopy(graph) for pass_ in self.custom_passes: pass_(graph) diff --git a/tests/compile/test_functionalization.py b/tests/compile/test_functionalization.py new file mode 100644 index 0000000000000..5036189077be2 --- /dev/null +++ b/tests/compile/test_functionalization.py @@ -0,0 +1,95 @@ +import pytest +import torch + +import vllm.envs as envs +from vllm import LLM, SamplingParams +from vllm.compilation.fix_functionalization import FixFunctionalizationPass +from vllm.compilation.fusion import (FusionPass, find_auto_fn, + find_auto_fn_maybe) +from vllm.compilation.reshapes import RedundantReshapesPass +from vllm.compilation.vllm_inductor_pass import is_func +from vllm.config import CompilationConfig + +from .backend import TestBackend + +OPS_IN_MODEL = [ + torch.ops._C.rotary_embedding.default, + torch.ops._C.fused_add_rms_norm.default, + torch.ops._C.silu_and_mul.default, +] + +RMS_OP = torch.ops._C.rms_norm.default + +RMS_QUANT_OPS = { + "static_fp8": [ + torch.ops._C.rms_norm_static_fp8_quant.default, + torch.ops._C.fused_add_rms_norm_static_fp8_quant.default + ], +} + +prompts = [ + "Hello, my name is", + "The president of the United States is", + "The capital of France is", + "The future of AI is", +] + + +@pytest.mark.parametrize("model", + ["nm-testing/TinyLlama-1.1B-Chat-v1.0-FP8-e2e"]) +@pytest.mark.parametrize("do_fusion", [True, False]) +@pytest.mark.skipif(envs.VLLM_TARGET_DEVICE != "cuda", + reason="Only test on CUDA") +def test_fix_functionalization(model: str, do_fusion: bool): + torch.set_default_device("cuda") + + config = CompilationConfig.PassConfig(enable_fusion=do_fusion, + enable_reshape=True) + reshape_pass = RedundantReshapesPass(config) + fusion_pass = FusionPass.instance(config) + + passes = [reshape_pass, fusion_pass] if do_fusion else [reshape_pass] + func_pass = FixFunctionalizationPass(config) + backend_func = TestBackend(*passes, func_pass) + backend_no_func = TestBackend(*passes) + + # instantiate a full engine and manually compile the model 2x + # (with and without FixFunctionalizationPass) + llm = LLM(model=model, enforce_eager=True) + model_runner = llm.llm_engine.model_executor.driver_worker.model_runner + orig_model = model_runner.model + # TODO mark inputs dynamic? (currently torch.compile is triggered 4x) + # Can only do that by using the decorator but then we'd have to instantiate + # 2 LLM instances. + + sampling_params = SamplingParams(temperature=0.0, top_p=1.0) + model_runner.model = torch.compile(orig_model, + fullgraph=True, + backend=backend_func) + gen_func = llm.generate(prompts, sampling_params) + + model_runner.model = torch.compile(orig_model, + fullgraph=True, + backend=backend_no_func) + gen_no_func = llm.generate(prompts, sampling_params) + + for output_func, output_no_func in zip(gen_func, gen_no_func): + assert output_func.outputs[0].text == output_no_func.outputs[0].text + + # OPS_IN_MODEL always appear. RMS_OP is fused away if we run fusion, + # and replaced by fused quantized ops in RMS_QUANT_OPS. + ops = OPS_IN_MODEL + (RMS_QUANT_OPS["static_fp8"] + if do_fusion else [RMS_OP]) + + for op in ops: + find_auto_fn(backend_no_func.graph_post_pass.nodes, op) + assert find_auto_fn_maybe(backend_func.graph_post_pass.nodes, + op) is None # noqa: E501 + + # make sure the ops were all de-functionalized + found = dict() + for node in backend_func.graph_post_pass.nodes: + for op in ops: + if is_func(node, op): + found[op] = True + assert all(found[op] for op in ops) diff --git a/tests/compile/test_fusion.py b/tests/compile/test_fusion.py index 4db79b070fd8d..f92ec8d0de5f1 100644 --- a/tests/compile/test_fusion.py +++ b/tests/compile/test_fusion.py @@ -38,12 +38,6 @@ def forward(self, x): return y3 -# Init does pattern registration, which can only happen once -config = CompilationConfig(enable_fusion=True) -reshape_pass = RedundantReshapesPass(config) -fusion_pass = FusionPass.instance(config) - - @pytest.mark.parametrize("dtype", [torch.float16, torch.bfloat16]) @pytest.mark.parametrize("hidden_size", [64, 3392, 4096]) @pytest.mark.parametrize("num_tokens", [7, 256, 533, 2048, 2049]) @@ -58,6 +52,11 @@ def test_fusion_rmsnorm_quant(dtype, hidden_size, num_tokens, eps): pytest.skip("Only test eps=1e-5 for now") # Reshape pass is needed for the fusion pass to work + config = CompilationConfig.PassConfig(enable_fusion=True, + enable_reshape=True) + reshape_pass = RedundantReshapesPass(config) + fusion_pass = FusionPass.instance(config) + backend = TestBackend(reshape_pass, fusion_pass) model = TestModel(hidden_size, eps) diff --git a/tests/compile/test_pass_manager.py b/tests/compile/test_pass_manager.py new file mode 100644 index 0000000000000..03e7535093c5d --- /dev/null +++ b/tests/compile/test_pass_manager.py @@ -0,0 +1,35 @@ +import pickle + +import pytest +import torch +from torch._inductor.codecache import BypassFxGraphCache + +from vllm.compilation.config import CompilationConfig +from vllm.compilation.inductor_pass import (CallableInductorPass, + as_inductor_pass) +from vllm.compilation.pass_manager import PostGradPassManager + + +def simple_callable(graph: torch.fx.Graph): + pass + + +@as_inductor_pass(files=(__file__, )) +def callable_decorated(graph: torch.fx.Graph): + pass + + +@pytest.mark.parametrize( + "works, callable", + [(False, simple_callable), (True, callable_decorated), + (True, CallableInductorPass(simple_callable, "simple_callable"))]) +def test_pass_manager(works: bool, callable): + config = CompilationConfig().pass_config + pass_manager = PostGradPassManager([callable]) + pass_manager.configure(config) # Adds default passes + + if works: + pickle.dumps(pass_manager) + else: + with pytest.raises(BypassFxGraphCache): + pickle.dumps(pass_manager) diff --git a/vllm/compilation/backends.py b/vllm/compilation/backends.py index 416cffd326489..464bc2af8fd6d 100644 --- a/vllm/compilation/backends.py +++ b/vllm/compilation/backends.py @@ -1,6 +1,5 @@ import copy import dataclasses -import operator from contextlib import ExitStack from typing import Any, Callable, Dict, List, Optional, Sequence, Set, Tuple from unittest.mock import patch @@ -11,205 +10,15 @@ import vllm.envs as envs from vllm.config import CompilationConfig from vllm.logger import init_logger -from vllm.utils import combine_fx_passes, weak_ref_tensors +from vllm.utils import weak_ref_tensors from .counter import compilation_counter -from .fusion import FusionPass -from .reshapes import RedundantReshapesPass +from .inductor_pass import InductorPass +from .pass_manager import PostGradPassManager logger = init_logger(__name__) -def fix_functionalization(graph: fx.Graph): - """ - Rewrite the graph module to replace the pattern involving - torch._higher_order_ops.auto_functionalize.auto_functionalized - with a direct call to the inplace custom op. - - # TODO: check if PyTorch nightly has fixed this issue - """ - - # debug code, if we want to see the graph before the transformation - # with open("before.py", "w") as f: - # print(graph.python_code(root_module="self", verbose=True).src, file=f) - - nodes_to_remove = [] - - for node in graph.nodes: - # Identify the auto_functionalized node - if node.op == 'call_function' and node.target == torch._higher_order_ops.auto_functionalize.auto_functionalized: # noqa - if node.args[0] == torch.ops._C.rotary_embedding.default: - # manual replace for rotary_embedding - - # Now, collect the arguments - kwargs = node.kwargs - - query = kwargs['query'] - mm_node = query.args[0].args[0] - - # Create a new call to torch.ops._C.rotary_embedding.default - with graph.inserting_before(node): - # just insert the call to the custom op - # NOTE: don't run dead code elimination, - # otherwise this op will be removed - graph.call_function(torch.ops._C.rotary_embedding.default, - kwargs=kwargs) - - # Remove the auto_functionalized node - # Since the node may have outputs, we need to handle its users - # Replace uses of the outputs (getitem nodes) with mm_node - for user in list(node.users): - if user.op == 'call_function' and user.target == operator.getitem: # noqa - # Remove the getitem node - for getitem_user in list(user.users): - if (getitem_user.op == 'call_function' - and getitem_user.target - == torch.ops.aten.slice_scatter.default): - # Replace the uses of slice_scatter node - # with mm_node - getitem_user.replace_all_uses_with(mm_node) - nodes_to_remove.append(getitem_user) - nodes_to_remove.append(user) - nodes_to_remove.append(node) - - elif node.args[0] == torch.ops._C.fused_add_rms_norm.default: - # manual replace for fused_add_rms_norm - # this is the most effective optimization for llama - # failing to do this will result in many unnecessary copies - - kwargs = node.kwargs - - input = kwargs['input'] - residual = kwargs['residual'] - - # Create a new call to torch.ops._C.rotary_embedding.default - with graph.inserting_before(node): - # just insert the call to the custom op - # NOTE: don't run dead code elimination, - # otherwise this op will be removed - graph.call_function( - torch.ops._C.fused_add_rms_norm.default, kwargs=kwargs) - - for user in list(node.users): - if user.op == 'call_function' and user.target == operator.getitem: # noqa - # Remove the getitem node - if user.args[1] == 1: - replace_node = input - elif user.args[1] == 2: - replace_node = residual - user.replace_all_uses_with(replace_node) - nodes_to_remove.append(user) - nodes_to_remove.append(node) - elif (node.args[0] == - torch.ops._C.fused_add_rms_norm_static_fp8_quant.default): - # manual replace for fused_add_rms_norm_static_fp8_quant - # this is the most effective optimization for llama - # failing to do this will result in many unnecessary copies - - kwargs = node.kwargs - - result = kwargs['result'] - residual = kwargs['residual'] - - # Create a new call to - # torch.ops._C.fused_add_rms_norm_static_fp8_quant.default - with graph.inserting_before(node): - # just insert the call to the custom op - # NOTE: don't run dead code elimination, - # otherwise this op will be removed - graph.call_function( - torch.ops._C.fused_add_rms_norm_static_fp8_quant. - default, - kwargs=kwargs) - - for user in list(node.users): - if user.op == 'call_function' and user.target == operator.getitem: # noqa - # Remove the getitem node - if user.args[1] == 1: - replace_node = result - elif user.args[1] == 2: - replace_node = residual - user.replace_all_uses_with(replace_node) - nodes_to_remove.append(user) - nodes_to_remove.append(node) - - elif node.args[0] == torch.ops._C.rms_norm.default: - # manual replace for rms_norm - - kwargs = node.kwargs - - replace_node = kwargs['result'] - # Create a new call to torch.ops._C.rms_norm.default - with graph.inserting_before(node): - # just insert the call to the custom op - # NOTE: don't run dead code elimination, - # otherwise this op will be removed - graph.call_function(torch.ops._C.rms_norm.default, - kwargs=kwargs) - - for user in list(node.users): - if user.op == 'call_function' and user.target == operator.getitem: # noqa - user.replace_all_uses_with(replace_node) - nodes_to_remove.append(user) - nodes_to_remove.append(node) - - elif node.args[ - 0] == torch.ops._C.rms_norm_static_fp8_quant.default: # noqa - # manual replace for rms_norm_static_fp8_quant - - kwargs = node.kwargs - - replace_node = kwargs['result'] - # Create a new call to torch.ops._C.rms_norm_static_fp8_quant.default # noqa - with graph.inserting_before(node): - # just insert the call to the custom op - # NOTE: don't run dead code elimination, - # otherwise this op will be removed - graph.call_function( - torch.ops._C.rms_norm_static_fp8_quant.default, - kwargs=kwargs) - - for user in list(node.users): - if user.op == 'call_function' and user.target == operator.getitem: # noqa - user.replace_all_uses_with(replace_node) - nodes_to_remove.append(user) - nodes_to_remove.append(node) - - elif node.args[0] == torch.ops._C.silu_and_mul.default: - # manual replace for silu_and_mul - - kwargs = node.kwargs - - input = kwargs['input'] - out = kwargs['out'] - - # Create a new call to torch.ops._C.silu_and_mul.default - # cannot use kwargs, because we have an `out`, see https://github.com/pytorch/pytorch/blob/a00faf440888ffb724bad413f329a49e2b6388e7/torch/_inductor/lowering.py#L351 # noqa - with graph.inserting_before(node): - # just insert the call to the custom op - # NOTE: don't run dead code elimination, - # otherwise this op will be removed - graph.call_function( - torch.ops._C.silu_and_mul.default, - args=(out, input), - ) - replace_node = out - - for user in list(node.users): - if user.op == 'call_function' and user.target == operator.getitem: # noqa - user.replace_all_uses_with(replace_node) - nodes_to_remove.append(user) - nodes_to_remove.append(node) - - # Remove the nodes all at once - for node in nodes_to_remove: - graph.erase_node(node) - - # debug code, if we want to see the graph after the transformation - # with open("after.py", "w") as f: - # print(graph.python_code(root_module="self", verbose=True).src, file=f) - - def wrap_inductor(graph, example_inputs, additional_inductor_config, @@ -368,12 +177,8 @@ class VllmBackend: The major work of this backend is to split the graph into piecewise graphs, and pass them to the piecewise backend. - This backend also handles custom passes and adds them to Inductor config. - The order of the post-grad post-passes is: - 1. post_grad_passes (constructor parameter) - 2. config["post_grad_custom_post_pass"] - 3. fix_functionalization - This way, all passes operate on a functionalized graph. + This backend also adds the PostGradPassManager to Inductor config, + which handles the post-grad passes. """ compilation_configs: CompilationConfig @@ -402,7 +207,9 @@ def __init__( # streams, it might not be safe to share a global pool. # only investigate this when we use multiple streams self.graph_pool = global_graph_pool - self.post_grad_passes = [] + + # Passes to run on the graph post-grad. + self.post_grad_pass_manager = PostGradPassManager() self.sym_tensor_indices = [] self.input_buffers = [] @@ -412,24 +219,19 @@ def __init__( # `torch.compile` is JIT compiled, so we don't need to # do anything here - def add_passes_to_config(self): + def configure_post_pass(self): config = self.compilation_configs - passes = list(self.post_grad_passes) - - passes = passes + [RedundantReshapesPass(config)] - - if config.enable_fusion: - passes = passes + [FusionPass.instance(config)] + self.post_grad_pass_manager.configure(config.pass_config) + # Post-grad custom passes are run using the post_grad_custom_post_pass + # hook. If a pass for that hook exists, add it to the pass manager. inductor_config = config.inductor_compile_config - if "post_grad_custom_post_pass" in inductor_config: - passes = passes + [inductor_config["post_grad_custom_post_pass"]] - - # add the fix_functionalization pass last, so that all other - # passes operate on a functionalized graph - passes = passes + [fix_functionalization] - combined_pass = combine_fx_passes(passes) - inductor_config["post_grad_custom_post_pass"] = combined_pass + PASS_KEY = "post_grad_custom_post_pass" + if PASS_KEY in inductor_config: + # Config should automatically wrap all inductor passes + assert isinstance(inductor_config[PASS_KEY], InductorPass) + self.post_grad_pass_manager.add(inductor_config[PASS_KEY]) + inductor_config[PASS_KEY] = self.post_grad_pass_manager def __call__(self, graph: fx.GraphModule, example_inputs) -> Callable: @@ -444,7 +246,7 @@ def __call__(self, graph: fx.GraphModule, example_inputs) -> Callable: # we get the sizes to capture for cudagraph # from compilation context self.compilation_configs.init_during_runtime() - self.add_passes_to_config() + self.configure_post_pass() self.split_gm, self.piecewise_graphs = split_graph( graph, self.compilation_configs.splitting_ops) diff --git a/vllm/compilation/fix_functionalization.py b/vllm/compilation/fix_functionalization.py new file mode 100644 index 0000000000000..3584cc3608caf --- /dev/null +++ b/vllm/compilation/fix_functionalization.py @@ -0,0 +1,177 @@ +import operator +from typing import Dict, Iterable, List, Optional, Tuple, Union + +import torch +from torch._higher_order_ops.auto_functionalize import auto_functionalized + +from vllm.logger import init_logger + +from .vllm_inductor_pass import VllmInductorPass, is_func + +logger = init_logger(__name__) + + +class FixFunctionalizationPass(VllmInductorPass): + """ + This pass defunctionalizes certain nodes to avoid redundant tensor copies. + After this pass, DCE (dead-code elimination) should never be run, + as de-functionalized nodes may appear as dead code. + + To add new nodes to defunctionalize, add to the if-elif chain in __call__. + """ + + def __call__(self, graph: torch.fx.Graph): + self.begin() + self.dump_graph(graph, "before_fix_functionalization") + + self.nodes_to_remove: List[torch.fx.Node] = [] + count = 0 + for node in graph.nodes: + if not is_func(node, auto_functionalized): + continue # Avoid deep if-elif nesting + + kwargs = node.kwargs + at_target = node.args[0] + + if at_target == torch.ops._C.rotary_embedding.default: + query = kwargs['query'] + mm_node = query.args[0].args[0] + + # rotary_embedding is a special case: the two mutating inputs + # are query and key, which are slices of mm_node. + # While functionalized, results at[1] and at[2] are scattered + # back into mm_node. After de-functionalization, we can just + # use mm_node directly. + for idx, user in self.getitem_users(node).items(): + for user_of_getitem in user.users: + if is_func(user_of_getitem, + torch.ops.aten.slice_scatter.default): + user_of_getitem.replace_all_uses_with(mm_node) + self._remove(user_of_getitem) + self._remove(user) + + self.insert_defunctionalized(graph, node) + self._remove(node) + + # These 2 replacements avoid the most copies for LLaMa. + elif at_target == torch.ops._C.fused_add_rms_norm.default: + mutated_args = {1: 'input', 2: 'residual'} + self.defunctionalize(graph, node, mutated_args) + elif at_target == torch.ops._C.fused_add_rms_norm_static_fp8_quant.default: # noqa: E501 + mutated_args = {1: 'result', 2: 'residual'} + self.defunctionalize(graph, node, mutated_args) + + elif at_target in [ + torch.ops._C.rms_norm.default, + torch.ops._C.rms_norm_static_fp8_quant.default + ]: + mutated_args = {1: 'result'} + self.defunctionalize(graph, node, mutated_args) + + elif at_target == torch.ops._C.silu_and_mul.default: + mutated_args = {1: 'out'} + # Because we have an 'out', need to specify args directly + self.defunctionalize(graph, + node, + mutated_args, + args=('out', 'input')) + else: + continue # skip the count + + count += 1 + + self.dump_graph(graph, "before_fix_functionalization_cleanup") + + # Remove the nodes all at once + count_removed = len(self.nodes_to_remove) + for node in self.nodes_to_remove: + graph.erase_node(node) + + logger.debug("De-functionalized %s nodes, removed %s nodes", count, + count_removed) + self.dump_graph(graph, "after_fix_functionalization") + self.end_and_log() + + def _remove(self, node_or_nodes: Union[torch.fx.Node, + Iterable[torch.fx.Node]]): + """ + Stage a node (or nodes) for removal at the end of the pass. + """ + if isinstance(node_or_nodes, torch.fx.Node): + self.nodes_to_remove.append(node_or_nodes) + else: + self.nodes_to_remove.extend(node_or_nodes) + + def defunctionalize(self, + graph: torch.fx.Graph, + node: torch.fx.Node, + mutated_args: Dict[int, Union[torch.fx.Node, str]], + args: Optional[Tuple[Union[torch.fx.Node, str], + ...]] = None): + """ + De-functionalize a node by replacing it with a call to the original. + It also replaces the getitem users with the mutated arguments. + See replace_users_with_mutated_args and insert_defunctionalized. + """ + self.replace_users_with_mutated_args(node, mutated_args) + self.insert_defunctionalized(graph, node, args=args) + self._remove(node) + + def replace_users_with_mutated_args(self, node: torch.fx.Node, + mutated_args: Dict[int, + Union[torch.fx.Node, + str]]): + """ + Replace all getitem users of the auto-functionalized node with the + mutated arguments. + :param node: The auto-functionalized node + :param mutated_args: The mutated arguments, indexed by getitem index. + If the value of an arg is a string, `node.kwargs[arg]` is used. + """ + for idx, user in self.getitem_users(node).items(): + arg = mutated_args[idx] + arg = node.kwargs[arg] if isinstance(arg, str) else arg + user.replace_all_uses_with(arg) + self._remove(user) + + def getitem_users(self, node: torch.fx.Node) -> Dict[int, torch.fx.Node]: + """ + Returns the operator.getitem users of the auto-functionalized node, + indexed by the index they are getting. + """ + users = {} + for user in node.users: + if is_func(user, operator.getitem): + idx = user.args[1] + users[idx] = user + return users + + def insert_defunctionalized(self, + graph: torch.fx.Graph, + node: torch.fx.Node, + args: Optional[Tuple[Union[torch.fx.Node, str], + ...]] = None): + """ + Insert a new defunctionalized node into the graph before node. + If one of the kwargs is 'out', provide args directly, + as node.kwargs cannot be used. + See https://github.com/pytorch/pytorch/blob/a00faf440888ffb724bad413f329a49e2b6388e7/torch/_inductor/lowering.py#L351 + + :param graph: Graph to insert the defunctionalized node into + :param node: The auto-functionalized node to defunctionalize + :param args: If we cannot use kwargs, specify args directly. + If an arg is a string, `node.kwargs[arg]` is used. + """ # noqa: E501 + assert is_func(node, auto_functionalized), \ + f"node must be auto-functionalized, is {node} instead" + + # Create a new call to the original function + with graph.inserting_before(node): + function = node.args[0] + if args is None: + graph.call_function(function, kwargs=node.kwargs) + else: + # Args passed as strings refer to items in node.kwargs + args = tuple(node.kwargs[arg] if isinstance(arg, str) else arg + for arg in args) + graph.call_function(function, args=args) diff --git a/vllm/compilation/fusion.py b/vllm/compilation/fusion.py index e6a3afef85e1b..5efa410fab6a0 100644 --- a/vllm/compilation/fusion.py +++ b/vllm/compilation/fusion.py @@ -6,10 +6,11 @@ from torch._inductor.pattern_matcher import (Match, PatternMatcherPass, fwd_only, register_replacement) -from vllm.compilation.inductor_pass import InductorPass from vllm.config import CompilationConfig from vllm.logger import init_logger +from .vllm_inductor_pass import VllmInductorPass, is_func + logger = init_logger(__name__) @@ -90,8 +91,6 @@ def empty_fp32(*args, **kwargs): # Utilities for post-processing multi-output matches -def is_func(node: torch.fx.Node, target) -> bool: - return node.op == "call_function" and node.target == target # Returns the first auto_functionalized node with the given op (if it exists) @@ -127,7 +126,7 @@ def find_getitem(node: torch.fx.Node, idx: int) -> torch.fx.Node: return ret -class FusionPass(InductorPass): +class FusionPass(VllmInductorPass): """ This pass fuses a pre-defined set of custom ops into fused ops. It uses the torch pattern matcher to find the patterns and replace them. @@ -142,7 +141,7 @@ class FusionPass(InductorPass): _instance: 'Optional[FusionPass]' = None @classmethod - def instance(cls, config: CompilationConfig): + def instance(cls, config: CompilationConfig.PassConfig): """ Get the singleton instance of the FusionPass. If the instance exists, the config is updated but @@ -154,7 +153,7 @@ def instance(cls, config: CompilationConfig): cls._instance.config = config return cls._instance - def __init__(self, config: CompilationConfig): + def __init__(self, config: CompilationConfig.PassConfig): assert self.__class__._instance is None, \ "FusionPass singleton instance already exists" super().__init__(config) @@ -278,6 +277,7 @@ def process_matches(self, graph: torch.fx.Graph): for node in match.nodes) def __call__(self, graph: torch.fx.Graph): + self.begin() self.dump_graph(graph, "before_fusion") count = self.patterns.apply(graph) @@ -289,3 +289,4 @@ def __call__(self, graph: torch.fx.Graph): logger.debug("Post-processed %s matches", len(self.matches)) self.dump_graph(graph, "after_fusion") self.matches.clear() + self.end_and_log() diff --git a/vllm/compilation/inductor_pass.py b/vllm/compilation/inductor_pass.py index 8082a08b40019..f6846c08ac841 100644 --- a/vllm/compilation/inductor_pass.py +++ b/vllm/compilation/inductor_pass.py @@ -1,38 +1,84 @@ +import hashlib +import inspect +import types from abc import ABC, abstractmethod +from typing import Any, Callable, Optional, Union import torch - -from vllm.config import CompilationConfig -# yapf: disable -from vllm.distributed import get_tensor_model_parallel_rank as get_tp_rank -from vllm.distributed import ( - get_tensor_model_parallel_world_size as get_tp_world_size) -from vllm.distributed import model_parallel_is_initialized as p_is_init -# yapf: enable -from vllm.logger import init_logger - -logger = init_logger(__name__) +from torch import fx class InductorPass(ABC): + """ + General custom inductor pass interface. + TODO(torch==2.6) use torch._inductor.custom_graph_pass.CustomGraphPass + """ @abstractmethod def __call__(self, graph: torch.fx.Graph): + """ + Execute the pass on the given graph. + """ raise NotImplementedError - def __init__(self, config: CompilationConfig): - self.config = config - - def dump_graph(self, graph: torch.fx.Graph, stage: str): - if stage in self.config.dump_graph_stages: - # Make sure filename includes rank in the distributed setting - parallel = p_is_init() and get_tp_world_size() > 1 - rank = f"-{get_tp_rank()}" if parallel else "" - filepath = self.config.dump_graph_dir / f"{stage}{rank}.py" - - logger.info("Printing graph to %s", filepath) - with open(filepath, "w") as f: - src = graph.python_code(root_module="self", verbose=True).src - # Add imports so it's not full of errors - print("import torch; from torch import device", file=f) - print(src, file=f) + def uuid(self) -> Any: + """ + Provide a unique identifier for the pass, used in Inductor code cache. + This should depend on the pass implementation, so that changes to the + pass result in recompilation. + By default, the object source is hashed. + """ + return InductorPass.hash_source(self) + + @staticmethod + def hash_source(*srcs: Union[str, Any]): + """ + Utility method to hash the sources of functions or objects. + :param srcs: strings or objects to add to the hash. + Objects and functions have their source inspected. + :return: + """ + hasher = hashlib.sha256() + for src in srcs: + if isinstance(src, str): + src_str = src + elif isinstance(src, types.FunctionType): + src_str = inspect.getsource(src) + else: + src_str = inspect.getsource(src.__class__) + hasher.update(src_str.encode("utf-8")) + return hasher.digest() + + +class CallableInductorPass(InductorPass): + """ + This class is a wrapper for a callable that automatically provides an + implementation of the UUID. + """ + + def __init__(self, + callable: Callable[[fx.Graph], None], + uuid: Optional[Any] = None): + self.callable = callable + if uuid is None: + uuid = InductorPass.hash_source(callable) + self._uuid = uuid + + def __call__(self, graph: torch.fx.Graph): + self.callable(graph) + + def uuid(self) -> Any: + return self._uuid + + def __getstate__(self): + """ + Pickling occurs in the Inductor code cache if a pass is not given to + the pass manager but is instead directly added to config as a pass. + See PostGradPassManager for more. + + TODO(torch==2.6), use the `uuid` method in CustomGraphPass instead. + """ + return self._uuid + + def __setstate__(self, state): + raise ValueError("Cannot unpickle CallableInductorPass") diff --git a/vllm/compilation/pass_manager.py b/vllm/compilation/pass_manager.py new file mode 100644 index 0000000000000..fb522ae053e97 --- /dev/null +++ b/vllm/compilation/pass_manager.py @@ -0,0 +1,77 @@ +from typing import List + +from torch import fx as fx + +from vllm.config import CompilationConfig +from vllm.logger import init_logger + +from .fix_functionalization import FixFunctionalizationPass +from .fusion import FusionPass +from .inductor_pass import InductorPass +from .reshapes import RedundantReshapesPass + +logger = init_logger(__name__) + + +class PostGradPassManager: + """ + The pass manager for post-grad passes. + It handles configuration, adding custom passes, and running passes. + It also supports pickling, which is used by the Inductor code cache. + TODO(torch==2.6), use CustomGraphPass + (torch._inductor.custom_graph_pass.CustomGraphPass) + + The order of the post-grad post-passes is: + 1. passes (constructor parameter) + 2. default passes (RedundantReshapesPass, FusionPass) + 3. config["post_grad_custom_post_pass"] (if it exists) + 4. fix_functionalization + This way, all passes operate on a functionalized graph. + """ + + def __init__(self): + self.passes: List[InductorPass] = [] + + def __call__(self, graph: fx.Graph): + for pass_ in self.passes: + pass_(graph) + + # always run fix_functionalization last + self.fix_functionalization(graph) + + def configure(self, pass_config: CompilationConfig.PassConfig): + self.pass_config = pass_config + if pass_config.enable_reshape: + self.passes += [RedundantReshapesPass(pass_config)] + + if pass_config.enable_fusion: + self.passes += [FusionPass.instance(pass_config)] + + self.fix_functionalization = FixFunctionalizationPass(pass_config) + + def add(self, pass_: InductorPass): + assert isinstance(pass_, InductorPass) + self.passes.append(pass_) + + def __getstate__(self): + """ + Custom pickling for the pass manager, as some passes cannot be pickled. + Pickling occurs because the pass manager is set as the value of + `config["post_grad_custom_post_pass"]` in the Inductor config. + The config is pickled to act as a key in the Inductor code cache. + Any other passes in the config are pickled as well. + + TODO(torch==2.6), use the `uuid` method in CustomGraphPass instead. + """ + state = {"pass_config": self.pass_config.uuid(), "passes": []} + for pass_ in self.passes: + state["passes"].append(pass_.uuid()) + state["passes"].append(self.fix_functionalization.uuid()) + return state + + def __setstate__(self, state): + """ + Do not allow unpickling of the pass manager. + If this is needed in the future, it should properly pickle the passes. + """ + raise ValueError("Cannot unpickle PostGradPassManager") diff --git a/vllm/compilation/reshapes.py b/vllm/compilation/reshapes.py index 36597e119d2e1..63a369fe8d966 100644 --- a/vllm/compilation/reshapes.py +++ b/vllm/compilation/reshapes.py @@ -3,14 +3,14 @@ import torch.fx from torch import SymInt -from vllm.compilation.fusion import is_func -from vllm.compilation.inductor_pass import InductorPass from vllm.logger import init_logger +from .vllm_inductor_pass import VllmInductorPass, is_func + logger = init_logger(__name__) -class RedundantReshapesPass(InductorPass): +class RedundantReshapesPass(VllmInductorPass): """ This is an inductor pass that removes redundant reshape operations. It is required for RMSNorm-quant fusion to work properly. @@ -31,6 +31,7 @@ class RedundantReshapesPass(InductorPass): """ def __call__(self, graph: torch.fx.Graph): + self.begin() self.dump_graph(graph, "before_reshapes") count = 0 # Remove no-op reshapes/views: @@ -56,6 +57,7 @@ def __call__(self, graph: torch.fx.Graph): logger.debug("Removed %s no-op reshapes", count) self.dump_graph(graph, "after_reshapes") + self.end_and_log() def dims_equivalent(self, dim: Union[int, torch.fx.Node], i_dim: Union[int, SymInt]) -> bool: diff --git a/vllm/compilation/vllm_inductor_pass.py b/vllm/compilation/vllm_inductor_pass.py new file mode 100644 index 0000000000000..dbf6b8f7789e1 --- /dev/null +++ b/vllm/compilation/vllm_inductor_pass.py @@ -0,0 +1,53 @@ +import time + +import torch + +from vllm.config import CompilationConfig +# yapf: disable +from vllm.distributed import get_tensor_model_parallel_rank as get_tp_rank +from vllm.distributed import ( + get_tensor_model_parallel_world_size as get_tp_world_size) +from vllm.distributed import model_parallel_is_initialized as p_is_init +# yapf: enable +from vllm.logger import init_logger + +from .inductor_pass import InductorPass + +logger = init_logger(__name__) + + +def is_func(node: torch.fx.Node, target) -> bool: + return node.op == "call_function" and node.target == target + + +class VllmInductorPass(InductorPass): + """ + An inductor pass with access to vLLM PassConfig. + It provides timing, logging, and dumping utilities. + """ + + def __init__(self, config: CompilationConfig.PassConfig): + self.config = config + self.pass_name = self.__class__.__name__ + + def dump_graph(self, graph: torch.fx.Graph, stage: str): + if stage in self.config.dump_graph_stages: + # Make sure filename includes rank in the distributed setting + parallel = p_is_init() and get_tp_world_size() > 1 + rank = f"-{get_tp_rank()}" if parallel else "" + filepath = self.config.dump_graph_dir / f"{stage}{rank}.py" + + logger.info("%s printing graph to %s", self.pass_name, filepath) + with open(filepath, "w") as f: + src = graph.python_code(root_module="self", verbose=True).src + # Add imports so it's not full of errors + print("import torch; from torch import device", file=f) + print(src, file=f) + + def begin(self): + self._start_time = time.perf_counter_ns() + + def end_and_log(self): + self._end_time = time.perf_counter_ns() + duration_ms = float(self._end_time - self._start_time) / 1.0e6 + logger.debug("%s completed in %.1f ms", self.pass_name, duration_ms) diff --git a/vllm/config.py b/vllm/config.py index 0ed92f370cf50..b2785e1ce2d5f 100644 --- a/vllm/config.py +++ b/vllm/config.py @@ -1,5 +1,6 @@ import copy import enum +import hashlib import json import warnings from dataclasses import dataclass, field, replace @@ -13,6 +14,7 @@ from transformers import PretrainedConfig import vllm.envs as envs +from vllm.compilation.inductor_pass import CallableInductorPass, InductorPass from vllm.logger import init_logger from vllm.model_executor.layers.quantization import (QUANTIZATION_METHODS, get_quantization_config) @@ -2120,12 +2122,7 @@ class CompilationConfig(BaseModel): name because the config uses json format. If we pass the config from Python, functions can also be passed directly via Python object constructor, e.g. `CompilationConfig(inductor_passes={"a": func})` - - custom inductor passes: - - dump_graph_stages: list of stages for which we want to dump the graph. - Each pass defines its own stages (before, after, maybe in-between). - - dump_graph_dir: directory to dump the graph. Default is . - - enable_fusion: whether to enable the custom fusion pass. - TODO better pass enabling system. + - custom inductor passes: see PassConfig for more details Why we have different sizes for cudagraph and inductor: - cudagraph: a cudagraph captured for a specific size can only be used @@ -2157,9 +2154,43 @@ class CompilationConfig(BaseModel): cudagraph_capture_sizes: Optional[List[int]] = None cudagraph_copy_inputs: bool = False - dump_graph_stages: List[str] = Field(default_factory=list) - dump_graph_dir: Path = Field(default=Path(".")) - enable_fusion: bool = True + class PassConfig(BaseModel): + """ + Configuration for custom Inductor passes. + This is separate from general CompilationConfig so that inductor passes + don't all have access to full configuration - that would create a cycle + as the PassManager is set as a property of config. + - dump_graph_stages: list of stages for which we want to dump the graph. + Each pass defines its own stages (before, after, maybe in-between). + - dump_graph_dir: directory to dump the graphs. Default is . + - enable_fusion: whether to enable the custom fusion pass. + - enable_reshape: whether to enable the custom reshape elimination pass. + TODO better pass enabling system. + """ + dump_graph_stages: List[str] = Field(default_factory=list) + dump_graph_dir: Path = Field(default=Path(".")) + enable_fusion: bool = True + enable_reshape: bool = True + + def uuid(self): + """ + Produces a hash unique to the pass configuration. + Any new fields that affect compilation should be added to the hash. + Do not include dump_graph_* in the hash - they don't affect + compilation. + """ + dict_ = self.model_dump( + include={"enable_fusion", "enable_reshape"}) + encoded = json.dumps(dict_, sort_keys=True).encode("utf-8") + return hashlib.sha256(encoded).digest() + + def model_post_init(self, __context: Any) -> None: + if not self.enable_reshape and self.enable_fusion: + print_warning_once( + "Fusion enabled but reshape elimination disabled." + "RMSNorm + quant (fp8) fusion might not work") + + pass_config: PassConfig = Field(default_factory=PassConfig) # not configurable, computed after init compile_sizes: List[int] = PrivateAttr @@ -2185,8 +2216,9 @@ def model_post_init(self, __context: Any) -> None: for k, v in self.inductor_passes.items(): if not isinstance(v, str): assert callable(v), ( - f"pass {k} should be a function or a qualified name") - self.inductor_compile_config[k] = v + f"pass {k} should be callable or a qualified name") + self.inductor_compile_config[k] = v if isinstance( + v, InductorPass) else CallableInductorPass(v) continue # resolve function from qualified name @@ -2194,7 +2226,8 @@ def model_post_init(self, __context: Any) -> None: module = ".".join(names[:-1]) func_name = names[-1] func = __import__(module).__dict__[func_name] - self.inductor_compile_config[k] = func + self.inductor_compile_config[k] = func if isinstance( + func, InductorPass) else CallableInductorPass(func) self.enabled_custom_ops = Counter() self.disabled_custom_ops = Counter() @@ -2344,7 +2377,8 @@ def __post_init__(self): self.compilation_config.custom_ops = ["none"] self.compilation_config.use_cudagraph = True self.compilation_config.use_inductor = True - self.compilation_config.enable_fusion = False + self.compilation_config.pass_config.enable_fusion = False + self.compilation_config.pass_config.enable_reshape = False current_platform.check_and_update_config(self) diff --git a/vllm/utils.py b/vllm/utils.py index 2bbdc8d1ebde8..cb2ad43a2ae8d 100644 --- a/vllm/utils.py +++ b/vllm/utils.py @@ -1501,15 +1501,6 @@ def __len__(self): return len(self._factory) -def combine_fx_passes(passes: List[Callable]) -> Callable: - - def combined_fx(graph) -> None: - for fx in passes: - fx(graph) - - return combined_fx - - def weak_ref_tensor(tensor: torch.Tensor) -> torch.Tensor: """ Create a weak reference to a tensor. diff --git a/vllm/v1/worker/gpu_model_runner.py b/vllm/v1/worker/gpu_model_runner.py index 1f9b544637bf7..5f66293cbe8e4 100644 --- a/vllm/v1/worker/gpu_model_runner.py +++ b/vllm/v1/worker/gpu_model_runner.py @@ -548,7 +548,7 @@ def capture_model(self) -> None: if not self.use_cuda_graph: logger.warning( "Skipping CUDA graph capture. Please add " - "-O 3 to use CUDA graphs.", CompilationLevel.PIECEWISE) + "-O %s to use CUDA graphs.", CompilationLevel.PIECEWISE) return start_time = time.perf_counter()