diff --git a/paddle/phi/common/complex.h b/paddle/phi/common/complex.h index 34605855137e0e..b3ad3fdc88a5cb 100644 --- a/paddle/phi/common/complex.h +++ b/paddle/phi/common/complex.h @@ -294,8 +294,10 @@ HOSTDEVICE inline complex& operator*=(complex& a, // NOLINT thrust::complex(b.real, b.imag)); return a; #else - a.real = a.real * b.real - a.imag * b.imag; - a.imag = a.imag * b.real + b.imag * a.real; + T r = a.real * b.real - a.imag * b.imag; + T i = a.imag * b.real + b.imag * a.real; + a.real = r; + a.imag = i; return a; #endif } diff --git a/paddle/phi/kernels/cpu/activation_grad_kernel.cc b/paddle/phi/kernels/cpu/activation_grad_kernel.cc index 84ec899d9d399b..cb821233004f80 100644 --- a/paddle/phi/kernels/cpu/activation_grad_kernel.cc +++ b/paddle/phi/kernels/cpu/activation_grad_kernel.cc @@ -304,7 +304,8 @@ PD_REGISTER_ACTIVATION_GRAD_KERNEL(elu_grad, EluGradKernel) PD_REGISTER_ACTIVATION_GRAD_KERNEL_WITH_COMPLEX(silu_grad, SiluGradKernel) PD_REGISTER_ACTIVATION_GRAD_KERNEL(mish_grad, MishGradKernel) PD_REGISTER_ACTIVATION_GRAD_KERNEL_WITH_COMPLEX(stanh_grad, STanhGradKernel) -PD_REGISTER_ACTIVATION_GRAD_KERNEL(reciprocal_grad, ReciprocalGradKernel) +PD_REGISTER_ACTIVATION_GRAD_KERNEL_WITH_COMPLEX(reciprocal_grad, + ReciprocalGradKernel) PD_REGISTER_ACTIVATION_GRAD_KERNEL(sqrt_grad, SqrtGradKernel) PD_REGISTER_ACTIVATION_GRAD_KERNEL(rsqrt_grad, RsqrtGradKernel) PD_REGISTER_ACTIVATION_GRAD_KERNEL_WITH_COMPLEX(softplus_grad, @@ -364,7 +365,9 @@ PD_REGISTER_KERNEL(square_grad, float, double, int, - int64_t) {} + int64_t, + phi::dtype::complex, + phi::dtype::complex) {} PD_REGISTER_KERNEL(square_double_grad, CPU, ALL_LAYOUT, @@ -373,7 +376,9 @@ PD_REGISTER_KERNEL(square_double_grad, double, phi::dtype::float16, int, - int64_t) {} + int64_t, + phi::dtype::complex, + phi::dtype::complex) {} PD_REGISTER_KERNEL(sin_double_grad, CPU, diff --git a/paddle/phi/kernels/cpu/activation_kernel.cc b/paddle/phi/kernels/cpu/activation_kernel.cc index e704eefc54ebb2..11312aa3a7972b 100644 --- a/paddle/phi/kernels/cpu/activation_kernel.cc +++ b/paddle/phi/kernels/cpu/activation_kernel.cc @@ -198,7 +198,7 @@ PD_REGISTER_ACTIVATION_KERNEL(elu, EluKernel) PD_REGISTER_ACTIVATION_KERNEL_WITH_COMPLEX(silu, SiluKernel) PD_REGISTER_ACTIVATION_KERNEL(mish, MishKernel) PD_REGISTER_ACTIVATION_KERNEL_WITH_COMPLEX(stanh, STanhKernel) -PD_REGISTER_ACTIVATION_KERNEL(reciprocal, ReciprocalKernel) +PD_REGISTER_ACTIVATION_KERNEL_WITH_COMPLEX(reciprocal, ReciprocalKernel) PD_REGISTER_ACTIVATION_KERNEL(sqrt, SqrtKernel) PD_REGISTER_ACTIVATION_KERNEL(rsqrt, RsqrtKernel) PD_REGISTER_ACTIVATION_KERNEL_WITH_COMPLEX(softplus, SoftplusKernel) @@ -228,8 +228,16 @@ PD_REGISTER_KERNEL(expm1, phi::dtype::complex) {} PD_REGISTER_KERNEL(logit, CPU, ALL_LAYOUT, phi::LogitKernel, float, double) {} -PD_REGISTER_KERNEL( - square, CPU, ALL_LAYOUT, phi::SquareKernel, float, double, int, int64_t) {} +PD_REGISTER_KERNEL(square, + CPU, + ALL_LAYOUT, + phi::SquareKernel, + float, + double, + int, + int64_t, + phi::dtype::complex, + phi::dtype::complex) {} PD_REGISTER_ACTIVATION_KERNEL_WITH_COMPLEX(softsign, SoftsignKernel) PD_REGISTER_ACTIVATION_KERNEL_WITH_COMPLEX(sigmoid, SigmoidKernel) PD_REGISTER_ACTIVATION_KERNEL_WITH_COMPLEX(logsigmoid, LogSigmoidKernel) diff --git a/paddle/phi/kernels/funcs/activation_functor.h b/paddle/phi/kernels/funcs/activation_functor.h index 06b59644cf11d4..2309f5fa30de28 100644 --- a/paddle/phi/kernels/funcs/activation_functor.h +++ b/paddle/phi/kernels/funcs/activation_functor.h @@ -323,6 +323,24 @@ struct ReciprocalGradFunctor : public BaseActivationFunctor { } }; +template +struct ReciprocalGradFunctor> + : public BaseActivationFunctor> { + template + void operator()(Device d, X x UNUSED, Out out, dOut dout, dX dx) const { + dx.device(d) = dout * static_cast>(-1) * + (out * out).unaryExpr(Conj()); + } + + static constexpr ActBwdOpFwdDeps FwdDeps() { + return ActBwdOpFwdDeps::kDepOut; + } +}; + // 1st reverse grad // y = cos(x) // x --> y @@ -704,6 +722,22 @@ struct SquareGradFunctor : public BaseActivationFunctor { static constexpr ActBwdOpFwdDeps FwdDeps() { return ActBwdOpFwdDeps::kDepX; } }; +template +struct SquareGradFunctor> + : public BaseActivationFunctor> { + template + void operator()(Device d, X x, Out out UNUSED, dOut dout, dX dx) const { + dx.device(d) = + dout * static_cast>(2) * x.unaryExpr(Conj()); + } + + static constexpr ActBwdOpFwdDeps FwdDeps() { return ActBwdOpFwdDeps::kDepX; } +}; + // sqrt(x) = x^(1/2) template struct SqrtFunctor : public BaseActivationFunctor { @@ -3220,6 +3254,20 @@ struct CudaSquareGradFunctor : public BaseActivationFunctor { static constexpr ActBwdOpFwdDeps FwdDeps() { return ActBwdOpFwdDeps::kDepX; } }; +template +struct CudaSquareGradFunctor> + : public BaseActivationFunctor> { + ComplexType two = static_cast>(2.0f); + + // dx = dout * 2 * x + __device__ __forceinline__ ComplexType operator()( + const ComplexType dout, const ComplexType x) const { + return static_cast>(dout * two * conj(x)); + } + + static constexpr ActBwdOpFwdDeps FwdDeps() { return ActBwdOpFwdDeps::kDepX; } +}; + template struct CudaExpGradFunctor : public BaseActivationFunctor { // dx = dout * out @@ -3268,6 +3316,20 @@ struct CudaReciprocalGradFunctor : public BaseActivationFunctor { } }; +template +struct CudaReciprocalGradFunctor> + : public BaseActivationFunctor> { + // dx = -dout * out^2 + __device__ __forceinline__ ComplexType operator()( + const ComplexType dout, const ComplexType out) const { + return -dout * conj(out * out); + } + + static constexpr ActBwdOpFwdDeps FwdDeps() { + return ActBwdOpFwdDeps::kDepOut; + } +}; + template struct CudaExpm1Functor : public BaseActivationFunctor { using U = typename std::conditional_t::value, float, T>; diff --git a/paddle/phi/kernels/gpu/activation_grad_kernel.cu b/paddle/phi/kernels/gpu/activation_grad_kernel.cu index 2a1c6759bbc8ba..7af857345cdd67 100644 --- a/paddle/phi/kernels/gpu/activation_grad_kernel.cu +++ b/paddle/phi/kernels/gpu/activation_grad_kernel.cu @@ -380,7 +380,8 @@ PD_REGISTER_ACTIVATION_GRAD_KERNEL(thresholded_relu_grad, PD_REGISTER_ACTIVATION_GRAD_KERNEL(relu6_grad, Relu6GradKernel) PD_REGISTER_ACTIVATION_GRAD_KERNEL(mish_grad, MishGradKernel) PD_REGISTER_ACTIVATION_GRAD_KERNEL_WITH_COMPLEX(stanh_grad, STanhGradKernel) -PD_REGISTER_ACTIVATION_GRAD_KERNEL(reciprocal_grad, ReciprocalGradKernel) +PD_REGISTER_ACTIVATION_GRAD_KERNEL_WITH_COMPLEX(reciprocal_grad, + ReciprocalGradKernel) PD_REGISTER_ACTIVATION_GRAD_KERNEL_WITH_COMPLEX(softplus_grad, SoftplusGradKernel) PD_REGISTER_ACTIVATION_GRAD_KERNEL_WITH_COMPLEX(softplus_double_grad, @@ -431,7 +432,9 @@ PD_REGISTER_KERNEL(square_grad, int, int64_t, phi::dtype::float16, - phi::dtype::bfloat16) {} + phi::dtype::bfloat16, + phi::dtype::complex, + phi::dtype::complex) {} PD_REGISTER_KERNEL(square_double_grad, GPU, ALL_LAYOUT, @@ -441,7 +444,9 @@ PD_REGISTER_KERNEL(square_double_grad, int, int64_t, phi::dtype::float16, - phi::dtype::bfloat16) {} + phi::dtype::bfloat16, + phi::dtype::complex, + phi::dtype::complex) {} PD_REGISTER_KERNEL(sin_double_grad, GPU, diff --git a/paddle/phi/kernels/gpu/activation_kernel.cu b/paddle/phi/kernels/gpu/activation_kernel.cu index 34bbbfbd11859e..e8dadf31fd945f 100644 --- a/paddle/phi/kernels/gpu/activation_kernel.cu +++ b/paddle/phi/kernels/gpu/activation_kernel.cu @@ -247,7 +247,7 @@ PD_REGISTER_ACTIVATION_KERNEL(relu6, Relu6Kernel) PD_REGISTER_ACTIVATION_KERNEL(leaky_relu, LeakyReluKernel) PD_REGISTER_ACTIVATION_KERNEL(mish, MishKernel) PD_REGISTER_ACTIVATION_KERNEL_WITH_COMPLEX(stanh, StanhKernel) -PD_REGISTER_ACTIVATION_KERNEL(reciprocal, ReciprocalKernel) +PD_REGISTER_ACTIVATION_KERNEL_WITH_COMPLEX(reciprocal, ReciprocalKernel) PD_REGISTER_ACTIVATION_KERNEL(sqrt, SqrtKernel) PD_REGISTER_ACTIVATION_KERNEL(rsqrt, RsqrtKernel) PD_REGISTER_ACTIVATION_KERNEL_WITH_COMPLEX(softplus, SoftplusKernel) @@ -285,7 +285,9 @@ PD_REGISTER_KERNEL(square, int, int64_t, phi::dtype::float16, - phi::dtype::bfloat16) {} + phi::dtype::bfloat16, + phi::dtype::complex, + phi::dtype::complex) {} PD_REGISTER_ACTIVATION_KERNEL(hard_shrink, HardShrinkKernel) PD_REGISTER_ACTIVATION_KERNEL(softshrink, SoftShrinkKernel) diff --git a/test/legacy_test/test_activation_op.py b/test/legacy_test/test_activation_op.py index 7549b68dc336b0..c561fa7d38dd32 100644 --- a/test/legacy_test/test_activation_op.py +++ b/test/legacy_test/test_activation_op.py @@ -3414,6 +3414,11 @@ def setUp(self): np.random.seed(1024) x = np.random.uniform(1, 2, self.shape).astype(self.dtype) + if self.dtype == np.complex64 or self.dtype == np.complex128: + x = ( + np.random.uniform(-1, 1, self.shape) + + 1j * np.random.uniform(-1, 1, self.shape) + ).astype(self.dtype) out = np.reciprocal(x) self.inputs = {'X': OpTest.np_dtype_to_base_dtype(x)} @@ -3423,12 +3428,29 @@ def setUp(self): def test_check_grad(self): if self.dtype == np.float16: return - self.check_grad(['X'], 'Out', max_relative_error=0.01, check_pir=True) + if self.dtype == np.complex64 or self.dtype == np.complex128: + self.check_grad( + ['X'], 'Out', max_relative_error=0.03, check_pir=True + ) + else: + self.check_grad( + ['X'], 'Out', max_relative_error=0.01, check_pir=True + ) def test_check_output(self): self.check_output(check_pir=True) +class TestReciprocal_Complex64(TestReciprocal): + def init_dtype(self): + self.dtype = np.complex64 + + +class TestReciprocal_Complex128(TestReciprocal): + def init_dtype(self): + self.dtype = np.complex128 + + class TestReciprocal_ZeroDim(TestReciprocal): def init_shape(self): self.shape = [] @@ -3799,6 +3821,11 @@ def setUp(self): np.random.seed(1024) x = np.random.uniform(0.1, 1, self.shape).astype(self.dtype) + if self.dtype == np.complex64 or self.dtype == np.complex128: + x = ( + np.random.uniform(-1, 1, self.shape) + + 1j * np.random.uniform(-1, 1, self.shape) + ).astype(self.dtype) out = np.square(x) self.inputs = {'X': OpTest.np_dtype_to_base_dtype(x)} @@ -3814,6 +3841,16 @@ def test_check_output(self): self.check_output(check_pir=True) +class TestSquare_Complex64(TestSquare): + def init_dtype(self): + self.dtype = np.complex64 + + +class TestSquare_Complex128(TestSquare): + def init_dtype(self): + self.dtype = np.complex128 + + class TestSquare_ZeroDim(TestSquare): def init_shape(self): self.shape = []