Skip to content
New issue

Have a question about this project? # for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “#”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? # to your account

[TRANSFORMATIONS][GPU] Adjust SDPA Fusion pass #29575

Open
wants to merge 26 commits into
base: master
Choose a base branch
from

Conversation

merezman
Copy link

@merezman merezman commented Mar 19, 2025

Details:

  • Adjusted SDPA fusion pass to be able to handle patterns observed in many customer models, that requires to include reshapes to the pattern that changes 4D tensors to 3D or 2D. It handles scales that can be attached in different places, and inner reshapes ie. reshape of softmax output.

@github-actions github-actions bot added category: GPU OpenVINO GPU plugin category: transformations OpenVINO Runtime library - Transformations labels Mar 19, 2025
@sys-openvino-ci sys-openvino-ci added the ExternalIntelPR External contributor from Intel label Mar 19, 2025
Comment on lines 29 to 32
auto q_base = makePattern(ov::Rank(4));
auto q_shape = ov::pass::pattern::any_input();
auto q_reshaped = ov::pass::pattern::wrap_type<ov::op::v1::Reshape>({q_base, q_shape});
auto q = q_reshaped | q_base;

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I guess pattern::optional can be used here to simplify that part:

auto q_base = makePattern(ov::Rank(4));
auto q = optional<ov::op::v1::Reshape>({q_base, any_input()});

Copy link
Author

@merezman merezman Mar 20, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I used this way for later checks

// make sure that if inputs are reshaped the output is reshaped back
bool inputs_reshaped = pattern_map.count(q_reshaped) > 0 && pattern_map.count(k_reshaped) > 0 && pattern_map.count(v_reshaped) > 0;
bool output_reshaped = pattern_map.count(qkv_reshaped) > 0;
if (inputs_reshaped && !output_reshaped || !inputs_reshaped && output_reshaped)
return false;


// Optional k scale
auto attn_scale = ov::pass::pattern::any_input();
// auto k_opt_scaled = optional<ov::op::v1::Multiply>({k, attn_scale});

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

remove unused code

auto softmax = makePattern<ov::op::v8::Softmax>({optional_add_mask}, {{"axis", "-1"}});
auto qkv = makePattern<ov::op::v0::MatMul>({softmax, v}, {{"transpose_a", false}, {"transpose_b", false}});
// Optional reshape befor adding mask
auto qk_opt_scaled_pre_mask_shape = ov::pass::pattern::any_input();

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We have using ov::pass::pattern in the beginning of the function, so you can omit explicit namespace specification

auto qk_opt_scaled_pre_mask_reshaped = ov::pass::pattern::wrap_type<ov::op::v1::Reshape>({qk_opt_scaled, qk_opt_scaled_pre_mask_shape});
auto qk_opt_scaled_pre_mask_opt_reshaped = qk_opt_scaled_pre_mask_reshaped | qk_opt_scaled;
// Optional mask add
auto qk_opt_scaled_mask_added = makePattern<ov::op::v1::Add>({qk_opt_scaled_pre_mask_opt_reshaped, mask});

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think makePattern shall be replaced here with wrap_type

Comment on lines 51 to 52
auto k_opt_transposed_scaled = ov::pass::pattern::wrap_type<ov::op::v1::Multiply>({k_opt_transposed, attn_scale});
auto k_opt_transposed_opt_scaled = k_opt_transposed_scaled | k_opt_transposed;

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think optional can be used here too.

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It is also used for checks

Comment on lines 190 to 198
ov::Output<ov::Node> qk_out;
if (pattern_map.count(qk_opt_scaled_pre_mask_reshaped) > 0)
qk_out = pattern_map.at(qk_opt_scaled_pre_mask_reshaped);
else if (pattern_map.count(qk_unsqueeze) > 0)
qk_out = pattern_map.at(qk_unsqueeze);
else if (pattern_map.count(qk) > 0)
qk_out = pattern_map.at(qk);
else
return false;

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Probably can be simplified by taking first input of qk_opt_scaled_mask_added node.

@merezman merezman requested a review from sshlyapn March 21, 2025 10:51
@merezman merezman marked this pull request as ready for review March 24, 2025 12:06
@merezman merezman requested review from a team as code owners March 24, 2025 12:06
@merezman merezman requested review from itikhono and removed request for a team March 24, 2025 12:06
Copy link
Contributor

@sshlyapn sshlyapn left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

In general looks good to me

auto v = makePattern(ov::Rank(4));
auto q_base = makePattern(ov::Rank(4));
auto q_shape = any_input();
auto q_reshaped = wrap_type<ov::op::v1::Reshape>({q_base, q_shape});
Copy link
Contributor

@itikhono itikhono Mar 27, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

minor: can we use
pattern::optional for v1::Reshape?

class OPENVINO_API Optional : public Pattern {

it should work as pattern::Or inside.

the same for K, V
and probably for QK

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I used this way to be able to use those checks

// make sure that if inputs are reshaped the output is reshaped back
bool inputs_reshaped = pattern_map.count(q_reshaped) > 0 && pattern_map.count(k_reshaped) > 0 && pattern_map.count(v_reshaped) > 0;
bool output_reshaped = pattern_map.count(qkv_reshaped) > 0;
if (inputs_reshaped && !output_reshaped || !inputs_reshaped && output_reshaped)
return false;

Copy link
Contributor

@itikhono itikhono Mar 27, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I believe it should work the same way with auto q_reshaped = optional<Reshape>(in1, in2)> etc could you double check?
As I can see, we have a test for this

EXPECT_EQ(pattern_val_mp.count(pattern_convert), 0);

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Right, seems that it works

@itikhono itikhono requested review from CuriousPanCake and removed request for vladimir-paramuzov March 27, 2025 13:21
# for free to join this conversation on GitHub. Already have an account? # to comment
Labels
category: GPU OpenVINO GPU plugin category: transformations OpenVINO Runtime library - Transformations ExternalIntelPR External contributor from Intel
Projects
None yet
Development

Successfully merging this pull request may close these issues.

5 participants