|
| 1 | +#include "torch/csrc/jit/ir/subgraph_matcher.h" |
| 2 | +#include "torch/csrc/jit/passes/subgraph_rewrite.h" |
| 3 | + |
| 4 | +#include "core/util/prelude.h" |
| 5 | +#include "torch/csrc/jit/ir/irparser.h" |
| 6 | + |
| 7 | +namespace torch_tensorrt { |
| 8 | +namespace core { |
| 9 | +namespace lowering { |
| 10 | +namespace passes { |
| 11 | + |
| 12 | +// https://pytorch.org/docs/stable/generated/torch.nn.functional.scaled_dot_product_attention.html |
| 13 | +void UnpackScaledDotProductAttention(std::shared_ptr<torch::jit::Graph>& graph) { |
| 14 | + std::string sdpa_pattern = R"IR( |
| 15 | + graph(%query, %key, %value, %attn_mask, %dropout_p, %is_causal): |
| 16 | + %out: Tensor = aten::scaled_dot_product_attention(%query, %key, %value, %attn_mask, %dropout_p, %is_causal) |
| 17 | + return (%out))IR"; |
| 18 | + |
| 19 | + std::string unpacked_sdpa_pattern = R"IR( |
| 20 | + graph(%query, %key, %value, %attn_mask, %dropout_p, %is_causal): |
| 21 | + %none : NoneType = prim::Constant() |
| 22 | + %1 : int = prim::Constant[value=-1]() |
| 23 | + %2 : int = prim::Constant[value=-2]() |
| 24 | + %3 : int = aten::size(%query, %1) |
| 25 | + %q_size : Long() = prim::NumToTensor(%3) |
| 26 | + %sqrt : Tensor = aten::sqrt(%q_size) |
| 27 | + %scale_factor : Tensor = aten::reciprocal(%sqrt) |
| 28 | + %key_transpose : Tensor = aten::transpose(%key, %2, %1) |
| 29 | + %matmul : Tensor = aten::matmul(%query, %key_transpose) |
| 30 | + %attn_weight : Tensor = aten::mul(%matmul, %scale_factor) |
| 31 | + %softmax : Tensor = aten::softmax(%attn_weight, %1, %none) |
| 32 | + %out : Tensor = aten::matmul(%softmax, %value) |
| 33 | + return(%out))IR"; |
| 34 | + |
| 35 | + std::string unpacked_sdpa_attn_biased_pattern = R"IR( |
| 36 | + graph(%query, %key, %value, %attn_mask, %dropout_p, %is_causal): |
| 37 | + %none : NoneType = prim::Constant() |
| 38 | + %0 : int = prim::Constant[value=1]() |
| 39 | + %1 : int = prim::Constant[value=-1]() |
| 40 | + %2 : int = prim::Constant[value=-2]() |
| 41 | + %3 : int = aten::size(%query, %1) |
| 42 | + %q_size : Long() = prim::NumToTensor(%3) |
| 43 | + %sqrt : Tensor = aten::sqrt(%q_size) |
| 44 | + %scale_factor : Tensor = aten::reciprocal(%sqrt) |
| 45 | + %key_transpose : Tensor = aten::transpose(%key, %2, %1) |
| 46 | + %matmul : Tensor = aten::matmul(%query, %key_transpose) |
| 47 | + %attn_weight : Tensor = aten::mul(%matmul, %scale_factor) |
| 48 | + %attn_bias : Tensor = trt::attn_bias_from_attn_mask(%attn_mask) |
| 49 | + %attn_weight_with_bias : Tensor = aten::add(%attn_weight, %attn_bias, %0) |
| 50 | + %softmax : Tensor = aten::softmax(%attn_weight_with_bias, %1, %none) |
| 51 | + %out : Tensor = aten::matmul(%softmax, %value) |
| 52 | + return(%out))IR"; |
| 53 | + |
| 54 | + // rewrite with None attn_mask |
| 55 | + torch::jit::SubgraphRewriter sdpa_rewriter; |
| 56 | + sdpa_rewriter.RegisterRewritePattern(sdpa_pattern, unpacked_sdpa_pattern); |
| 57 | + sdpa_rewriter.runOnGraph( |
| 58 | + graph, [](const torch::jit::Match& match, const std::unordered_map<std::string, torch::jit::Value*>&) { |
| 59 | + auto is_causal_node = match.anchor->inputs().at(5)->node(); |
| 60 | + if (is_causal_node->kind() != at::prim::Constant) { |
| 61 | + LOG_WARNING("Could not unpack scaled_dot_product_attention with non constant is_causal: " << *is_causal_node); |
| 62 | + return false; |
| 63 | + } |
| 64 | + if (is_causal_node->i(at::attr::value) == 1) { |
| 65 | + LOG_WARNING("Could not unpack scaled_dot_product_attention with is_causal = True: " << *is_causal_node); |
| 66 | + return false; |
| 67 | + } |
| 68 | + auto attn_mask_node = match.anchor->inputs().at(3)->node(); |
| 69 | + if (attn_mask_node->kind() != at::prim::Constant || !attn_mask_node->mustBeNone()) { |
| 70 | + return false; |
| 71 | + } |
| 72 | + return true; |
| 73 | + }); |
| 74 | + |
| 75 | + // rewrite with float/bool attn_mask this uses a custom op to implement the divergent behavior between bool and float |
| 76 | + // masks without a conditional |
| 77 | + torch::jit::SubgraphRewriter sdpa_attn_mask_rewriter; |
| 78 | + sdpa_attn_mask_rewriter.RegisterRewritePattern(sdpa_pattern, unpacked_sdpa_attn_biased_pattern); |
| 79 | + sdpa_attn_mask_rewriter.runOnGraph( |
| 80 | + graph, [](const torch::jit::Match& match, const std::unordered_map<std::string, torch::jit::Value*>&) { |
| 81 | + auto is_causal_node = match.anchor->inputs().at(5)->node(); |
| 82 | + if (is_causal_node->kind() != at::prim::Constant || is_causal_node->i(at::attr::value) == 1) { |
| 83 | + // messages already written in first pass, do not write again |
| 84 | + return false; |
| 85 | + } |
| 86 | + return true; |
| 87 | + }); |
| 88 | + LOG_GRAPH("Post unpack scaled_dot_product_attention: " << *graph); |
| 89 | +} |
| 90 | + |
| 91 | +} // namespace passes |
| 92 | +} // namespace lowering |
| 93 | +} // namespace core |
| 94 | +} // namespace torch_tensorrt |
0 commit comments