-
Notifications
You must be signed in to change notification settings - Fork 2.5k
New issue
Have a question about this project? # for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “#”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? # to your account
[TRANSFORMATIONS][GPU] Adjust SDPA Fusion pass #29575
base: master
Are you sure you want to change the base?
Conversation
auto q_base = makePattern(ov::Rank(4)); | ||
auto q_shape = ov::pass::pattern::any_input(); | ||
auto q_reshaped = ov::pass::pattern::wrap_type<ov::op::v1::Reshape>({q_base, q_shape}); | ||
auto q = q_reshaped | q_base; |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I guess pattern::optional can be used here to simplify that part:
auto q_base = makePattern(ov::Rank(4));
auto q = optional<ov::op::v1::Reshape>({q_base, any_input()});
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I used this way for later checks
openvino/src/common/transformations/src/transformations/common_optimizations/sdpa_fusion.cpp
Lines 146 to 150 in d32a498
// make sure that if inputs are reshaped the output is reshaped back | |
bool inputs_reshaped = pattern_map.count(q_reshaped) > 0 && pattern_map.count(k_reshaped) > 0 && pattern_map.count(v_reshaped) > 0; | |
bool output_reshaped = pattern_map.count(qkv_reshaped) > 0; | |
if (inputs_reshaped && !output_reshaped || !inputs_reshaped && output_reshaped) | |
return false; |
|
||
// Optional k scale | ||
auto attn_scale = ov::pass::pattern::any_input(); | ||
// auto k_opt_scaled = optional<ov::op::v1::Multiply>({k, attn_scale}); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
remove unused code
auto softmax = makePattern<ov::op::v8::Softmax>({optional_add_mask}, {{"axis", "-1"}}); | ||
auto qkv = makePattern<ov::op::v0::MatMul>({softmax, v}, {{"transpose_a", false}, {"transpose_b", false}}); | ||
// Optional reshape befor adding mask | ||
auto qk_opt_scaled_pre_mask_shape = ov::pass::pattern::any_input(); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
We have using ov::pass::pattern
in the beginning of the function, so you can omit explicit namespace specification
auto qk_opt_scaled_pre_mask_reshaped = ov::pass::pattern::wrap_type<ov::op::v1::Reshape>({qk_opt_scaled, qk_opt_scaled_pre_mask_shape}); | ||
auto qk_opt_scaled_pre_mask_opt_reshaped = qk_opt_scaled_pre_mask_reshaped | qk_opt_scaled; | ||
// Optional mask add | ||
auto qk_opt_scaled_mask_added = makePattern<ov::op::v1::Add>({qk_opt_scaled_pre_mask_opt_reshaped, mask}); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think makePattern
shall be replaced here with wrap_type
auto k_opt_transposed_scaled = ov::pass::pattern::wrap_type<ov::op::v1::Multiply>({k_opt_transposed, attn_scale}); | ||
auto k_opt_transposed_opt_scaled = k_opt_transposed_scaled | k_opt_transposed; |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think optional
can be used here too.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It is also used for checks
ov::Output<ov::Node> qk_out; | ||
if (pattern_map.count(qk_opt_scaled_pre_mask_reshaped) > 0) | ||
qk_out = pattern_map.at(qk_opt_scaled_pre_mask_reshaped); | ||
else if (pattern_map.count(qk_unsqueeze) > 0) | ||
qk_out = pattern_map.at(qk_unsqueeze); | ||
else if (pattern_map.count(qk) > 0) | ||
qk_out = pattern_map.at(qk); | ||
else | ||
return false; |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Probably can be simplified by taking first input of qk_opt_scaled_mask_added
node.
src/common/transformations/src/transformations/common_optimizations/sdpa_fusion.cpp
Outdated
Show resolved
Hide resolved
src/common/transformations/src/transformations/common_optimizations/sdpa_fusion.cpp
Outdated
Show resolved
Hide resolved
src/common/transformations/src/transformations/common_optimizations/sdpa_fusion.cpp
Show resolved
Hide resolved
src/plugins/intel_gpu/src/kernel_selector/kernels/sdpa/sdpa_kernel_ref.cpp
Outdated
Show resolved
Hide resolved
src/common/transformations/tests/common_optimizations/sdpa_fusion_test.cpp
Outdated
Show resolved
Hide resolved
src/common/transformations/src/transformations/common_optimizations/sdpa_fusion.cpp
Outdated
Show resolved
Hide resolved
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
In general looks good to me
auto v = makePattern(ov::Rank(4)); | ||
auto q_base = makePattern(ov::Rank(4)); | ||
auto q_shape = any_input(); | ||
auto q_reshaped = wrap_type<ov::op::v1::Reshape>({q_base, q_shape}); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
minor: can we use
pattern::optional for v1::Reshape?
class OPENVINO_API Optional : public Pattern { |
it should work as pattern::Or inside.
the same for K, V
and probably for QK
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I used this way to be able to use those checks
openvino/src/common/transformations/src/transformations/common_optimizations/sdpa_fusion.cpp
Lines 146 to 150 in d32a498
// make sure that if inputs are reshaped the output is reshaped back | |
bool inputs_reshaped = pattern_map.count(q_reshaped) > 0 && pattern_map.count(k_reshaped) > 0 && pattern_map.count(v_reshaped) > 0; | |
bool output_reshaped = pattern_map.count(qkv_reshaped) > 0; | |
if (inputs_reshaped && !output_reshaped || !inputs_reshaped && output_reshaped) | |
return false; |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I believe it should work the same way with auto q_reshaped = optional<Reshape>(in1, in2)>
etc could you double check?
As I can see, we have a test for this
openvino/src/core/tests/pattern.cpp
Line 579 in 94bc742
EXPECT_EQ(pattern_val_mp.count(pattern_convert), 0); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Right, seems that it works
Details: