Skip to content
New issue

Have a question about this project? # for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “#”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? # to your account

【complex op No.7】add complex support for Log/log10/log2/log1p #62448

Merged
merged 7 commits into from
Mar 29, 2024

Conversation

zbt78
Copy link
Contributor

@zbt78 zbt78 commented Mar 6, 2024

PR types

New features

PR changes

OPs

Description

add complex support for Log/log10/log2/log1p

#61975

Copy link

paddle-bot bot commented Mar 6, 2024

你的PR提交成功,感谢你对开源项目的贡献!
请关注后续CI自动化测试结果,详情请参考Paddle-CI手册
Your PR has been submitted. Thanks for your contribution!
Please wait for the result of CI firstly. See Paddle CI Manual for details.

@@ -32,6 +33,35 @@ namespace phi {
dev_ctx, &x, nullptr, &dout, dx, functor); \
}

#define DEFINE_CPU_ACTIVATION_GRAD_KERNEL_WITH_INT_IN_FLOAT_OUT_DEPX( \
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

这里为什么要新增这个宏呢?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

这里是为了支持输入数据为int,输出数据为float时的 backward。求梯度时由于x的类型和T不一致,在把x转化为EigenVector时会出错,所以这里先转化了x。
另外这是四个相关的api,只有functor的名字不同,不使用宏的话占的篇幅比较大。
有个问题是对于int型输入需不需要求backward,torch应该是不支持的
image

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

这些kernel原来就不接收int的输入吧,你看原来的注册宏里面就没有注册int相关的类型

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

#54089 是这个里面新加的

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

这里的是反向kernel,因为如果需要算梯度的话,算出来的x梯度是float类型,还得再转回int,从float转成int 损失太大,这个梯度就没有意义了,所以这种的反向一般都不支持int的输入。

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

好的,那我把backward相关部分删除了

x.dtype() == phi::DataType::INT64) { \
DenseTensor x_fp; \
MetaTensor meta_xp(x_fp); \
UnchangedInferMeta(dx, &meta_xp); \
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

infermeta 按理来说不应该放在kernel 内部

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

已修改

check_pir=True, check_pir_onednn=self.check_pir_onednn
)

def test_api_complex(self):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

这个不能复用check_output吗

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

这里应该是可以的

PD_REGISTER_ACTIVATION_GRAD_KERNEL_WITH_COMPLEX(log_grad, LogGradKernel)
PD_REGISTER_ACTIVATION_GRAD_KERNEL_WITH_COMPLEX(log2_grad, Log2GradKernel)
PD_REGISTER_ACTIVATION_GRAD_KERNEL_WITH_COMPLEX(log10_grad, Log10GradKernel)
PD_REGISTER_ACTIVATION_GRAD_KERNEL_WITH_COMPLEX(log1p_grad, Log1pGradKernel)
PD_REGISTER_KERNEL(log_double_grad,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

这里注册了log的double_grad,但是单测有覆盖到这个二阶吗

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

还没覆盖,高阶导的测试有示例吗🤔,我没太找到相应的实现参考

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

目前高阶单测不支持复数,这个就先不加了,加了但是不能保证正确性意义不大

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

这里应该也是有之前的问题,高阶如果不注册的话,低阶在测试的时候会出问题。

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

这个还是加个单测吧,可以用np手写反向计算逻辑,然后与paddle 算出来的进行比较,不能添加一个不能保证正确性额kernel

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done

if device == 'cpu' or (
device == 'gpu' and paddle.is_compiled_with_cuda()
):
paddle.set_device(device)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

用 paddle.set_device(device) 这个会不会引起全局并发的问题,从而影响其他的单测,可以在to_tensor的时候设置place来区分不同设备

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done

Copy link
Contributor

@GGBond8488 GGBond8488 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM

@luotao1 luotao1 merged commit ed19f42 into PaddlePaddle:develop Mar 29, 2024
30 checks passed
# for free to join this conversation on GitHub. Already have an account? # to comment
Labels
contributor External developers
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants