Skip to content
New issue

Have a question about this project? # for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “#”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? # to your account

【Complex op】add complex support for sin, cos, tan, tanh #55380

Merged
Merged
Show file tree
Hide file tree
Changes from 4 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
30 changes: 30 additions & 0 deletions paddle/phi/common/complex.h
Original file line number Diff line number Diff line change
Expand Up @@ -422,6 +422,36 @@ HOSTDEVICE inline complex<T> sqrt(const complex<T>& a) {
#endif
}

template <typename T>
HOSTDEVICE inline complex<T> sin(const complex<T>& a) {
#if defined(PADDLE_WITH_CUDA_OR_HIP_COMPLEX) && \
(defined(__CUDA_ARCH__) || defined(__HIPCC__))
return complex<T>(thrust::sin(thrust::complex<T>(a)));
#else
return complex<T>(std::sin(std::complex<T>(a)));
#endif
}

template <typename T>
HOSTDEVICE inline complex<T> cos(const complex<T>& a) {
#if defined(PADDLE_WITH_CUDA_OR_HIP_COMPLEX) && \
(defined(__CUDA_ARCH__) || defined(__HIPCC__))
return complex<T>(thrust::cos(thrust::complex<T>(a)));
#else
return complex<T>(std::cos(std::complex<T>(a)));
#endif
}

template <typename T>
HOSTDEVICE inline complex<T> tan(const complex<T>& a) {
#if defined(PADDLE_WITH_CUDA_OR_HIP_COMPLEX) && \
(defined(__CUDA_ARCH__) || defined(__HIPCC__))
return complex<T>(thrust::tan(thrust::complex<T>(a)));
#else
return complex<T>(std::tan(std::complex<T>(a)));
#endif
}

template <typename T>
HOSTDEVICE inline complex<T> tanh(const complex<T>& a) {
#if defined(PADDLE_WITH_CUDA_OR_HIP_COMPLEX) && \
Expand Down
32 changes: 25 additions & 7 deletions paddle/phi/kernels/cpu/activation_grad_kernel.cc
Original file line number Diff line number Diff line change
Expand Up @@ -255,13 +255,23 @@ PD_REGISTER_KERNEL(
#define PD_REGISTER_ACTIVATION_GRAD_KERNEL(name, func) \
PD_REGISTER_KERNEL(name, CPU, ALL_LAYOUT, phi::func, float, double) {}

#define PD_REGISTER_ACTIVATION_GRAD_KERNEL_WITH_COMPLEX(name, func) \
PD_REGISTER_KERNEL(name, \
CPU, \
ALL_LAYOUT, \
phi::func, \
float, \
double, \
phi::dtype::complex<float>, \
phi::dtype::complex<double>) {}

#define PD_REGISTER_ACTIVATION_DOUBLE_GRAD_KERNEL(name, func) \
PD_REGISTER_KERNEL( \
name, CPU, ALL_LAYOUT, phi::func, float, double, phi::dtype::float16) {}

PD_REGISTER_ACTIVATION_GRAD_KERNEL(sin_grad, SinGradKernel)
PD_REGISTER_ACTIVATION_GRAD_KERNEL(cos_grad, CosGradKernel)
PD_REGISTER_ACTIVATION_GRAD_KERNEL(tan_grad, TanGradKernel)
PD_REGISTER_ACTIVATION_GRAD_KERNEL_WITH_COMPLEX(sin_grad, SinGradKernel)
PD_REGISTER_ACTIVATION_GRAD_KERNEL_WITH_COMPLEX(cos_grad, CosGradKernel)
PD_REGISTER_ACTIVATION_GRAD_KERNEL_WITH_COMPLEX(tan_grad, TanGradKernel)
PD_REGISTER_ACTIVATION_GRAD_KERNEL(acos_grad, AcosGradKernel)
PD_REGISTER_ACTIVATION_GRAD_KERNEL(asin_grad, AsinGradKernel)
PD_REGISTER_ACTIVATION_GRAD_KERNEL(atan_grad, AtanGradKernel)
Expand Down Expand Up @@ -355,7 +365,9 @@ PD_REGISTER_KERNEL(sin_double_grad,
double,
phi::dtype::float16,
int,
int64_t) {}
int64_t,
phi::dtype::complex<float>,
phi::dtype::complex<double>) {}

PD_REGISTER_KERNEL(sin_triple_grad,
CPU,
Expand All @@ -365,7 +377,9 @@ PD_REGISTER_KERNEL(sin_triple_grad,
double,
phi::dtype::float16,
int,
int64_t) {}
int64_t,
phi::dtype::complex<float>,
phi::dtype::complex<double>) {}

PD_REGISTER_KERNEL(cos_double_grad,
CPU,
Expand All @@ -375,7 +389,9 @@ PD_REGISTER_KERNEL(cos_double_grad,
double,
phi::dtype::float16,
int,
int64_t) {}
int64_t,
phi::dtype::complex<float>,
phi::dtype::complex<double>) {}

PD_REGISTER_KERNEL(cos_triple_grad,
CPU,
Expand All @@ -385,7 +401,9 @@ PD_REGISTER_KERNEL(cos_triple_grad,
double,
phi::dtype::float16,
int,
int64_t) {}
int64_t,
phi::dtype::complex<float>,
phi::dtype::complex<double>) {}

PD_REGISTER_ACTIVATION_GRAD_KERNEL(softsign_grad, SoftsignGradKernel)
PD_REGISTER_ACTIVATION_GRAD_KERNEL(sigmoid_grad, SigmoidGradKernel)
Expand Down
18 changes: 14 additions & 4 deletions paddle/phi/kernels/cpu/activation_kernel.cc
Original file line number Diff line number Diff line change
Expand Up @@ -156,9 +156,19 @@ PD_REGISTER_KERNEL(relu, CPU, ALL_LAYOUT, phi::ReluKernel, float, double) {}
#define PD_REGISTER_ACTIVATION_KERNEL(name, func) \
PD_REGISTER_KERNEL(name, CPU, ALL_LAYOUT, phi::func, float, double) {}

PD_REGISTER_ACTIVATION_KERNEL(sin, SinKernel)
PD_REGISTER_ACTIVATION_KERNEL(cos, CosKernel)
PD_REGISTER_ACTIVATION_KERNEL(tan, TanKernel)
#define PD_REGISTER_ACTIVATION_KERNEL_WITH_COMPLEX(name, func) \
PD_REGISTER_KERNEL(name, \
CPU, \
ALL_LAYOUT, \
phi::func, \
float, \
double, \
phi::dtype::complex<float>, \
phi::dtype::complex<double>) {}

PD_REGISTER_ACTIVATION_KERNEL_WITH_COMPLEX(sin, SinKernel)
PD_REGISTER_ACTIVATION_KERNEL_WITH_COMPLEX(cos, CosKernel)
PD_REGISTER_ACTIVATION_KERNEL_WITH_COMPLEX(tan, TanKernel)
PD_REGISTER_ACTIVATION_KERNEL(acos, AcosKernel)
PD_REGISTER_ACTIVATION_KERNEL(asin, AsinKernel)
PD_REGISTER_ACTIVATION_KERNEL(atan, AtanKernel)
Expand All @@ -167,7 +177,7 @@ PD_REGISTER_ACTIVATION_KERNEL(cosh, CoshKernel)
PD_REGISTER_ACTIVATION_KERNEL(asinh, AsinhKernel)
PD_REGISTER_ACTIVATION_KERNEL(acosh, AcoshKernel)
PD_REGISTER_ACTIVATION_KERNEL(atanh, AtanhKernel)
PD_REGISTER_ACTIVATION_KERNEL(tanh, TanhKernel)
PD_REGISTER_ACTIVATION_KERNEL_WITH_COMPLEX(tanh, TanhKernel)
PD_REGISTER_ACTIVATION_KERNEL(hardtanh, HardTanhKernel)
PD_REGISTER_ACTIVATION_KERNEL(leaky_relu, LeakyReluKernel)
PD_REGISTER_ACTIVATION_KERNEL(thresholded_relu, ThresholdedReluKernel)
Expand Down
36 changes: 28 additions & 8 deletions paddle/phi/kernels/gpu/activation_grad_kernel.cu
Original file line number Diff line number Diff line change
Expand Up @@ -343,9 +343,21 @@ PD_REGISTER_KERNEL(relu_double_grad,
phi::dtype::float16, \
phi::dtype::bfloat16) {}

PD_REGISTER_ACTIVATION_GRAD_KERNEL(sin_grad, SinGradKernel)
PD_REGISTER_ACTIVATION_GRAD_KERNEL(cos_grad, CosGradKernel)
PD_REGISTER_ACTIVATION_GRAD_KERNEL(tan_grad, TanGradKernel)
#define PD_REGISTER_ACTIVATION_GRAD_KERNEL_WITH_COMPLEX(name, func) \
PD_REGISTER_KERNEL(name, \
GPU, \
ALL_LAYOUT, \
phi::func, \
float, \
double, \
phi::dtype::float16, \
phi::dtype::bfloat16, \
phi::dtype::complex<float>, \
phi::dtype::complex<double>) {}

PD_REGISTER_ACTIVATION_GRAD_KERNEL_WITH_COMPLEX(sin_grad, SinGradKernel)
PD_REGISTER_ACTIVATION_GRAD_KERNEL_WITH_COMPLEX(cos_grad, CosGradKernel)
PD_REGISTER_ACTIVATION_GRAD_KERNEL_WITH_COMPLEX(tan_grad, TanGradKernel)
PD_REGISTER_ACTIVATION_GRAD_KERNEL(acos_grad, AcosGradKernel)
PD_REGISTER_ACTIVATION_GRAD_KERNEL(asin_grad, AsinGradKernel)
PD_REGISTER_ACTIVATION_GRAD_KERNEL(atan_grad, AtanGradKernel)
Expand All @@ -354,7 +366,7 @@ PD_REGISTER_ACTIVATION_GRAD_KERNEL(cosh_grad, CoshGradKernel)
PD_REGISTER_ACTIVATION_GRAD_KERNEL(asinh_grad, AsinhGradKernel)
PD_REGISTER_ACTIVATION_GRAD_KERNEL(acosh_grad, AcoshGradKernel)
PD_REGISTER_ACTIVATION_GRAD_KERNEL(atanh_grad, AtanhGradKernel)
PD_REGISTER_ACTIVATION_GRAD_KERNEL(tanh_grad, TanhGradKernel)
PD_REGISTER_ACTIVATION_GRAD_KERNEL_WITH_COMPLEX(tanh_grad, TanhGradKernel)
PD_REGISTER_ACTIVATION_GRAD_KERNEL(tanh_double_grad, TanhDoubleGradKernel)
PD_REGISTER_ACTIVATION_GRAD_KERNEL(tanh_triple_grad, TanhTripleGradKernel)
PD_REGISTER_ACTIVATION_GRAD_KERNEL(hardtanh_grad, HardTanhGradKernel)
Expand Down Expand Up @@ -433,7 +445,9 @@ PD_REGISTER_KERNEL(sin_double_grad,
int,
int64_t,
phi::dtype::float16,
phi::dtype::bfloat16) {}
phi::dtype::bfloat16,
phi::dtype::complex<float>,
phi::dtype::complex<double>) {}

PD_REGISTER_KERNEL(sin_triple_grad,
GPU,
Expand All @@ -444,7 +458,9 @@ PD_REGISTER_KERNEL(sin_triple_grad,
int,
int64_t,
phi::dtype::float16,
phi::dtype::bfloat16) {}
phi::dtype::bfloat16,
phi::dtype::complex<float>,
phi::dtype::complex<double>) {}

PD_REGISTER_KERNEL(cos_double_grad,
GPU,
Expand All @@ -455,7 +471,9 @@ PD_REGISTER_KERNEL(cos_double_grad,
int,
int64_t,
phi::dtype::float16,
phi::dtype::bfloat16) {}
phi::dtype::bfloat16,
phi::dtype::complex<float>,
phi::dtype::complex<double>) {}

PD_REGISTER_KERNEL(cos_triple_grad,
GPU,
Expand All @@ -466,7 +484,9 @@ PD_REGISTER_KERNEL(cos_triple_grad,
int,
int64_t,
phi::dtype::float16,
phi::dtype::bfloat16) {}
phi::dtype::bfloat16,
phi::dtype::complex<float>,
phi::dtype::complex<double>) {}

PD_REGISTER_ACTIVATION_GRAD_KERNEL(softsign_grad, SoftsignGradKernel)
PD_REGISTER_ACTIVATION_GRAD_KERNEL(sigmoid_grad, SigmoidGradKernel)
Expand Down
20 changes: 16 additions & 4 deletions paddle/phi/kernels/gpu/activation_kernel.cu
Original file line number Diff line number Diff line change
Expand Up @@ -207,9 +207,21 @@ PD_REGISTER_KERNEL(relu,
phi::dtype::float16, \
phi::dtype::bfloat16) {}

PD_REGISTER_ACTIVATION_KERNEL(sin, SinKernel)
PD_REGISTER_ACTIVATION_KERNEL(cos, CosKernel)
PD_REGISTER_ACTIVATION_KERNEL(tan, TanKernel)
#define PD_REGISTER_ACTIVATION_KERNEL_WITH_COMPLEX(name, func) \
PD_REGISTER_KERNEL(name, \
GPU, \
ALL_LAYOUT, \
phi::func, \
float, \
double, \
phi::dtype::float16, \
phi::dtype::bfloat16, \
phi::dtype::complex<float>, \
phi::dtype::complex<double>) {}

PD_REGISTER_ACTIVATION_KERNEL_WITH_COMPLEX(sin, SinKernel)
PD_REGISTER_ACTIVATION_KERNEL_WITH_COMPLEX(cos, CosKernel)
PD_REGISTER_ACTIVATION_KERNEL_WITH_COMPLEX(tan, TanKernel)
PD_REGISTER_ACTIVATION_KERNEL(acos, AcosKernel)
PD_REGISTER_ACTIVATION_KERNEL(asin, AsinKernel)
PD_REGISTER_ACTIVATION_KERNEL(atan, AtanKernel)
Expand All @@ -218,7 +230,7 @@ PD_REGISTER_ACTIVATION_KERNEL(cosh, CoshKernel)
PD_REGISTER_ACTIVATION_KERNEL(asinh, AsinhKernel)
PD_REGISTER_ACTIVATION_KERNEL(acosh, AcoshKernel)
PD_REGISTER_ACTIVATION_KERNEL(atanh, AtanhKernel)
PD_REGISTER_ACTIVATION_KERNEL(tanh, TanhKernel)
PD_REGISTER_ACTIVATION_KERNEL_WITH_COMPLEX(tanh, TanhKernel)
PD_REGISTER_ACTIVATION_KERNEL(hardtanh, HardTanhKernel)
PD_REGISTER_ACTIVATION_KERNEL(thresholded_relu, ThresholdedReluKernel)
PD_REGISTER_ACTIVATION_KERNEL(relu6_raw, Relu6RawKernel)
Expand Down
29 changes: 26 additions & 3 deletions python/paddle/tensor/ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -482,7 +482,10 @@ def cos(x, name=None):
return _C_ops.cos(x)
else:
check_variable_and_dtype(
x, 'x', ['float16', 'float32', 'float64'], 'cos'
x,
'x',
['float16', 'float32', 'float64', 'complex64', 'complex128'],
'cos',
)
helper = LayerHelper('cos', **locals())
out = helper.create_variable_for_type_inference(dtype=x.dtype)
Expand Down Expand Up @@ -849,7 +852,17 @@ def sin(x, name=None):
return _C_ops.sin(x)
else:
check_variable_and_dtype(
x, 'x', ['float16', 'uint16', 'float32', 'float64'], 'sin'
x,
'x',
[
'float16',
'uint16',
'float32',
'float64',
'complex64',
'complex128',
],
'sin',
)
helper = LayerHelper('sin', **locals())
out = helper.create_variable_for_type_inference(dtype=x.dtype)
Expand Down Expand Up @@ -1011,7 +1024,17 @@ def tan(x, name=None):
return _C_ops.tan(x)
else:
check_variable_and_dtype(
x, 'x', ['float16', 'uint16', 'float32', 'float64'], 'tan'
x,
'x',
[
'float16',
'uint16',
'float32',
'float64',
'complex64',
'complex128',
],
'tan',
)
helper = LayerHelper('tan', **locals())
out = helper.create_variable_for_type_inference(dtype=x.dtype)
Expand Down
46 changes: 44 additions & 2 deletions test/legacy_test/test_activation_op.py
Original file line number Diff line number Diff line change
Expand Up @@ -566,6 +566,16 @@ def if_enable_cinn(self):
pass


class TestTanh_Complex64(TestTanh):
def init_dytpe(self):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

只是改变dtype,似乎input只有实部,没有虚部(虚部都为0)

self.dtype = np.complex64


class TestTanh_Complex128(TestTanh):
def init_dytpe(self):
self.dtype = np.complex128


class TestTanh_ZeroDim(TestTanh):
def init_shape(self):
self.shape = []
Expand Down Expand Up @@ -1577,12 +1587,23 @@ def init_shape(self):
def test_check_grad(self):
if self.dtype == np.float16:
return
self.check_grad(['X'], 'Out', check_prim=True)
# TODO(ScottWong98): set `check_prim=False` when `fill_any_like` supports `complex` dtype
self.check_grad(['X'], 'Out', check_prim=False)

def if_enable_cinn(self):
pass


class TestCos_Complex64(TestCos):
def init_dtype(self):
self.dtype = np.complex64


class TestCos_Complex128(TestCos):
def init_dtype(self):
self.dtype = np.complex128


class TestCos_ZeroDim(TestCos):
def init_shape(self):
self.shape = []
Expand Down Expand Up @@ -1619,6 +1640,16 @@ def test_check_grad(self):
self.check_grad(['X'], 'Out')


class TestTan_Complex64(TestTan):
def init_dtype(self):
self.dtype = np.complex64


class TestTan_Complex128(TestTan):
def init_dtype(self):
self.dtype = np.complex128


class TestTan_ZeroDim(TestTan):
def init_shape(self):
self.shape = []
Expand Down Expand Up @@ -1718,12 +1749,23 @@ def init_shape(self):
def test_check_grad(self):
if self.dtype == np.float16:
return
self.check_grad(['X'], 'Out', check_prim=True)
# TODO(ScottWong98): set `check_prim=False` when `fill_any_like` supports `complex` dtype
self.check_grad(['X'], 'Out', check_prim=False)

def if_enable_cinn(self):
pass


class TestSin_Complex64(TestSin):
def init_dtype(self):
self.dtype = np.complex64


class TestSin_Complex128(TestSin):
def init_dtype(self):
self.dtype = np.complex128


class TestSin_ZeroDim(TestSin):
def init_shape(self):
self.shape = []
Expand Down