-
Notifications
You must be signed in to change notification settings - Fork 5.7k
New issue
Have a question about this project? # for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “#”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? # to your account
【Infer Symbolic Shape No.230】【BUAA】Add stft, changed 3 files #67663
Conversation
|
||
infer_context->AddEqualCstr(window_shape[0], symbol::DimExpr{n_fft}); | ||
|
||
int seq_length = x_shape[x_rank - 1].Get<std::int64_t>(); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
这里加个判断吧,是int类型再get
int seq_length = x_shape[x_rank - 1].Get<std::int64_t>(); | ||
symbol::DimExpr n_frames = 1 + (seq_length - n_fft) / hop_length; | ||
|
||
PADDLE_ENFORCE_LE(n_fft, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
放在上面的if分支里
|
||
if (x_shape[x_rank - 1].isa<int64_t>()) { | ||
int seq_length = x_shape[x_rank - 1].Get<std::int64_t>(); | ||
int n_frames = 1 + (seq_length - n_fft) / hop_length; |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
seq_length 和 n_frames放在if外面,直接声明为DimExpr类型,if判断是否是int只服务于ENFORCE_LE
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM
} else { | ||
output_shape.push_back(symbol::DimExpr{n_fft}); | ||
} | ||
output_shape.push_back(symbol::DimExpr{n_frames}); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
上面声明用的是DimExpr的话,这里就不用转换了
infer_context->AddEqualCstr(window_shape[0], symbol::DimExpr{n_fft}); | ||
const symbol::DimExpr seq_length = x_shape[x_rank - 1]; | ||
const symbol::DimExpr n_frames = | ||
(symbol::DimExpr{1}) + |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
这里多了个括号
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM
PR Category
CINN
PR Types
Others
Description
加入stft,单测位于
test\legacy_test\test_stft_op.py
check_dygraph=False
,默认的check_pir=False, check_symbol_infer=True