Skip to content
New issue

Have a question about this project? # for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “#”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? # to your account

【Complex op】add complex support for sin, cos, tan, tanh #55380

Merged
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
40 changes: 40 additions & 0 deletions paddle/phi/common/complex.h
Original file line number Diff line number Diff line change
Expand Up @@ -422,6 +422,36 @@ HOSTDEVICE inline complex<T> sqrt(const complex<T>& a) {
#endif
}

template <typename T>
HOSTDEVICE inline complex<T> sin(const complex<T>& a) {
#if defined(PADDLE_WITH_CUDA_OR_HIP_COMPLEX) && \
(defined(__CUDA_ARCH__) || defined(__HIPCC__))
return complex<T>(thrust::sin(thrust::complex<T>(a)));
#else
return complex<T>(std::sin(std::complex<T>(a)));
#endif
}

template <typename T>
HOSTDEVICE inline complex<T> cos(const complex<T>& a) {
#if defined(PADDLE_WITH_CUDA_OR_HIP_COMPLEX) && \
(defined(__CUDA_ARCH__) || defined(__HIPCC__))
return complex<T>(thrust::cos(thrust::complex<T>(a)));
#else
return complex<T>(std::cos(std::complex<T>(a)));
#endif
}

template <typename T>
HOSTDEVICE inline complex<T> tan(const complex<T>& a) {
#if defined(PADDLE_WITH_CUDA_OR_HIP_COMPLEX) && \
(defined(__CUDA_ARCH__) || defined(__HIPCC__))
return complex<T>(thrust::tan(thrust::complex<T>(a)));
#else
return complex<T>(std::tan(std::complex<T>(a)));
#endif
}

template <typename T>
HOSTDEVICE inline complex<T> tanh(const complex<T>& a) {
#if defined(PADDLE_WITH_CUDA_OR_HIP_COMPLEX) && \
Expand All @@ -432,6 +462,16 @@ HOSTDEVICE inline complex<T> tanh(const complex<T>& a) {
#endif
}

template <typename T>
HOSTDEVICE inline complex<T> conj(const complex<T>& a) {
#if defined(PADDLE_WITH_CUDA_OR_HIP_COMPLEX) && \
(defined(__CUDA_ARCH__) || defined(__HIPCC__))
return complex<T>(thrust::conj(thrust::complex<T>(a)));
#else
return complex<T>(std::conj(std::complex<T>(a)));
#endif
}

template <typename T>
HOSTDEVICE inline complex<T> log(const complex<T>& a) {
#if defined(PADDLE_WITH_CUDA_OR_HIP_COMPLEX) && \
Expand Down
53 changes: 42 additions & 11 deletions paddle/phi/kernels/cpu/activation_grad_kernel.cc
Original file line number Diff line number Diff line change
Expand Up @@ -255,13 +255,34 @@ PD_REGISTER_KERNEL(
#define PD_REGISTER_ACTIVATION_GRAD_KERNEL(name, func) \
PD_REGISTER_KERNEL(name, CPU, ALL_LAYOUT, phi::func, float, double) {}

#define PD_REGISTER_ACTIVATION_GRAD_KERNEL_WITH_COMPLEX(name, func) \
PD_REGISTER_KERNEL(name, \
CPU, \
ALL_LAYOUT, \
phi::func, \
float, \
double, \
phi::dtype::complex<float>, \
phi::dtype::complex<double>) {}

#define PD_REGISTER_ACTIVATION_DOUBLE_GRAD_KERNEL(name, func) \
PD_REGISTER_KERNEL( \
name, CPU, ALL_LAYOUT, phi::func, float, double, phi::dtype::float16) {}

PD_REGISTER_ACTIVATION_GRAD_KERNEL(sin_grad, SinGradKernel)
PD_REGISTER_ACTIVATION_GRAD_KERNEL(cos_grad, CosGradKernel)
PD_REGISTER_ACTIVATION_GRAD_KERNEL(tan_grad, TanGradKernel)
#define PD_REGISTER_ACTIVATION_DOUBLE_GRAD_KERNEL_WITH_COMPLEX(name, func) \
PD_REGISTER_KERNEL(name, \
CPU, \
ALL_LAYOUT, \
phi::func, \
float, \
double, \
phi::dtype::float16, \
phi::dtype::complex<float>, \
phi::dtype::complex<double>) {}

PD_REGISTER_ACTIVATION_GRAD_KERNEL_WITH_COMPLEX(sin_grad, SinGradKernel)
PD_REGISTER_ACTIVATION_GRAD_KERNEL_WITH_COMPLEX(cos_grad, CosGradKernel)
PD_REGISTER_ACTIVATION_GRAD_KERNEL_WITH_COMPLEX(tan_grad, TanGradKernel)
PD_REGISTER_ACTIVATION_GRAD_KERNEL(acos_grad, AcosGradKernel)
PD_REGISTER_ACTIVATION_GRAD_KERNEL(asin_grad, AsinGradKernel)
PD_REGISTER_ACTIVATION_GRAD_KERNEL(atan_grad, AtanGradKernel)
Expand All @@ -270,7 +291,7 @@ PD_REGISTER_ACTIVATION_GRAD_KERNEL(cosh_grad, CoshGradKernel)
PD_REGISTER_ACTIVATION_GRAD_KERNEL(asinh_grad, AsinhGradKernel)
PD_REGISTER_ACTIVATION_GRAD_KERNEL(acosh_grad, AcoshGradKernel)
PD_REGISTER_ACTIVATION_GRAD_KERNEL(atanh_grad, AtanhGradKernel)
PD_REGISTER_ACTIVATION_GRAD_KERNEL(tanh_grad, TanhGradKernel)
PD_REGISTER_ACTIVATION_GRAD_KERNEL_WITH_COMPLEX(tanh_grad, TanhGradKernel)
PD_REGISTER_ACTIVATION_GRAD_KERNEL(hardtanh_grad, HardTanhGradKernel)
PD_REGISTER_ACTIVATION_GRAD_KERNEL(leaky_relu_grad, LeakyReluGradKernel)
PD_REGISTER_ACTIVATION_GRAD_KERNEL(thresholded_relu_grad,
Expand All @@ -290,8 +311,8 @@ PD_REGISTER_ACTIVATION_GRAD_KERNEL(softplus_grad, SoftplusGradKernel)

PD_REGISTER_ACTIVATION_DOUBLE_GRAD_KERNEL(relu_double_grad,
ReluDoubleGradKernel)
PD_REGISTER_ACTIVATION_DOUBLE_GRAD_KERNEL(tanh_double_grad,
TanhDoubleGradKernel)
PD_REGISTER_ACTIVATION_DOUBLE_GRAD_KERNEL_WITH_COMPLEX(tanh_double_grad,
TanhDoubleGradKernel)
PD_REGISTER_ACTIVATION_DOUBLE_GRAD_KERNEL(leaky_relu_double_grad,
LeakyReluDoubleGradKernel)
PD_REGISTER_ACTIVATION_DOUBLE_GRAD_KERNEL(elu_double_grad, EluDoubleGradKernel)
Expand All @@ -308,7 +329,9 @@ PD_REGISTER_KERNEL(tanh_triple_grad,
phi::TanhTripleGradKernel,
float,
double,
phi::dtype::float16) {}
phi::dtype::float16,
phi::dtype::complex<float>,
phi::dtype::complex<double>) {}

PD_REGISTER_KERNEL(exp_grad,
CPU,
Expand Down Expand Up @@ -355,7 +378,9 @@ PD_REGISTER_KERNEL(sin_double_grad,
double,
phi::dtype::float16,
int,
int64_t) {}
int64_t,
phi::dtype::complex<float>,
phi::dtype::complex<double>) {}

PD_REGISTER_KERNEL(sin_triple_grad,
CPU,
Expand All @@ -365,7 +390,9 @@ PD_REGISTER_KERNEL(sin_triple_grad,
double,
phi::dtype::float16,
int,
int64_t) {}
int64_t,
phi::dtype::complex<float>,
phi::dtype::complex<double>) {}

PD_REGISTER_KERNEL(cos_double_grad,
CPU,
Expand All @@ -375,7 +402,9 @@ PD_REGISTER_KERNEL(cos_double_grad,
double,
phi::dtype::float16,
int,
int64_t) {}
int64_t,
phi::dtype::complex<float>,
phi::dtype::complex<double>) {}

PD_REGISTER_KERNEL(cos_triple_grad,
CPU,
Expand All @@ -385,7 +414,9 @@ PD_REGISTER_KERNEL(cos_triple_grad,
double,
phi::dtype::float16,
int,
int64_t) {}
int64_t,
phi::dtype::complex<float>,
phi::dtype::complex<double>) {}

PD_REGISTER_ACTIVATION_GRAD_KERNEL(softsign_grad, SoftsignGradKernel)
PD_REGISTER_ACTIVATION_GRAD_KERNEL(sigmoid_grad, SigmoidGradKernel)
Expand Down
18 changes: 14 additions & 4 deletions paddle/phi/kernels/cpu/activation_kernel.cc
Original file line number Diff line number Diff line change
Expand Up @@ -166,9 +166,19 @@ PD_REGISTER_KERNEL(relu, CPU, ALL_LAYOUT, phi::ReluKernel, float, double) {}
#define PD_REGISTER_ACTIVATION_KERNEL(name, func) \
PD_REGISTER_KERNEL(name, CPU, ALL_LAYOUT, phi::func, float, double) {}

PD_REGISTER_ACTIVATION_KERNEL(sin, SinKernel)
PD_REGISTER_ACTIVATION_KERNEL(cos, CosKernel)
PD_REGISTER_ACTIVATION_KERNEL(tan, TanKernel)
#define PD_REGISTER_ACTIVATION_KERNEL_WITH_COMPLEX(name, func) \
PD_REGISTER_KERNEL(name, \
CPU, \
ALL_LAYOUT, \
phi::func, \
float, \
double, \
phi::dtype::complex<float>, \
phi::dtype::complex<double>) {}

PD_REGISTER_ACTIVATION_KERNEL_WITH_COMPLEX(sin, SinKernel)
PD_REGISTER_ACTIVATION_KERNEL_WITH_COMPLEX(cos, CosKernel)
PD_REGISTER_ACTIVATION_KERNEL_WITH_COMPLEX(tan, TanKernel)
PD_REGISTER_ACTIVATION_KERNEL(acos, AcosKernel)
PD_REGISTER_ACTIVATION_KERNEL(asin, AsinKernel)
PD_REGISTER_ACTIVATION_KERNEL(atan, AtanKernel)
Expand All @@ -177,7 +187,7 @@ PD_REGISTER_ACTIVATION_KERNEL(cosh, CoshKernel)
PD_REGISTER_ACTIVATION_KERNEL(asinh, AsinhKernel)
PD_REGISTER_ACTIVATION_KERNEL(acosh, AcoshKernel)
PD_REGISTER_ACTIVATION_KERNEL(atanh, AtanhKernel)
PD_REGISTER_ACTIVATION_KERNEL(tanh, TanhKernel)
PD_REGISTER_ACTIVATION_KERNEL_WITH_COMPLEX(tanh, TanhKernel)
PD_REGISTER_ACTIVATION_KERNEL(hardtanh, HardTanhKernel)
PD_REGISTER_ACTIVATION_KERNEL(leaky_relu, LeakyReluKernel)
PD_REGISTER_ACTIVATION_KERNEL(thresholded_relu, ThresholdedReluKernel)
Expand Down
137 changes: 137 additions & 0 deletions paddle/phi/kernels/funcs/activation_functor.h
Original file line number Diff line number Diff line change
Expand Up @@ -96,6 +96,17 @@ struct Cosine<dtype::bfloat16> {
}
};

template <typename T>
using ComplexType = phi::dtype::complex<T>;

// T is phi::dtype::complex<float> or phi::dtype::complex<double>
template <typename T>
struct Conj {
HOSTDEVICE ComplexType<T> operator()(const ComplexType<T>& val) const {
return ComplexType<T>(val.real, -val.imag);
}
};

// sine'(x) = cos(x)
template <typename T>
struct SinGradFunctor : public BaseActivationFunctor<T> {
Expand All @@ -111,6 +122,21 @@ struct SinGradFunctor : public BaseActivationFunctor<T> {
static constexpr ActBwdOpFwdDeps FwdDeps() { return kDepX; }
};

template <typename T>
struct SinGradFunctor<ComplexType<T>>
: public BaseActivationFunctor<ComplexType<T>> {
template <typename Device,
typename X,
typename Out,
typename dOut,
typename dX>
void operator()(Device d, X x, Out out UNUSED, dOut dout, dX dx) const {
dx.device(d) =
dout * x.unaryExpr(Cosine<ComplexType<T>>()).unaryExpr(Conj<T>());
}

static constexpr ActBwdOpFwdDeps FwdDeps() { return kDepX; }
};
// sine(x) = sin(x)
template <typename T>
struct SinFunctor : public BaseActivationFunctor<T> {
Expand Down Expand Up @@ -320,6 +346,22 @@ struct CosGradFunctor : public BaseActivationFunctor<T> {
static constexpr ActBwdOpFwdDeps FwdDeps() { return kDepX; }
};

template <typename T>
struct CosGradFunctor<ComplexType<T>>
: public BaseActivationFunctor<ComplexType<T>> {
template <typename Device,
typename X,
typename Out,
typename dOut,
typename dX>
void operator()(Device d, X x, Out out UNUSED, dOut dout, dX dx) const {
dx.device(d) =
-dout * x.unaryExpr(Sine<ComplexType<T>>()).unaryExpr(Conj<T>());
}

static constexpr ActBwdOpFwdDeps FwdDeps() { return kDepX; }
};

// cos''(x) = -cos(x)
template <typename T>
struct CosDoubleGradFunctor : public BaseActivationFunctor<T> {
Expand Down Expand Up @@ -584,6 +626,27 @@ struct TanGradFunctor : public BaseActivationFunctor<T> {
static constexpr ActBwdOpFwdDeps FwdDeps() { return kDepX; }
};

template <typename T>
struct TanGradFunctor<ComplexType<T>>
: public BaseActivationFunctor<ComplexType<T>> {
template <typename Device,
typename X,
typename Out,
typename dOut,
typename dX>
void operator()(Device d, X x, Out out UNUSED, dOut dout, dX dx) const {
// auto dx_ =
// static_cast<ComplexType<T>>(x.unaryExpr(Cosine<T>()).square());
// ComplexType<T> dx_conj_(dx_.real, -dx_.imag);
// dx.device(d) = dout / dx_conj_;
dx.device(d) =
dout /
x.unaryExpr(Cosine<ComplexType<T>>()).square().unaryExpr(Conj<T>());
}

static constexpr ActBwdOpFwdDeps FwdDeps() { return kDepX; }
};

// square(x) = x^2
template <typename T>
struct SquareFunctor : public BaseActivationFunctor<T> {
Expand Down Expand Up @@ -1217,6 +1280,28 @@ struct TanhGradFunctor : public BaseActivationFunctor<T> {
}
};

template <typename T>
struct TanhGradFunctor<ComplexType<T>>
: public BaseActivationFunctor<ComplexType<T>> {
template <typename Device,
typename X,
typename Out,
typename dOut,
typename dX>
void operator()(Device d, X x UNUSED, Out out, dOut dout, dX dx) const {
// auto dx_ = static_cast<ComplexType<T>>(1) - out * out;
// ComplexType<T> dx_conj_(dx_.real, -dx_.imag);
// dx.device(d) = dout * dx_conj_;
dx.device(d) =
dout *
(static_cast<ComplexType<T>>(1) - out * out).unaryExpr(Conj<T>());
}

static constexpr ActBwdOpFwdDeps FwdDeps() {
return ActBwdOpFwdDeps::kDepOut;
}
};

template <typename T>
struct TanhGradGradFunctor : public BaseActivationFunctor<T> {
template <typename Device>
Expand Down Expand Up @@ -2675,6 +2760,18 @@ struct CudaCosGradFunctor : public BaseActivationFunctor<T> {
static constexpr ActBwdOpFwdDeps FwdDeps() { return ActBwdOpFwdDeps::kDepX; }
};

template <typename T>
struct CudaCosGradFunctor<ComplexType<T>>
: public BaseActivationFunctor<ComplexType<T>> {
// dx = dout * (-sin(x))
__device__ __forceinline__ ComplexType<T> operator()(
const ComplexType<T> dout, const ComplexType<T> x) const {
return static_cast<ComplexType<T>>(-dout * conj(sin(x)));
}

static constexpr ActBwdOpFwdDeps FwdDeps() { return ActBwdOpFwdDeps::kDepX; }
};

template <typename T>
struct CudaExpFunctor : public BaseActivationFunctor<T> {
// exp(x) = expf(x)
Expand Down Expand Up @@ -2847,6 +2944,18 @@ struct CudaSinGradFunctor : public BaseActivationFunctor<T> {
static constexpr ActBwdOpFwdDeps FwdDeps() { return ActBwdOpFwdDeps::kDepX; }
};

template <typename T>
struct CudaSinGradFunctor<ComplexType<T>>
: public BaseActivationFunctor<ComplexType<T>> {
// dx = dout * cos(x)
__device__ __forceinline__ ComplexType<T> operator()(
const ComplexType<T> dout, const ComplexType<T> x) const {
return static_cast<ComplexType<T>>(dout * conj(cos(x)));
}

static constexpr ActBwdOpFwdDeps FwdDeps() { return ActBwdOpFwdDeps::kDepX; }
};

template <typename T>
struct CudaTanFunctor : public BaseActivationFunctor<T> {
using MPType = typename phi::dtype::MPTypeTrait<T>::Type;
Expand All @@ -2873,6 +2982,18 @@ struct CudaTanGradFunctor : public BaseActivationFunctor<T> {
static constexpr ActBwdOpFwdDeps FwdDeps() { return ActBwdOpFwdDeps::kDepX; }
};

template <typename T>
struct CudaTanGradFunctor<ComplexType<T>>
: public BaseActivationFunctor<ComplexType<T>> {
// dx = dout / cos(x)^2
__device__ __forceinline__ ComplexType<T> operator()(
const ComplexType<T> dout, const ComplexType<T> x) const {
return static_cast<ComplexType<T>>(dout / conj(cos(x) * cos(x)));
}

static constexpr ActBwdOpFwdDeps FwdDeps() { return ActBwdOpFwdDeps::kDepX; }
};

template <typename T>
struct CudaAsinFunctor : public BaseActivationFunctor<T> {
using MPType = typename phi::dtype::MPTypeTrait<T>::Type;
Expand Down Expand Up @@ -3250,6 +3371,22 @@ struct CudaTanhGradFunctor : public BaseActivationFunctor<T> {
}
};

template <typename T>
struct CudaTanhGradFunctor<ComplexType<T>>
: public BaseActivationFunctor<ComplexType<T>> {
ComplexType<T> one = static_cast<ComplexType<T>>(1.0f);

// dx = dout * (1 - out^2)
__device__ __forceinline__ ComplexType<T> operator()(
const ComplexType<T> dout, const ComplexType<T> out) const {
return dout * conj(one - out * out);
}

static constexpr ActBwdOpFwdDeps FwdDeps() {
return ActBwdOpFwdDeps::kDepOut;
}
};

template <typename T>
struct CudaHardTanhFunctor : public BaseActivationFunctor<T> {
float t_min;
Expand Down
Loading