|
| 1 | +import operator |
| 2 | + |
| 3 | +import torch |
| 4 | +import torch.fx as fx |
| 5 | + |
| 6 | + |
| 7 | +def fix_functionalization(graph: fx.Graph): |
| 8 | + """ |
| 9 | + Rewrite the graph module to replace the pattern involving |
| 10 | + torch._higher_order_ops.auto_functionalize.auto_functionalized |
| 11 | + with a direct call to the inplace custom op. |
| 12 | +
|
| 13 | + # TODO: check if PyTorch nightly has fixed this issue |
| 14 | + """ |
| 15 | + |
| 16 | + # debug code, if we want to see the graph before the transformation |
| 17 | + # with open("before.py", "w") as f: |
| 18 | + # print(graph.python_code(root_module="self", verbose=True).src, file=f) |
| 19 | + |
| 20 | + nodes_to_remove = [] |
| 21 | + |
| 22 | + for node in graph.nodes: |
| 23 | + # Identify the auto_functionalized node |
| 24 | + if node.op == 'call_function' and node.target == torch._higher_order_ops.auto_functionalize.auto_functionalized: # noqa |
| 25 | + if node.args[0] == torch.ops._C.rotary_embedding.default: |
| 26 | + # manual replace for rotary_embedding |
| 27 | + |
| 28 | + # Now, collect the arguments |
| 29 | + kwargs = node.kwargs |
| 30 | + |
| 31 | + query = kwargs['query'] |
| 32 | + mm_node = query.args[0].args[0] |
| 33 | + |
| 34 | + # Create a new call to torch.ops._C.rotary_embedding.default |
| 35 | + with graph.inserting_before(node): |
| 36 | + # just insert the call to the custom op |
| 37 | + # NOTE: don't run dead code elimination, |
| 38 | + # otherwise this op will be removed |
| 39 | + graph.call_function(torch.ops._C.rotary_embedding.default, |
| 40 | + kwargs=kwargs) |
| 41 | + |
| 42 | + # Remove the auto_functionalized node |
| 43 | + # Since the node may have outputs, we need to handle its users |
| 44 | + # Replace uses of the outputs (getitem nodes) with mm_node |
| 45 | + for user in list(node.users): |
| 46 | + if user.op == 'call_function' and user.target == operator.getitem: # noqa |
| 47 | + # Remove the getitem node |
| 48 | + for getitem_user in list(user.users): |
| 49 | + if (getitem_user.op == 'call_function' |
| 50 | + and getitem_user.target |
| 51 | + == torch.ops.aten.slice_scatter.default): |
| 52 | + # Replace the uses of slice_scatter node |
| 53 | + # with mm_node |
| 54 | + getitem_user.replace_all_uses_with(mm_node) |
| 55 | + nodes_to_remove.append(getitem_user) |
| 56 | + nodes_to_remove.append(user) |
| 57 | + nodes_to_remove.append(node) |
| 58 | + |
| 59 | + elif node.args[0] == torch.ops._C.fused_add_rms_norm.default: |
| 60 | + # manual replace for fused_add_rms_norm |
| 61 | + # this is the most effective optimization for llama |
| 62 | + # failing to do this will result in many unnecessary copies |
| 63 | + |
| 64 | + kwargs = node.kwargs |
| 65 | + |
| 66 | + input = kwargs['input'] |
| 67 | + residual = kwargs['residual'] |
| 68 | + |
| 69 | + # Create a new call to torch.ops._C.rotary_embedding.default |
| 70 | + with graph.inserting_before(node): |
| 71 | + # just insert the call to the custom op |
| 72 | + # NOTE: don't run dead code elimination, |
| 73 | + # otherwise this op will be removed |
| 74 | + graph.call_function( |
| 75 | + torch.ops._C.fused_add_rms_norm.default, kwargs=kwargs) |
| 76 | + |
| 77 | + for user in list(node.users): |
| 78 | + if user.op == 'call_function' and user.target == operator.getitem: # noqa |
| 79 | + # Remove the getitem node |
| 80 | + if user.args[1] == 1: |
| 81 | + replace_node = input |
| 82 | + elif user.args[1] == 2: |
| 83 | + replace_node = residual |
| 84 | + user.replace_all_uses_with(replace_node) |
| 85 | + nodes_to_remove.append(user) |
| 86 | + nodes_to_remove.append(node) |
| 87 | + |
| 88 | + elif node.args[0] == torch.ops._C.rms_norm.default: |
| 89 | + # manual replace for rms_norm |
| 90 | + |
| 91 | + kwargs = node.kwargs |
| 92 | + |
| 93 | + input = kwargs['input'] |
| 94 | + out = kwargs['out'] |
| 95 | + weight = kwargs['weight'] |
| 96 | + epsilon = kwargs['epsilon'] |
| 97 | + # Create a new call to torch.ops._C.rotary_embedding.default |
| 98 | + # cannot use kwargs, because we have an `out`, see https://github.com/pytorch/pytorch/blob/a00faf440888ffb724bad413f329a49e2b6388e7/torch/_inductor/lowering.py#L351 # noqa |
| 99 | + with graph.inserting_before(node): |
| 100 | + # just insert the call to the custom op |
| 101 | + # NOTE: don't run dead code elimination, |
| 102 | + # otherwise this op will be removed |
| 103 | + graph.call_function( |
| 104 | + torch.ops._C.rms_norm.default, |
| 105 | + args=(out, input, weight, epsilon), |
| 106 | + ) |
| 107 | + |
| 108 | + replace_node = out |
| 109 | + |
| 110 | + for user in list(node.users): |
| 111 | + if user.op == 'call_function' and user.target == operator.getitem: # noqa |
| 112 | + user.replace_all_uses_with(replace_node) |
| 113 | + nodes_to_remove.append(user) |
| 114 | + nodes_to_remove.append(node) |
| 115 | + |
| 116 | + elif node.args[0] == torch.ops._C.silu_and_mul.default: |
| 117 | + # manual replace for silu_and_mul |
| 118 | + |
| 119 | + kwargs = node.kwargs |
| 120 | + |
| 121 | + input = kwargs['input'] |
| 122 | + out = kwargs['out'] |
| 123 | + |
| 124 | + # Create a new call to torch.ops._C.rotary_embedding.default |
| 125 | + # cannot use kwargs, because we have an `out`, see https://github.com/pytorch/pytorch/blob/a00faf440888ffb724bad413f329a49e2b6388e7/torch/_inductor/lowering.py#L351 # noqa |
| 126 | + with graph.inserting_before(node): |
| 127 | + # just insert the call to the custom op |
| 128 | + # NOTE: don't run dead code elimination, |
| 129 | + # otherwise this op will be removed |
| 130 | + graph.call_function( |
| 131 | + torch.ops._C.silu_and_mul.default, |
| 132 | + args=(out, input), |
| 133 | + ) |
| 134 | + replace_node = out |
| 135 | + |
| 136 | + for user in list(node.users): |
| 137 | + if user.op == 'call_function' and user.target == operator.getitem: # noqa |
| 138 | + user.replace_all_uses_with(replace_node) |
| 139 | + nodes_to_remove.append(user) |
| 140 | + nodes_to_remove.append(node) |
| 141 | + |
| 142 | + # Remove the nodes all at once |
| 143 | + for node in nodes_to_remove: |
| 144 | + graph.erase_node(node) |
| 145 | + |
| 146 | + # debug code, if we want to see the graph after the transformation |
| 147 | + # with open("after.py", "w") as f: |
| 148 | + # print(graph.python_code(root_module="self", verbose=True).src, file=f) |
| 149 | + |
| 150 | + |
| 151 | +def vllm_backend(graph, example_inputs): |
| 152 | + from torch._inductor import config |
| 153 | + current_config = config.shallow_copy_dict() |
| 154 | + from torch._inductor.compile_fx import compile_fx |
| 155 | + current_config['post_grad_custom_post_pass'] = fix_functionalization |
| 156 | + return compile_fx(graph, example_inputs, config_patches=current_config) |
0 commit comments