Skip to content
New issue

Have a question about this project? # for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “#”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? # to your account

【complex op No.7】add complex support for Log/log10/log2/log1p #62448

Merged
merged 7 commits into from
Mar 29, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
11 changes: 6 additions & 5 deletions paddle/phi/kernels/cpu/activation_grad_kernel.cc
Original file line number Diff line number Diff line change
Expand Up @@ -438,11 +438,12 @@ PD_REGISTER_ACTIVATION_GRAD_KERNEL_WITH_COMPLEX(sigmoid_triple_grad,
PD_REGISTER_ACTIVATION_GRAD_KERNEL(hardsigmoid_grad, HardSigmoidGradKernel)
PD_REGISTER_ACTIVATION_GRAD_KERNEL_WITH_COMPLEX(logsigmoid_grad,
LogSigmoidGradKernel)
PD_REGISTER_ACTIVATION_GRAD_KERNEL(log_grad, LogGradKernel)
PD_REGISTER_ACTIVATION_GRAD_KERNEL(log2_grad, Log2GradKernel)
PD_REGISTER_ACTIVATION_GRAD_KERNEL(log10_grad, Log10GradKernel)
PD_REGISTER_ACTIVATION_GRAD_KERNEL(log1p_grad, Log1pGradKernel)
PD_REGISTER_ACTIVATION_DOUBLE_GRAD_KERNEL(log_double_grad, LogDoubleGradKernel)
PD_REGISTER_ACTIVATION_GRAD_KERNEL_WITH_COMPLEX(log_grad, LogGradKernel)
PD_REGISTER_ACTIVATION_GRAD_KERNEL_WITH_COMPLEX(log2_grad, Log2GradKernel)
PD_REGISTER_ACTIVATION_GRAD_KERNEL_WITH_COMPLEX(log10_grad, Log10GradKernel)
PD_REGISTER_ACTIVATION_GRAD_KERNEL_WITH_COMPLEX(log1p_grad, Log1pGradKernel)
PD_REGISTER_ACTIVATION_DOUBLE_GRAD_KERNEL_WITH_COMPLEX(log_double_grad,
LogDoubleGradKernel)
PD_REGISTER_ACTIVATION_GRAD_KERNEL_WITH_COMPLEX(hardswish_grad,
HardSwishGradKernel)
PD_REGISTER_ACTIVATION_GRAD_KERNEL(swish_grad, SwishGradKernel)
Expand Down
16 changes: 12 additions & 4 deletions paddle/phi/kernels/cpu/activation_kernel.cc
Original file line number Diff line number Diff line change
Expand Up @@ -254,7 +254,9 @@ PD_REGISTER_KERNEL(log,
int,
int64_t,
phi::dtype::float16,
phi::dtype::bfloat16) {}
phi::dtype::bfloat16,
phi::dtype::complex<float>,
phi::dtype::complex<double>) {}
PD_REGISTER_KERNEL(log2,
CPU,
ALL_LAYOUT,
Expand All @@ -264,7 +266,9 @@ PD_REGISTER_KERNEL(log2,
int,
int64_t,
phi::dtype::float16,
phi::dtype::bfloat16) {}
phi::dtype::bfloat16,
phi::dtype::complex<float>,
phi::dtype::complex<double>) {}
PD_REGISTER_KERNEL(log10,
CPU,
ALL_LAYOUT,
Expand All @@ -274,7 +278,9 @@ PD_REGISTER_KERNEL(log10,
int,
int64_t,
phi::dtype::float16,
phi::dtype::bfloat16) {}
phi::dtype::bfloat16,
phi::dtype::complex<float>,
phi::dtype::complex<double>) {}
PD_REGISTER_KERNEL(log1p,
CPU,
ALL_LAYOUT,
Expand All @@ -284,7 +290,9 @@ PD_REGISTER_KERNEL(log1p,
int,
int64_t,
phi::dtype::float16,
phi::dtype::bfloat16) {}
phi::dtype::bfloat16,
phi::dtype::complex<float>,
phi::dtype::complex<double>) {}

PD_REGISTER_ACTIVATION_KERNEL_WITH_COMPLEX(hardswish, HardSwishKernel)
PD_REGISTER_ACTIVATION_KERNEL(round, RoundKernel)
Expand Down
225 changes: 225 additions & 0 deletions paddle/phi/kernels/funcs/activation_functor.h
Original file line number Diff line number Diff line change
Expand Up @@ -2445,6 +2445,13 @@ struct Log {
HOSTDEVICE T operator()(const T& val) const { return std::log(val); }
};

template <typename T>
struct Log<ComplexType<T>> {
HOSTDEVICE ComplexType<T> operator()(const ComplexType<T>& val) const {
return ComplexType<T>(std::log(std::complex<T>(val)));
}
};

template <>
struct Log<dtype::float16> {
HOSTDEVICE dtype::float16 operator()(const dtype::float16& val) const {
Expand Down Expand Up @@ -2484,11 +2491,35 @@ struct LogGradFunctor : public BaseActivationFunctor<T> {
static constexpr ActBwdOpFwdDeps FwdDeps() { return ActBwdOpFwdDeps::kDepX; }
};

template <typename T>
struct LogGradFunctor<ComplexType<T>>
: public BaseActivationFunctor<ComplexType<T>> {
template <typename Device,
typename X,
typename Out,
typename dOut,
typename dX>
void operator()(Device d, X x, Out out UNUSED, dOut dout, dX dx) const {
dx.device(d) =
dout * (static_cast<ComplexType<T>>(1) / x).unaryExpr(Conj<T>());
}

static constexpr ActBwdOpFwdDeps FwdDeps() { return ActBwdOpFwdDeps::kDepX; }
};

template <typename T>
struct Log2 {
HOSTDEVICE T operator()(const T& val) const { return std::log2(val); }
};

template <typename T>
struct Log2<ComplexType<T>> {
HOSTDEVICE ComplexType<T> operator()(const ComplexType<T>& val) const {
return ComplexType<T>(std::log(std::complex<T>(val)) /
std::log(std::complex<T>(2)));
}
};

template <>
struct Log2<dtype::float16> {
HOSTDEVICE dtype::float16 operator()(const dtype::float16& val) const {
Expand Down Expand Up @@ -2529,11 +2560,35 @@ struct Log2GradFunctor : public BaseActivationFunctor<T> {
static constexpr ActBwdOpFwdDeps FwdDeps() { return ActBwdOpFwdDeps::kDepX; }
};

template <typename T>
struct Log2GradFunctor<ComplexType<T>>
: public BaseActivationFunctor<ComplexType<T>> {
template <typename Device,
typename X,
typename Out,
typename dOut,
typename dX>
void operator()(Device d, X x, Out out UNUSED, dOut dout, dX dx) const {
dx.device(d) = dout * (static_cast<ComplexType<T>>(1) /
(x * static_cast<ComplexType<T>>(log(2))))
.unaryExpr(Conj<T>());
}

static constexpr ActBwdOpFwdDeps FwdDeps() { return ActBwdOpFwdDeps::kDepX; }
};

template <typename T>
struct Log10 {
HOSTDEVICE T operator()(const T& val) const { return std::log10(val); }
};

template <typename T>
struct Log10<ComplexType<T>> {
HOSTDEVICE ComplexType<T> operator()(const ComplexType<T>& val) const {
return ComplexType<T>(std::log10(std::complex<T>(val)));
}
};

template <>
struct Log10<dtype::float16> {
HOSTDEVICE dtype::float16 operator()(const dtype::float16& val) const {
Expand Down Expand Up @@ -2574,11 +2629,35 @@ struct Log10GradFunctor : public BaseActivationFunctor<T> {
static constexpr ActBwdOpFwdDeps FwdDeps() { return ActBwdOpFwdDeps::kDepX; }
};

template <typename T>
struct Log10GradFunctor<ComplexType<T>>
: public BaseActivationFunctor<ComplexType<T>> {
template <typename Device,
typename X,
typename Out,
typename dOut,
typename dX>
void operator()(Device d, X x, Out out UNUSED, dOut dout, dX dx) const {
dx.device(d) = dout * (static_cast<ComplexType<T>>(1) /
(x * static_cast<ComplexType<T>>(log(10))))
.unaryExpr(Conj<T>());
}

static constexpr ActBwdOpFwdDeps FwdDeps() { return ActBwdOpFwdDeps::kDepX; }
};

template <typename T>
struct Log1p {
HOSTDEVICE T operator()(const T& val) const { return std::log1p(val); }
};

template <typename T>
struct Log1p<ComplexType<T>> {
HOSTDEVICE ComplexType<T> operator()(const ComplexType<T>& val) const {
return ComplexType<T>(std::log(std::complex<T>(1) + std::complex<T>(val)));
}
};

template <>
struct Log1p<dtype::float16> {
HOSTDEVICE dtype::float16 operator()(const dtype::float16& val) const {
Expand Down Expand Up @@ -2618,6 +2697,23 @@ struct Log1pGradFunctor : public BaseActivationFunctor<T> {
static constexpr ActBwdOpFwdDeps FwdDeps() { return ActBwdOpFwdDeps::kDepX; }
};

template <typename T>
struct Log1pGradFunctor<ComplexType<T>>
: public BaseActivationFunctor<ComplexType<T>> {
template <typename Device,
typename X,
typename Out,
typename dOut,
typename dX>
void operator()(Device d, X x, Out out UNUSED, dOut dout, dX dx) const {
dx.device(d) = dout * (static_cast<ComplexType<T>>(1) /
(x + static_cast<ComplexType<T>>(1)))
.unaryExpr(Conj<T>());
}

static constexpr ActBwdOpFwdDeps FwdDeps() { return ActBwdOpFwdDeps::kDepX; }
};

template <typename T>
struct LogGradGradFunctor : public BaseActivationFunctor<T> {
template <typename Device>
Expand Down Expand Up @@ -2651,6 +2747,42 @@ struct LogGradGradFunctor : public BaseActivationFunctor<T> {
static constexpr ActBwdOpFwdDeps FwdDeps() { return ActBwdOpFwdDeps::kDepX; }
};

template <typename T>
struct LogGradGradFunctor<ComplexType<T>>
: public BaseActivationFunctor<ComplexType<T>> {
template <typename Device>
void operator()(const Device& dev,
const DenseTensor* X,
const DenseTensor* ddX,
DenseTensor* ddOut,
const DenseTensor* dOut,
DenseTensor* dX) const {
auto* d = dev.eigen_device();
auto ddx = EigenVector<ComplexType<T>>::Flatten(
GET_DATA_SAFELY(ddX, "Input", "DDX", "LogGradGrad"));
auto x = EigenVector<ComplexType<T>>::Flatten(
GET_DATA_SAFELY(X, "Input", "X", "LogGradGrad"));
// ddout = ddx / x; dx = -(dout / x) * (ddx / x)
// calculate dx first, so ddout can inplace ddx
if (dX) {
auto dout = EigenVector<ComplexType<T>>::Flatten(
GET_DATA_SAFELY(dOut, "Output", "DOut", "LogGradGrad"));
auto dx = EigenVector<ComplexType<T>>::Flatten(
GET_DATA_SAFELY(dX, "Output", "DX", "LogGradGrad"));
dx.device(*d) = dout * static_cast<ComplexType<T>>(-1) * ddx /
(x * x).unaryExpr(Conj<T>());
}
if (ddOut) {
auto ddout = EigenVector<ComplexType<T>>::Flatten(
GET_DATA_SAFELY(ddOut, "Output", "DDOut", "LogGradGrad"));
ddout.device(*d) =
ddx * static_cast<ComplexType<T>>(1) / x.unaryExpr(Conj<T>());
}
}

static constexpr ActBwdOpFwdDeps FwdDeps() { return ActBwdOpFwdDeps::kDepX; }
};

// HardSwish = min(max(0, x+3), 6) * x / 6
template <typename T>
struct HardSwishFunctor : public BaseActivationFunctor<T> {
Expand Down Expand Up @@ -4642,6 +4774,16 @@ struct CudaLogFunctor : public BaseActivationFunctor<T> {
}
};

template <typename T>
struct CudaLogFunctor<ComplexType<T>>
: public BaseActivationFunctor<ComplexType<T>> {
// log(x) = log(x)
__device__ __forceinline__ ComplexType<T> operator()(
const ComplexType<T> arg_x) const {
return static_cast<ComplexType<T>>(log(arg_x));
}
};

template <typename T>
struct CudaLogGradFunctor : public BaseActivationFunctor<T> {
// dx = dout / x
Expand All @@ -4652,6 +4794,18 @@ struct CudaLogGradFunctor : public BaseActivationFunctor<T> {
static constexpr ActBwdOpFwdDeps FwdDeps() { return ActBwdOpFwdDeps::kDepX; }
};

template <typename T>
struct CudaLogGradFunctor<ComplexType<T>>
: public BaseActivationFunctor<ComplexType<T>> {
// dx = dout / conj(x)
__device__ __forceinline__ ComplexType<T> operator()(
const ComplexType<T> dout, const ComplexType<T> x) const {
return dout / conj(x);
}

static constexpr ActBwdOpFwdDeps FwdDeps() { return ActBwdOpFwdDeps::kDepX; }
};

template <typename T>
struct CudaLog1pFunctor : public BaseActivationFunctor<T> {
using MPType = typename phi::dtype::MPTypeTrait<T>::Type;
Expand All @@ -4665,6 +4819,17 @@ struct CudaLog1pFunctor : public BaseActivationFunctor<T> {
}
};

template <typename T>
struct CudaLog1pFunctor<ComplexType<T>>
: public BaseActivationFunctor<ComplexType<T>> {
// log1p(x) = log(1 + x)
__device__ __forceinline__ ComplexType<T> operator()(
const ComplexType<T> arg_x) const {
return static_cast<ComplexType<T>>(
log(static_cast<ComplexType<T>>(1) + arg_x));
}
};

template <typename T>
struct CudaLog1pGradFunctor : public BaseActivationFunctor<T> {
T one = static_cast<T>(1.0f);
Expand All @@ -4677,6 +4842,20 @@ struct CudaLog1pGradFunctor : public BaseActivationFunctor<T> {
static constexpr ActBwdOpFwdDeps FwdDeps() { return ActBwdOpFwdDeps::kDepX; }
};

template <typename T>
struct CudaLog1pGradFunctor<ComplexType<T>>
: public BaseActivationFunctor<ComplexType<T>> {
ComplexType<T> one = static_cast<ComplexType<T>>(1.0f);

// dx = dout / conj(1 + x)
__device__ __forceinline__ ComplexType<T> operator()(
const ComplexType<T> dout, const ComplexType<T> x) const {
return dout / conj(one + x);
}

static constexpr ActBwdOpFwdDeps FwdDeps() { return ActBwdOpFwdDeps::kDepX; }
};

template <typename T>
__device__ __forceinline__
std::conditional_t<std::is_integral<T>::value, float, T>
Expand Down Expand Up @@ -4709,6 +4888,17 @@ struct CudaLog2Functor : public BaseActivationFunctor<T> {
}
};

template <typename T>
struct CudaLog2Functor<ComplexType<T>>
: public BaseActivationFunctor<ComplexType<T>> {
// log2(x) = log(x)/log(2)
__device__ __forceinline__ ComplexType<T> operator()(
const ComplexType<T> arg_x) const {
return static_cast<ComplexType<T>>(log(arg_x) /
static_cast<ComplexType<T>>(log(2.0f)));
}
};

template <typename T>
struct CudaLog2GradFunctor : public BaseActivationFunctor<T> {
using MPType = typename phi::dtype::MPTypeTrait<T>::Type;
Expand All @@ -4722,6 +4912,18 @@ struct CudaLog2GradFunctor : public BaseActivationFunctor<T> {
static constexpr ActBwdOpFwdDeps FwdDeps() { return ActBwdOpFwdDeps::kDepX; }
};

template <typename T>
struct CudaLog2GradFunctor<ComplexType<T>>
: public BaseActivationFunctor<ComplexType<T>> {
// dx = dout / conj(x * log(2))
__device__ __forceinline__ ComplexType<T> operator()(
const ComplexType<T> dout, const ComplexType<T> x) const {
return dout / conj(x * static_cast<ComplexType<T>>(log(2.0f)));
}

static constexpr ActBwdOpFwdDeps FwdDeps() { return ActBwdOpFwdDeps::kDepX; }
};

template <typename T>
__device__ __forceinline__
std::conditional_t<std::is_integral<T>::value, float, T>
Expand Down Expand Up @@ -4754,6 +4956,17 @@ struct CudaLog10Functor : public BaseActivationFunctor<T> {
}
};

template <typename T>
struct CudaLog10Functor<ComplexType<T>>
: public BaseActivationFunctor<ComplexType<T>> {
// log10(x) = log(x)/log(10)
__device__ __forceinline__ ComplexType<T> operator()(
const ComplexType<T> arg_x) const {
return static_cast<ComplexType<T>>(log(arg_x) /
static_cast<ComplexType<T>>(log(10.0f)));
}
};

template <typename T>
struct CudaLog10GradFunctor : public BaseActivationFunctor<T> {
using MPType = typename phi::dtype::MPTypeTrait<T>::Type;
Expand All @@ -4767,6 +4980,18 @@ struct CudaLog10GradFunctor : public BaseActivationFunctor<T> {
static constexpr ActBwdOpFwdDeps FwdDeps() { return ActBwdOpFwdDeps::kDepX; }
};

template <typename T>
struct CudaLog10GradFunctor<ComplexType<T>>
: public BaseActivationFunctor<ComplexType<T>> {
// dx = dout / conj(x * log(10))
__device__ __forceinline__ ComplexType<T> operator()(
const ComplexType<T> dout, const ComplexType<T> x) const {
return dout / conj(x * static_cast<ComplexType<T>>(log(10.0f)));
}

static constexpr ActBwdOpFwdDeps FwdDeps() { return ActBwdOpFwdDeps::kDepX; }
};

template <typename T>
struct CudaSwishFunctor : public BaseActivationFunctor<T> {
using MPType = typename phi::dtype::MPTypeTrait<T>::Type;
Expand Down
Loading