diff --git a/paddle/phi/kernels/cpu/activation_grad_kernel.cc b/paddle/phi/kernels/cpu/activation_grad_kernel.cc index cb821233004f80..3f26f8c388e667 100644 --- a/paddle/phi/kernels/cpu/activation_grad_kernel.cc +++ b/paddle/phi/kernels/cpu/activation_grad_kernel.cc @@ -438,11 +438,12 @@ PD_REGISTER_ACTIVATION_GRAD_KERNEL_WITH_COMPLEX(sigmoid_triple_grad, PD_REGISTER_ACTIVATION_GRAD_KERNEL(hardsigmoid_grad, HardSigmoidGradKernel) PD_REGISTER_ACTIVATION_GRAD_KERNEL_WITH_COMPLEX(logsigmoid_grad, LogSigmoidGradKernel) -PD_REGISTER_ACTIVATION_GRAD_KERNEL(log_grad, LogGradKernel) -PD_REGISTER_ACTIVATION_GRAD_KERNEL(log2_grad, Log2GradKernel) -PD_REGISTER_ACTIVATION_GRAD_KERNEL(log10_grad, Log10GradKernel) -PD_REGISTER_ACTIVATION_GRAD_KERNEL(log1p_grad, Log1pGradKernel) -PD_REGISTER_ACTIVATION_DOUBLE_GRAD_KERNEL(log_double_grad, LogDoubleGradKernel) +PD_REGISTER_ACTIVATION_GRAD_KERNEL_WITH_COMPLEX(log_grad, LogGradKernel) +PD_REGISTER_ACTIVATION_GRAD_KERNEL_WITH_COMPLEX(log2_grad, Log2GradKernel) +PD_REGISTER_ACTIVATION_GRAD_KERNEL_WITH_COMPLEX(log10_grad, Log10GradKernel) +PD_REGISTER_ACTIVATION_GRAD_KERNEL_WITH_COMPLEX(log1p_grad, Log1pGradKernel) +PD_REGISTER_ACTIVATION_DOUBLE_GRAD_KERNEL_WITH_COMPLEX(log_double_grad, + LogDoubleGradKernel) PD_REGISTER_ACTIVATION_GRAD_KERNEL_WITH_COMPLEX(hardswish_grad, HardSwishGradKernel) PD_REGISTER_ACTIVATION_GRAD_KERNEL(swish_grad, SwishGradKernel) diff --git a/paddle/phi/kernels/cpu/activation_kernel.cc b/paddle/phi/kernels/cpu/activation_kernel.cc index 11312aa3a7972b..92acf104fedcf6 100644 --- a/paddle/phi/kernels/cpu/activation_kernel.cc +++ b/paddle/phi/kernels/cpu/activation_kernel.cc @@ -254,7 +254,9 @@ PD_REGISTER_KERNEL(log, int, int64_t, phi::dtype::float16, - phi::dtype::bfloat16) {} + phi::dtype::bfloat16, + phi::dtype::complex, + phi::dtype::complex) {} PD_REGISTER_KERNEL(log2, CPU, ALL_LAYOUT, @@ -264,7 +266,9 @@ PD_REGISTER_KERNEL(log2, int, int64_t, phi::dtype::float16, - phi::dtype::bfloat16) {} + phi::dtype::bfloat16, + phi::dtype::complex, + phi::dtype::complex) {} PD_REGISTER_KERNEL(log10, CPU, ALL_LAYOUT, @@ -274,7 +278,9 @@ PD_REGISTER_KERNEL(log10, int, int64_t, phi::dtype::float16, - phi::dtype::bfloat16) {} + phi::dtype::bfloat16, + phi::dtype::complex, + phi::dtype::complex) {} PD_REGISTER_KERNEL(log1p, CPU, ALL_LAYOUT, @@ -284,7 +290,9 @@ PD_REGISTER_KERNEL(log1p, int, int64_t, phi::dtype::float16, - phi::dtype::bfloat16) {} + phi::dtype::bfloat16, + phi::dtype::complex, + phi::dtype::complex) {} PD_REGISTER_ACTIVATION_KERNEL_WITH_COMPLEX(hardswish, HardSwishKernel) PD_REGISTER_ACTIVATION_KERNEL(round, RoundKernel) diff --git a/paddle/phi/kernels/funcs/activation_functor.h b/paddle/phi/kernels/funcs/activation_functor.h index 8b83fcb0d10c13..ba1d9873ec2a47 100644 --- a/paddle/phi/kernels/funcs/activation_functor.h +++ b/paddle/phi/kernels/funcs/activation_functor.h @@ -2445,6 +2445,13 @@ struct Log { HOSTDEVICE T operator()(const T& val) const { return std::log(val); } }; +template +struct Log> { + HOSTDEVICE ComplexType operator()(const ComplexType& val) const { + return ComplexType(std::log(std::complex(val))); + } +}; + template <> struct Log { HOSTDEVICE dtype::float16 operator()(const dtype::float16& val) const { @@ -2484,11 +2491,35 @@ struct LogGradFunctor : public BaseActivationFunctor { static constexpr ActBwdOpFwdDeps FwdDeps() { return ActBwdOpFwdDeps::kDepX; } }; +template +struct LogGradFunctor> + : public BaseActivationFunctor> { + template + void operator()(Device d, X x, Out out UNUSED, dOut dout, dX dx) const { + dx.device(d) = + dout * (static_cast>(1) / x).unaryExpr(Conj()); + } + + static constexpr ActBwdOpFwdDeps FwdDeps() { return ActBwdOpFwdDeps::kDepX; } +}; + template struct Log2 { HOSTDEVICE T operator()(const T& val) const { return std::log2(val); } }; +template +struct Log2> { + HOSTDEVICE ComplexType operator()(const ComplexType& val) const { + return ComplexType(std::log(std::complex(val)) / + std::log(std::complex(2))); + } +}; + template <> struct Log2 { HOSTDEVICE dtype::float16 operator()(const dtype::float16& val) const { @@ -2529,11 +2560,35 @@ struct Log2GradFunctor : public BaseActivationFunctor { static constexpr ActBwdOpFwdDeps FwdDeps() { return ActBwdOpFwdDeps::kDepX; } }; +template +struct Log2GradFunctor> + : public BaseActivationFunctor> { + template + void operator()(Device d, X x, Out out UNUSED, dOut dout, dX dx) const { + dx.device(d) = dout * (static_cast>(1) / + (x * static_cast>(log(2)))) + .unaryExpr(Conj()); + } + + static constexpr ActBwdOpFwdDeps FwdDeps() { return ActBwdOpFwdDeps::kDepX; } +}; + template struct Log10 { HOSTDEVICE T operator()(const T& val) const { return std::log10(val); } }; +template +struct Log10> { + HOSTDEVICE ComplexType operator()(const ComplexType& val) const { + return ComplexType(std::log10(std::complex(val))); + } +}; + template <> struct Log10 { HOSTDEVICE dtype::float16 operator()(const dtype::float16& val) const { @@ -2574,11 +2629,35 @@ struct Log10GradFunctor : public BaseActivationFunctor { static constexpr ActBwdOpFwdDeps FwdDeps() { return ActBwdOpFwdDeps::kDepX; } }; +template +struct Log10GradFunctor> + : public BaseActivationFunctor> { + template + void operator()(Device d, X x, Out out UNUSED, dOut dout, dX dx) const { + dx.device(d) = dout * (static_cast>(1) / + (x * static_cast>(log(10)))) + .unaryExpr(Conj()); + } + + static constexpr ActBwdOpFwdDeps FwdDeps() { return ActBwdOpFwdDeps::kDepX; } +}; + template struct Log1p { HOSTDEVICE T operator()(const T& val) const { return std::log1p(val); } }; +template +struct Log1p> { + HOSTDEVICE ComplexType operator()(const ComplexType& val) const { + return ComplexType(std::log(std::complex(1) + std::complex(val))); + } +}; + template <> struct Log1p { HOSTDEVICE dtype::float16 operator()(const dtype::float16& val) const { @@ -2618,6 +2697,23 @@ struct Log1pGradFunctor : public BaseActivationFunctor { static constexpr ActBwdOpFwdDeps FwdDeps() { return ActBwdOpFwdDeps::kDepX; } }; +template +struct Log1pGradFunctor> + : public BaseActivationFunctor> { + template + void operator()(Device d, X x, Out out UNUSED, dOut dout, dX dx) const { + dx.device(d) = dout * (static_cast>(1) / + (x + static_cast>(1))) + .unaryExpr(Conj()); + } + + static constexpr ActBwdOpFwdDeps FwdDeps() { return ActBwdOpFwdDeps::kDepX; } +}; + template struct LogGradGradFunctor : public BaseActivationFunctor { template @@ -2651,6 +2747,42 @@ struct LogGradGradFunctor : public BaseActivationFunctor { static constexpr ActBwdOpFwdDeps FwdDeps() { return ActBwdOpFwdDeps::kDepX; } }; +template +struct LogGradGradFunctor> + : public BaseActivationFunctor> { + template + void operator()(const Device& dev, + const DenseTensor* X, + const DenseTensor* ddX, + DenseTensor* ddOut, + const DenseTensor* dOut, + DenseTensor* dX) const { + auto* d = dev.eigen_device(); + auto ddx = EigenVector>::Flatten( + GET_DATA_SAFELY(ddX, "Input", "DDX", "LogGradGrad")); + auto x = EigenVector>::Flatten( + GET_DATA_SAFELY(X, "Input", "X", "LogGradGrad")); + // ddout = ddx / x; dx = -(dout / x) * (ddx / x) + // calculate dx first, so ddout can inplace ddx + if (dX) { + auto dout = EigenVector>::Flatten( + GET_DATA_SAFELY(dOut, "Output", "DOut", "LogGradGrad")); + auto dx = EigenVector>::Flatten( + GET_DATA_SAFELY(dX, "Output", "DX", "LogGradGrad")); + dx.device(*d) = dout * static_cast>(-1) * ddx / + (x * x).unaryExpr(Conj()); + } + if (ddOut) { + auto ddout = EigenVector>::Flatten( + GET_DATA_SAFELY(ddOut, "Output", "DDOut", "LogGradGrad")); + ddout.device(*d) = + ddx * static_cast>(1) / x.unaryExpr(Conj()); + } + } + + static constexpr ActBwdOpFwdDeps FwdDeps() { return ActBwdOpFwdDeps::kDepX; } +}; + // HardSwish = min(max(0, x+3), 6) * x / 6 template struct HardSwishFunctor : public BaseActivationFunctor { @@ -4642,6 +4774,16 @@ struct CudaLogFunctor : public BaseActivationFunctor { } }; +template +struct CudaLogFunctor> + : public BaseActivationFunctor> { + // log(x) = log(x) + __device__ __forceinline__ ComplexType operator()( + const ComplexType arg_x) const { + return static_cast>(log(arg_x)); + } +}; + template struct CudaLogGradFunctor : public BaseActivationFunctor { // dx = dout / x @@ -4652,6 +4794,18 @@ struct CudaLogGradFunctor : public BaseActivationFunctor { static constexpr ActBwdOpFwdDeps FwdDeps() { return ActBwdOpFwdDeps::kDepX; } }; +template +struct CudaLogGradFunctor> + : public BaseActivationFunctor> { + // dx = dout / conj(x) + __device__ __forceinline__ ComplexType operator()( + const ComplexType dout, const ComplexType x) const { + return dout / conj(x); + } + + static constexpr ActBwdOpFwdDeps FwdDeps() { return ActBwdOpFwdDeps::kDepX; } +}; + template struct CudaLog1pFunctor : public BaseActivationFunctor { using MPType = typename phi::dtype::MPTypeTrait::Type; @@ -4665,6 +4819,17 @@ struct CudaLog1pFunctor : public BaseActivationFunctor { } }; +template +struct CudaLog1pFunctor> + : public BaseActivationFunctor> { + // log1p(x) = log(1 + x) + __device__ __forceinline__ ComplexType operator()( + const ComplexType arg_x) const { + return static_cast>( + log(static_cast>(1) + arg_x)); + } +}; + template struct CudaLog1pGradFunctor : public BaseActivationFunctor { T one = static_cast(1.0f); @@ -4677,6 +4842,20 @@ struct CudaLog1pGradFunctor : public BaseActivationFunctor { static constexpr ActBwdOpFwdDeps FwdDeps() { return ActBwdOpFwdDeps::kDepX; } }; +template +struct CudaLog1pGradFunctor> + : public BaseActivationFunctor> { + ComplexType one = static_cast>(1.0f); + + // dx = dout / conj(1 + x) + __device__ __forceinline__ ComplexType operator()( + const ComplexType dout, const ComplexType x) const { + return dout / conj(one + x); + } + + static constexpr ActBwdOpFwdDeps FwdDeps() { return ActBwdOpFwdDeps::kDepX; } +}; + template __device__ __forceinline__ std::conditional_t::value, float, T> @@ -4709,6 +4888,17 @@ struct CudaLog2Functor : public BaseActivationFunctor { } }; +template +struct CudaLog2Functor> + : public BaseActivationFunctor> { + // log2(x) = log(x)/log(2) + __device__ __forceinline__ ComplexType operator()( + const ComplexType arg_x) const { + return static_cast>(log(arg_x) / + static_cast>(log(2.0f))); + } +}; + template struct CudaLog2GradFunctor : public BaseActivationFunctor { using MPType = typename phi::dtype::MPTypeTrait::Type; @@ -4722,6 +4912,18 @@ struct CudaLog2GradFunctor : public BaseActivationFunctor { static constexpr ActBwdOpFwdDeps FwdDeps() { return ActBwdOpFwdDeps::kDepX; } }; +template +struct CudaLog2GradFunctor> + : public BaseActivationFunctor> { + // dx = dout / conj(x * log(2)) + __device__ __forceinline__ ComplexType operator()( + const ComplexType dout, const ComplexType x) const { + return dout / conj(x * static_cast>(log(2.0f))); + } + + static constexpr ActBwdOpFwdDeps FwdDeps() { return ActBwdOpFwdDeps::kDepX; } +}; + template __device__ __forceinline__ std::conditional_t::value, float, T> @@ -4754,6 +4956,17 @@ struct CudaLog10Functor : public BaseActivationFunctor { } }; +template +struct CudaLog10Functor> + : public BaseActivationFunctor> { + // log10(x) = log(x)/log(10) + __device__ __forceinline__ ComplexType operator()( + const ComplexType arg_x) const { + return static_cast>(log(arg_x) / + static_cast>(log(10.0f))); + } +}; + template struct CudaLog10GradFunctor : public BaseActivationFunctor { using MPType = typename phi::dtype::MPTypeTrait::Type; @@ -4767,6 +4980,18 @@ struct CudaLog10GradFunctor : public BaseActivationFunctor { static constexpr ActBwdOpFwdDeps FwdDeps() { return ActBwdOpFwdDeps::kDepX; } }; +template +struct CudaLog10GradFunctor> + : public BaseActivationFunctor> { + // dx = dout / conj(x * log(10)) + __device__ __forceinline__ ComplexType operator()( + const ComplexType dout, const ComplexType x) const { + return dout / conj(x * static_cast>(log(10.0f))); + } + + static constexpr ActBwdOpFwdDeps FwdDeps() { return ActBwdOpFwdDeps::kDepX; } +}; + template struct CudaSwishFunctor : public BaseActivationFunctor { using MPType = typename phi::dtype::MPTypeTrait::Type; diff --git a/paddle/phi/kernels/gpu/activation_grad_kernel.cu b/paddle/phi/kernels/gpu/activation_grad_kernel.cu index 7af857345cdd67..594eefe5b8de17 100644 --- a/paddle/phi/kernels/gpu/activation_grad_kernel.cu +++ b/paddle/phi/kernels/gpu/activation_grad_kernel.cu @@ -510,10 +510,10 @@ PD_REGISTER_ACTIVATION_GRAD_KERNEL_WITH_COMPLEX(sigmoid_triple_grad, PD_REGISTER_ACTIVATION_GRAD_KERNEL(hardsigmoid_grad, HardSigmoidGradKernel) PD_REGISTER_ACTIVATION_GRAD_KERNEL_WITH_COMPLEX(logsigmoid_grad, LogSigmoidGradKernel) -PD_REGISTER_ACTIVATION_GRAD_KERNEL(log_grad, LogGradKernel) -PD_REGISTER_ACTIVATION_GRAD_KERNEL(log2_grad, Log2GradKernel) -PD_REGISTER_ACTIVATION_GRAD_KERNEL(log10_grad, Log10GradKernel) -PD_REGISTER_ACTIVATION_GRAD_KERNEL(log1p_grad, Log1pGradKernel) +PD_REGISTER_ACTIVATION_GRAD_KERNEL_WITH_COMPLEX(log_grad, LogGradKernel) +PD_REGISTER_ACTIVATION_GRAD_KERNEL_WITH_COMPLEX(log2_grad, Log2GradKernel) +PD_REGISTER_ACTIVATION_GRAD_KERNEL_WITH_COMPLEX(log10_grad, Log10GradKernel) +PD_REGISTER_ACTIVATION_GRAD_KERNEL_WITH_COMPLEX(log1p_grad, Log1pGradKernel) PD_REGISTER_KERNEL(log_double_grad, GPU, ALL_LAYOUT, @@ -521,7 +521,9 @@ PD_REGISTER_KERNEL(log_double_grad, float, double, phi::dtype::float16, - phi::dtype::bfloat16) {} + phi::dtype::bfloat16, + phi::dtype::complex, + phi::dtype::complex) {} PD_REGISTER_ACTIVATION_GRAD_KERNEL_WITH_COMPLEX(hardswish_grad, HardSwishGradKernel) PD_REGISTER_ACTIVATION_GRAD_KERNEL(swish_grad, SwishGradKernel) diff --git a/paddle/phi/kernels/gpu/activation_kernel.cu b/paddle/phi/kernels/gpu/activation_kernel.cu index e8dadf31fd945f..1bf3d92d806207 100644 --- a/paddle/phi/kernels/gpu/activation_kernel.cu +++ b/paddle/phi/kernels/gpu/activation_kernel.cu @@ -315,7 +315,9 @@ PD_REGISTER_KERNEL(log, int, int64_t, phi::dtype::float16, - phi::dtype::bfloat16) {} + phi::dtype::bfloat16, + phi::dtype::complex, + phi::dtype::complex) {} PD_REGISTER_KERNEL(log2, GPU, ALL_LAYOUT, @@ -325,7 +327,9 @@ PD_REGISTER_KERNEL(log2, int, int64_t, phi::dtype::float16, - phi::dtype::bfloat16) {} + phi::dtype::bfloat16, + phi::dtype::complex, + phi::dtype::complex) {} PD_REGISTER_KERNEL(log10, GPU, ALL_LAYOUT, @@ -335,7 +339,9 @@ PD_REGISTER_KERNEL(log10, int, int64_t, phi::dtype::float16, - phi::dtype::bfloat16) {} + phi::dtype::bfloat16, + phi::dtype::complex, + phi::dtype::complex) {} PD_REGISTER_KERNEL(log1p, GPU, ALL_LAYOUT, @@ -345,7 +351,9 @@ PD_REGISTER_KERNEL(log1p, int, int64_t, phi::dtype::float16, - phi::dtype::bfloat16) {} + phi::dtype::bfloat16, + phi::dtype::complex, + phi::dtype::complex) {} PD_REGISTER_KERNEL(pow, GPU, ALL_LAYOUT, diff --git a/python/paddle/tensor/math.py b/python/paddle/tensor/math.py index a931912ae95727..eace002859e865 100644 --- a/python/paddle/tensor/math.py +++ b/python/paddle/tensor/math.py @@ -157,7 +157,7 @@ def log(x, name=None): Out = \ln(x) Args: - x (Tensor): Input Tensor. Must be one of the following types: int32, int64, float16, bfloat16, float32, float64. + x (Tensor): Input Tensor. Must be one of the following types: int32, int64, float16, bfloat16, float32, float64, complex64, complex128. name (str|None): The default value is None. Normally there is no need for user to set this property. For more information, please refer to :ref:`api_guide_Name` @@ -183,7 +183,16 @@ def log(x, name=None): check_variable_and_dtype( x, 'x', - ['int32', 'int64', 'uint16', 'float16', 'float32', 'float64'], + [ + 'int32', + 'int64', + 'uint16', + 'float16', + 'float32', + 'float64', + 'complex64', + 'complex128', + ], "log", ) inputs = {'X': [x]} @@ -3303,7 +3312,7 @@ def log1p(x, name=None): Out = \ln(x+1) Args: - x (Tensor): Input Tensor. Must be one of the following types: int32, int64, float16, bfloat16, float32, float64. + x (Tensor): Input Tensor. Must be one of the following types: int32, int64, float16, bfloat16, float32, float64, complex64, complex128. name (str, optional): Name for the operation (optional, default is None). For more information, please refer to :ref:`api_guide_Name`. Returns: @@ -3328,7 +3337,16 @@ def log1p(x, name=None): check_variable_and_dtype( x, 'x', - ['int32', 'int64', 'float16', 'uint16', 'float32', 'float64'], + [ + 'int32', + 'int64', + 'float16', + 'uint16', + 'float32', + 'float64', + 'complex64', + 'complex128', + ], "log1p", ) inputs = {'X': [x]} @@ -3359,7 +3377,7 @@ def log2(x, name=None): Out = \log_2x Args: - x (Tensor): Input tensor must be one of the following types: int32, int64, float16, bfloat16, float32, float64. + x (Tensor): Input tensor must be one of the following types: int32, int64, float16, bfloat16, float32, float64, complex64, complex128. name (str, optional): Name for the operation (optional, default is None). For more information, please refer to :ref:`api_guide_Name`. @@ -3402,7 +3420,16 @@ def log2(x, name=None): check_variable_and_dtype( x, 'x', - ['int32', 'int64', 'float16', 'uint16', 'float32', 'float64'], + [ + 'int32', + 'int64', + 'float16', + 'uint16', + 'float32', + 'float64', + 'complex64', + 'complex128', + ], "log2", ) inputs = {'X': [x]} @@ -3433,7 +3460,7 @@ def log10(x, name=None): Out = \log_10_x Args: - x (Tensor): Input tensor must be one of the following types: int32, int64, float16, bfloat16, float32, float64. + x (Tensor): Input tensor must be one of the following types: int32, int64, float16, bfloat16, float32, float64, complex64, complex128. name (str, optional): Name for the operation (optional, default is None). For more information, please refer to :ref:`api_guide_Name`. @@ -3476,7 +3503,16 @@ def log10(x, name=None): check_variable_and_dtype( x, 'x', - ['int32', 'int64', 'float16', 'uint16', 'float32', 'float64'], + [ + 'int32', + 'int64', + 'float16', + 'uint16', + 'float32', + 'float64', + 'complex64', + 'complex128', + ], "log10", ) inputs = {'X': [x]} diff --git a/test/legacy_test/test_activation_op.py b/test/legacy_test/test_activation_op.py index ffd8e85d2cd247..de8babf716cf1e 100644 --- a/test/legacy_test/test_activation_op.py +++ b/test/legacy_test/test_activation_op.py @@ -29,6 +29,8 @@ from paddle.base.layer_helper import LayerHelper from paddle.pir_utils import test_with_pir_api +devices = ['cpu', 'gpu'] + @contextmanager def dynamic_guard(): @@ -3739,6 +3741,11 @@ def setUp(self): np.random.seed(1024) x = np.random.uniform(0.1, 1, self.shape).astype(self.dtype) + if self.dtype == np.complex64 or self.dtype == np.complex128: + x = ( + np.random.uniform(0.1, 1, self.shape) + + 1j * np.random.uniform(0.1, 1, self.shape) + ).astype(self.dtype) out = np.log(x) self.inputs = {'X': OpTest.np_dtype_to_base_dtype(x)} @@ -3766,6 +3773,56 @@ def test_check_grad(self): ) +class TestLog_Complex64(TestLog): + def init_dtype(self): + self.dtype = np.complex64 + + def test_check_grad(self): + self.check_grad( + ['X'], 'Out', check_pir=True, check_pir_onednn=self.check_pir_onednn + ) + + def test_check_output(self): + self.check_output( + check_pir=True, check_pir_onednn=self.check_pir_onednn + ) + + def test_api_complex(self): + paddle.disable_static() + for device in devices: + if device == 'cpu' or ( + device == 'gpu' and paddle.is_compiled_with_cuda() + ): + np_x = np.array([[2, 3, 4], [7, 8, 9]], dtype=self.dtype) + x = paddle.to_tensor(np_x, dtype=self.dtype, place=device) + y = paddle.log(x) + x_expect = np.log(np_x) + np.testing.assert_allclose(y.numpy(), x_expect, rtol=1e-3) + paddle.enable_static() + + def test_grad_grad(self): + paddle.disable_static() + x_numpy = ( + np.random.uniform(0.1, 1, self.shape) + + 1j * np.random.uniform(0.1, 1, self.shape) + ).astype(self.dtype) + + expected_ddx = np.conj(-1 / np.power(x_numpy, 2)) + + x = paddle.to_tensor(x_numpy, stop_gradient=False) + y = paddle.log(x) + dx = paddle.grad( + outputs=[y], inputs=[x], create_graph=True, retain_graph=True + )[0] + ddx = paddle.grad(outputs=[dx], inputs=[x], retain_graph=True)[0] + np.testing.assert_allclose(ddx.numpy(), expected_ddx, rtol=1e-3) + + +class TestLog_Complex128(TestLog_Complex64): + def init_dtype(self): + self.dtype = np.complex128 + + class Test_Log_Op_Fp16(unittest.TestCase): def test_api_fp16(self): with static_guard(): @@ -3819,6 +3876,11 @@ def setUp(self): self.init_shape() x = np.random.uniform(0.1, 1, self.shape).astype(self.dtype) + if self.dtype == np.complex64 or self.dtype == np.complex128: + x = ( + np.random.uniform(0.1, 1, self.shape) + + 1j * np.random.uniform(0.1, 1, self.shape) + ).astype(self.dtype) out = np.log2(x) self.inputs = {'X': OpTest.np_dtype_to_base_dtype(x)} @@ -3864,6 +3926,34 @@ def test_api(self): np.testing.assert_allclose(np_z, z_expected, rtol=1e-05) +class TestLog2_Complex64(TestLog2): + def init_dtype(self): + self.dtype = np.complex64 + + def test_check_output(self): + self.check_output( + check_pir=True, check_pir_onednn=self.check_pir_onednn + ) + + def test_api_complex(self): + paddle.disable_static() + for device in devices: + if device == 'cpu' or ( + device == 'gpu' and paddle.is_compiled_with_cuda() + ): + np_x = np.array([[2, 3, 4], [7, 8, 9]], dtype=self.dtype) + x = paddle.to_tensor(np_x, dtype=self.dtype, place=device) + y = paddle.log2(x) + x_expect = np.log2(np_x) + np.testing.assert_allclose(y.numpy(), x_expect, rtol=1e-3) + paddle.enable_static() + + +class TestLog2_Complex128(TestLog2_Complex64): + def init_dtype(self): + self.dtype = np.complex128 + + class TestLog2_ZeroDim(TestLog2): def init_shape(self): self.shape = [] @@ -3903,6 +3993,11 @@ def setUp(self): self.init_shape() x = np.random.uniform(0.1, 1, self.shape).astype(self.dtype) + if self.dtype == np.complex64 or self.dtype == np.complex128: + x = ( + np.random.uniform(0.1, 1, self.shape) + + 1j * np.random.uniform(0.1, 1, self.shape) + ).astype(self.dtype) out = np.log10(x) self.inputs = {'X': OpTest.np_dtype_to_base_dtype(x)} @@ -3922,6 +4017,29 @@ def test_check_grad(self): ) +class TestLog10_Complex64(TestLog10): + def init_dtype(self): + self.dtype = np.complex64 + + def test_api_complex(self): + paddle.disable_static() + for device in devices: + if device == 'cpu' or ( + device == 'gpu' and paddle.is_compiled_with_cuda() + ): + np_x = np.array([[2, 3, 4], [7, 8, 9]], dtype=self.dtype) + x = paddle.to_tensor(np_x, dtype=self.dtype, place=device) + y = paddle.log10(x) + x_expect = np.log10(np_x) + np.testing.assert_allclose(y.numpy(), x_expect, rtol=1e-3) + paddle.enable_static() + + +class TestLog10_Complex128(TestLog10_Complex64): + def init_dtype(self): + self.dtype = np.complex128 + + class TestLog10_ZeroDim(TestLog10): def init_shape(self): self.shape = [] @@ -3995,6 +4113,11 @@ def setUp(self): np.random.seed(1024) x = np.random.uniform(0.1, 1, self.shape).astype(self.dtype) + if self.dtype == np.complex64 or self.dtype == np.complex128: + x = ( + np.random.uniform(0.1, 1, self.shape) + + 1j * np.random.uniform(0.1, 1, self.shape) + ).astype(self.dtype) out = np.log1p(x) self.inputs = {'X': OpTest.np_dtype_to_base_dtype(x)} @@ -4014,6 +4137,29 @@ def test_check_grad(self): ) +class TestLog1p_Complex64(TestLog1p): + def init_dtype(self): + self.dtype = np.complex64 + + def test_api_complex(self): + paddle.disable_static() + for device in devices: + if device == 'cpu' or ( + device == 'gpu' and paddle.is_compiled_with_cuda() + ): + np_x = np.array([[2, 3, 4], [7, 8, 9]], dtype=self.dtype) + x = paddle.to_tensor(np_x, dtype=self.dtype, place=device) + y = paddle.log1p(x) + x_expect = np.log1p(np_x) + np.testing.assert_allclose(y.numpy(), x_expect, rtol=1e-3) + paddle.enable_static() + + +class TestLog1p_Complex128(TestLog1p_Complex64): + def init_dtype(self): + self.dtype = np.complex128 + + class Test_Log1p_Op_Fp16(unittest.TestCase): @test_with_pir_api def test_api_fp16(self):