diff --git a/paddle/phi/kernels/cpu/cross_grad_kernel.cc b/paddle/phi/kernels/cpu/cross_grad_kernel.cc index 882c3dd9ee5120..4c41107ba0199e 100644 --- a/paddle/phi/kernels/cpu/cross_grad_kernel.cc +++ b/paddle/phi/kernels/cpu/cross_grad_kernel.cc @@ -18,6 +18,8 @@ #include "paddle/phi/core/dense_tensor.h" #include "paddle/phi/core/kernel_registry.h" #include "paddle/phi/core/tensor_utils.h" +#include "paddle/phi/kernels/funcs/complex_functors.h" +#include "paddle/phi/kernels/funcs/for_range.h" namespace phi { @@ -81,9 +83,27 @@ void CrossGradKernel(const Context &dev_ctx, slice_size *= static_cast(input_x_dims[i]); } + int64_t numel = x.numel(); + DenseTensor x_conj, y_conj; + DenseTensorMeta meta_xy(x.dtype(), x.dims()); + x_conj.set_meta(meta_xy); + y_conj.set_meta(meta_xy); + + auto *input_x_conj_data = dev_ctx.template Alloc(&x_conj); + + auto *input_y_conj_data = dev_ctx.template Alloc(&y_conj); + + phi::funcs::ForRange for_range(dev_ctx, numel); + phi::funcs::ConjFunctor functor_x( + input_x.data(), numel, input_x_conj_data); + phi::funcs::ConjFunctor functor_y( + input_y.data(), numel, input_y_conj_data); + for_range(functor_x); + for_range(functor_y); + std::vector input_x_vec, input_y_vec, input_dout_vec; - phi::TensorToVector(input_x, dev_ctx, &input_x_vec); - phi::TensorToVector(input_y, dev_ctx, &input_y_vec); + phi::TensorToVector(x_conj, dev_ctx, &input_x_vec); + phi::TensorToVector(y_conj, dev_ctx, &input_y_vec); phi::TensorToVector(input_out_grad, dev_ctx, &input_dout_vec); std::vector out_dx_vec(output_x_grad->numel()); std::vector out_dy_vec(output_y_grad->numel()); @@ -120,4 +140,6 @@ PD_REGISTER_KERNEL(cross_grad, float, double, int, - int64_t) {} + int64_t, + phi::dtype::complex, + phi::dtype::complex) {} diff --git a/paddle/phi/kernels/cpu/cross_kernel.cc b/paddle/phi/kernels/cpu/cross_kernel.cc index 0f45b7c304e319..95f826cfe9132b 100644 --- a/paddle/phi/kernels/cpu/cross_kernel.cc +++ b/paddle/phi/kernels/cpu/cross_kernel.cc @@ -105,5 +105,13 @@ void CrossKernel(const Context& dev_ctx, } // namespace phi -PD_REGISTER_KERNEL( - cross, CPU, ALL_LAYOUT, phi::CrossKernel, float, double, int, int64_t) {} +PD_REGISTER_KERNEL(cross, + CPU, + ALL_LAYOUT, + phi::CrossKernel, + float, + double, + int, + int64_t, + phi::dtype::complex, + phi::dtype::complex) {} diff --git a/paddle/phi/kernels/gpu/cross_grad_kernel.cu b/paddle/phi/kernels/gpu/cross_grad_kernel.cu index 58f53fcf3f3d22..33a3ce4e12f3ea 100644 --- a/paddle/phi/kernels/gpu/cross_grad_kernel.cu +++ b/paddle/phi/kernels/gpu/cross_grad_kernel.cu @@ -18,6 +18,8 @@ #include "paddle/phi/backends/gpu/gpu_launch_config.h" #include "paddle/phi/core/dense_tensor.h" #include "paddle/phi/core/kernel_registry.h" +#include "paddle/phi/kernels/funcs/complex_functors.h" +#include "paddle/phi/kernels/funcs/for_range.h" #include "paddle/phi/kernels/funcs/index_calculator.h" namespace phi { @@ -162,27 +164,56 @@ void CrossGradKernel(const Context& dev_ctx, const auto* input_x_data = input_x.data(); const auto* input_y_data = input_y.data(); + int64_t numel = x.numel(); const auto* input_out_grad_data = input_out_grad.data(); auto* output_x_grad_data = dev_ctx.template Alloc(x_grad); auto* output_y_grad_data = dev_ctx.template Alloc(y_grad); auto index_calculator = phi::funcs::IndexCalculator( merged_dims.size() - 1, cal_dims, left_strides, full_strides); - int64_t numel = x.numel(); backends::gpu::GpuLaunchConfig config = backends::gpu::GetGpuLaunchConfig1D(dev_ctx, numel / 3); - - CrossGrad<<>>(input_x_data, - input_y_data, - input_out_grad_data, - output_x_grad_data, - output_y_grad_data, - full_strides[merge_axis], - numel / 3, - index_calculator); + if (IsComplexType(x.dtype())) { + DenseTensor x_conj, y_conj; + DenseTensorMeta meta_xy(x.dtype(), x.dims()); + x_conj.set_meta(meta_xy); + y_conj.set_meta(meta_xy); + + auto* input_x_conj_data = dev_ctx.template Alloc(&x_conj); + auto* input_y_conj_data = dev_ctx.template Alloc(&y_conj); + + phi::funcs::ForRange for_range(dev_ctx, numel); + phi::funcs::ConjFunctor functor_x( + input_x_data, numel, input_x_conj_data); + phi::funcs::ConjFunctor functor_y( + input_y_data, numel, input_y_conj_data); + for_range(functor_x); + for_range(functor_y); + + CrossGrad<<>>(input_x_conj_data, + input_y_conj_data, + input_out_grad_data, + output_x_grad_data, + output_y_grad_data, + full_strides[merge_axis], + numel / 3, + index_calculator); + } else { + CrossGrad<<>>(input_x_data, + input_y_data, + input_out_grad_data, + output_x_grad_data, + output_y_grad_data, + full_strides[merge_axis], + numel / 3, + index_calculator); + } } } // namespace phi @@ -195,4 +226,6 @@ PD_REGISTER_KERNEL(cross_grad, float, double, int, - int64_t) {} + int64_t, + phi::dtype::complex, + phi::dtype::complex) {} diff --git a/paddle/phi/kernels/gpu/cross_kernel.cu b/paddle/phi/kernels/gpu/cross_kernel.cu index 461e3a219d5d6a..f1671c67973f51 100644 --- a/paddle/phi/kernels/gpu/cross_kernel.cu +++ b/paddle/phi/kernels/gpu/cross_kernel.cu @@ -172,4 +172,6 @@ PD_REGISTER_KERNEL(cross, float, double, int, - int64_t) {} + int64_t, + phi::dtype::complex, + phi::dtype::complex) {} diff --git a/python/paddle/tensor/linalg.py b/python/paddle/tensor/linalg.py index 5ff36cdb754d53..36f55559199bb7 100644 --- a/python/paddle/tensor/linalg.py +++ b/python/paddle/tensor/linalg.py @@ -1904,8 +1904,8 @@ def cross(x, y, axis=9, name=None): If `axis` is not given, it defaults to the first axis found with the length 3. Args: - x (Tensor): The first input tensor, the data type is float16, float32, float64, int32, int64. - y (Tensor): The second input tensor, the data type is float16, float32, float64, int32, int64. + x (Tensor): The first input tensor, the data type is float16, float32, float64, int32, int64, complex64, complex128. + y (Tensor): The second input tensor, the data type is float16, float32, float64, int32, int64, complex64, complex128. axis (int, optional): The axis along which to compute the cross product. It defaults to be 9 which indicates using the first axis found with the length 3. name (str, optional): Name for the operation (optional, default is None). For more information, please refer to :ref:`api_guide_Name`. @@ -1945,13 +1945,31 @@ def cross(x, y, axis=9, name=None): check_variable_and_dtype( x, 'x', - ['float16', 'uint16', 'float32', 'float64', "int32", "int64"], + [ + 'float16', + 'uint16', + 'float32', + 'float64', + "int32", + "int64", + "complex64", + "complex128", + ], 'cross', ) check_variable_and_dtype( y, 'y', - ['float16', 'uint16', 'float32', 'float64', "int32", "int64"], + [ + 'float16', + 'uint16', + 'float32', + 'float64', + "int32", + "int64", + "complex64", + "complex128", + ], 'cross', ) helper = LayerHelper("cross", **locals()) diff --git a/test/legacy_test/test_cross_op.py b/test/legacy_test/test_cross_op.py index 803a6924f25f9d..5f35f5099d5d44 100644 --- a/test/legacy_test/test_cross_op.py +++ b/test/legacy_test/test_cross_op.py @@ -32,6 +32,17 @@ def setUp(self): 'X': np.random.random(self.shape).astype(self.dtype), 'Y': np.random.random(self.shape).astype(self.dtype), } + if self.dtype is np.complex64 or self.dtype is np.complex128: + self.inputs = { + 'X': ( + np.random.random(self.shape) + + 1j * np.random.random(self.shape) + ).astype(self.dtype), + 'Y': ( + np.random.random(self.shape) + + 1j * np.random.random(self.shape) + ).astype(self.dtype), + } self.init_output() def initTestCase(self): @@ -81,6 +92,30 @@ def init_output(self): self.outputs = {'Out': np.array(z_list).reshape(self.shape)} +class TestCrossComplex64Op(TestCrossOp): + def initTestCase(self): + self.shape = (2048, 3) + self.dtype = np.complex64 + + def init_output(self): + z_list = [] + for i in range(2048): + z_list.append(np.cross(self.inputs['X'][i], self.inputs['Y'][i])) + self.outputs = {'Out': np.array(z_list).reshape(self.shape)} + + +class TestCrossComplex128Op(TestCrossOp): + def initTestCase(self): + self.shape = (2048, 3) + self.dtype = np.complex128 + + def init_output(self): + z_list = [] + for i in range(2048): + z_list.append(np.cross(self.inputs['X'][i], self.inputs['Y'][i])) + self.outputs = {'Out': np.array(z_list).reshape(self.shape)} + + @unittest.skipIf( not core.is_compiled_with_cuda() or not core.is_bfloat16_supported(core.CUDAPlace(0)),