Skip to content
New issue

Have a question about this project? # for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “#”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? # to your account

【Infer Symbolic Shape No.230】【BUAA】Add stft, changed 3 files #67663

Merged
merged 69 commits into from
Sep 18, 2024

Conversation

Whsjrczr
Copy link
Contributor

@Whsjrczr Whsjrczr commented Aug 22, 2024

PR Category

CINN

PR Types

Others

Description

加入stft,单测位于test\legacy_test\test_stft_op.py
check_dygraph=False,默认的check_pir=False, check_symbol_infer=True


infer_context->AddEqualCstr(window_shape[0], symbol::DimExpr{n_fft});

int seq_length = x_shape[x_rank - 1].Get<std::int64_t>();
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

这里加个判断吧,是int类型再get

int seq_length = x_shape[x_rank - 1].Get<std::int64_t>();
symbol::DimExpr n_frames = 1 + (seq_length - n_fft) / hop_length;

PADDLE_ENFORCE_LE(n_fft,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

放在上面的if分支里


if (x_shape[x_rank - 1].isa<int64_t>()) {
int seq_length = x_shape[x_rank - 1].Get<std::int64_t>();
int n_frames = 1 + (seq_length - n_fft) / hop_length;
Copy link
Contributor

@gongshaotian gongshaotian Sep 4, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

seq_length 和 n_frames放在if外面,直接声明为DimExpr类型,if判断是否是int只服务于ENFORCE_LE

Copy link
Contributor

@gongshaotian gongshaotian left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM

} else {
output_shape.push_back(symbol::DimExpr{n_fft});
}
output_shape.push_back(symbol::DimExpr{n_frames});
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

上面声明用的是DimExpr的话,这里就不用转换了

@Whsjrczr Whsjrczr closed this Sep 18, 2024
@gongshaotian gongshaotian reopened this Sep 18, 2024
infer_context->AddEqualCstr(window_shape[0], symbol::DimExpr{n_fft});
const symbol::DimExpr seq_length = x_shape[x_rank - 1];
const symbol::DimExpr n_frames =
(symbol::DimExpr{1}) +
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

这里多了个括号

Copy link
Contributor

@gongshaotian gongshaotian left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM

@luotao1 luotao1 merged commit 281e384 into PaddlePaddle:develop Sep 18, 2024
28 of 30 checks passed
# for free to join this conversation on GitHub. Already have an account? # to comment
Labels
contributor External developers HappyOpenSource Pro 进阶版快乐开源活动,更具挑战性的任务
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants