Skip to content

Commit ed19f42

Browse files
authored
【complex op No.7】add complex support for Log/log10/log2/log1p (#62448)
* log complex * remove int backward * add device info * remove duplicate implementation * fix device info * add gradgrad test for log
1 parent 0a2e7b6 commit ed19f42

7 files changed

Lines changed: 452 additions & 26 deletions

File tree

paddle/phi/kernels/cpu/activation_grad_kernel.cc

Lines changed: 6 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -438,11 +438,12 @@ PD_REGISTER_ACTIVATION_GRAD_KERNEL_WITH_COMPLEX(sigmoid_triple_grad,
438438
PD_REGISTER_ACTIVATION_GRAD_KERNEL(hardsigmoid_grad, HardSigmoidGradKernel)
439439
PD_REGISTER_ACTIVATION_GRAD_KERNEL_WITH_COMPLEX(logsigmoid_grad,
440440
LogSigmoidGradKernel)
441-
PD_REGISTER_ACTIVATION_GRAD_KERNEL(log_grad, LogGradKernel)
442-
PD_REGISTER_ACTIVATION_GRAD_KERNEL(log2_grad, Log2GradKernel)
443-
PD_REGISTER_ACTIVATION_GRAD_KERNEL(log10_grad, Log10GradKernel)
444-
PD_REGISTER_ACTIVATION_GRAD_KERNEL(log1p_grad, Log1pGradKernel)
445-
PD_REGISTER_ACTIVATION_DOUBLE_GRAD_KERNEL(log_double_grad, LogDoubleGradKernel)
441+
PD_REGISTER_ACTIVATION_GRAD_KERNEL_WITH_COMPLEX(log_grad, LogGradKernel)
442+
PD_REGISTER_ACTIVATION_GRAD_KERNEL_WITH_COMPLEX(log2_grad, Log2GradKernel)
443+
PD_REGISTER_ACTIVATION_GRAD_KERNEL_WITH_COMPLEX(log10_grad, Log10GradKernel)
444+
PD_REGISTER_ACTIVATION_GRAD_KERNEL_WITH_COMPLEX(log1p_grad, Log1pGradKernel)
445+
PD_REGISTER_ACTIVATION_DOUBLE_GRAD_KERNEL_WITH_COMPLEX(log_double_grad,
446+
LogDoubleGradKernel)
446447
PD_REGISTER_ACTIVATION_GRAD_KERNEL_WITH_COMPLEX(hardswish_grad,
447448
HardSwishGradKernel)
448449
PD_REGISTER_ACTIVATION_GRAD_KERNEL(swish_grad, SwishGradKernel)

paddle/phi/kernels/cpu/activation_kernel.cc

Lines changed: 12 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -254,7 +254,9 @@ PD_REGISTER_KERNEL(log,
254254
int,
255255
int64_t,
256256
phi::dtype::float16,
257-
phi::dtype::bfloat16) {}
257+
phi::dtype::bfloat16,
258+
phi::dtype::complex<float>,
259+
phi::dtype::complex<double>) {}
258260
PD_REGISTER_KERNEL(log2,
259261
CPU,
260262
ALL_LAYOUT,
@@ -264,7 +266,9 @@ PD_REGISTER_KERNEL(log2,
264266
int,
265267
int64_t,
266268
phi::dtype::float16,
267-
phi::dtype::bfloat16) {}
269+
phi::dtype::bfloat16,
270+
phi::dtype::complex<float>,
271+
phi::dtype::complex<double>) {}
268272
PD_REGISTER_KERNEL(log10,
269273
CPU,
270274
ALL_LAYOUT,
@@ -274,7 +278,9 @@ PD_REGISTER_KERNEL(log10,
274278
int,
275279
int64_t,
276280
phi::dtype::float16,
277-
phi::dtype::bfloat16) {}
281+
phi::dtype::bfloat16,
282+
phi::dtype::complex<float>,
283+
phi::dtype::complex<double>) {}
278284
PD_REGISTER_KERNEL(log1p,
279285
CPU,
280286
ALL_LAYOUT,
@@ -284,7 +290,9 @@ PD_REGISTER_KERNEL(log1p,
284290
int,
285291
int64_t,
286292
phi::dtype::float16,
287-
phi::dtype::bfloat16) {}
293+
phi::dtype::bfloat16,
294+
phi::dtype::complex<float>,
295+
phi::dtype::complex<double>) {}
288296

289297
PD_REGISTER_ACTIVATION_KERNEL_WITH_COMPLEX(hardswish, HardSwishKernel)
290298
PD_REGISTER_ACTIVATION_KERNEL(round, RoundKernel)

paddle/phi/kernels/funcs/activation_functor.h

Lines changed: 225 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2445,6 +2445,13 @@ struct Log {
24452445
HOSTDEVICE T operator()(const T& val) const { return std::log(val); }
24462446
};
24472447

2448+
template <typename T>
2449+
struct Log<ComplexType<T>> {
2450+
HOSTDEVICE ComplexType<T> operator()(const ComplexType<T>& val) const {
2451+
return ComplexType<T>(std::log(std::complex<T>(val)));
2452+
}
2453+
};
2454+
24482455
template <>
24492456
struct Log<dtype::float16> {
24502457
HOSTDEVICE dtype::float16 operator()(const dtype::float16& val) const {
@@ -2484,11 +2491,35 @@ struct LogGradFunctor : public BaseActivationFunctor<T> {
24842491
static constexpr ActBwdOpFwdDeps FwdDeps() { return ActBwdOpFwdDeps::kDepX; }
24852492
};
24862493

2494+
template <typename T>
2495+
struct LogGradFunctor<ComplexType<T>>
2496+
: public BaseActivationFunctor<ComplexType<T>> {
2497+
template <typename Device,
2498+
typename X,
2499+
typename Out,
2500+
typename dOut,
2501+
typename dX>
2502+
void operator()(Device d, X x, Out out UNUSED, dOut dout, dX dx) const {
2503+
dx.device(d) =
2504+
dout * (static_cast<ComplexType<T>>(1) / x).unaryExpr(Conj<T>());
2505+
}
2506+
2507+
static constexpr ActBwdOpFwdDeps FwdDeps() { return ActBwdOpFwdDeps::kDepX; }
2508+
};
2509+
24872510
template <typename T>
24882511
struct Log2 {
24892512
HOSTDEVICE T operator()(const T& val) const { return std::log2(val); }
24902513
};
24912514

2515+
template <typename T>
2516+
struct Log2<ComplexType<T>> {
2517+
HOSTDEVICE ComplexType<T> operator()(const ComplexType<T>& val) const {
2518+
return ComplexType<T>(std::log(std::complex<T>(val)) /
2519+
std::log(std::complex<T>(2)));
2520+
}
2521+
};
2522+
24922523
template <>
24932524
struct Log2<dtype::float16> {
24942525
HOSTDEVICE dtype::float16 operator()(const dtype::float16& val) const {
@@ -2529,11 +2560,35 @@ struct Log2GradFunctor : public BaseActivationFunctor<T> {
25292560
static constexpr ActBwdOpFwdDeps FwdDeps() { return ActBwdOpFwdDeps::kDepX; }
25302561
};
25312562

2563+
template <typename T>
2564+
struct Log2GradFunctor<ComplexType<T>>
2565+
: public BaseActivationFunctor<ComplexType<T>> {
2566+
template <typename Device,
2567+
typename X,
2568+
typename Out,
2569+
typename dOut,
2570+
typename dX>
2571+
void operator()(Device d, X x, Out out UNUSED, dOut dout, dX dx) const {
2572+
dx.device(d) = dout * (static_cast<ComplexType<T>>(1) /
2573+
(x * static_cast<ComplexType<T>>(log(2))))
2574+
.unaryExpr(Conj<T>());
2575+
}
2576+
2577+
static constexpr ActBwdOpFwdDeps FwdDeps() { return ActBwdOpFwdDeps::kDepX; }
2578+
};
2579+
25322580
template <typename T>
25332581
struct Log10 {
25342582
HOSTDEVICE T operator()(const T& val) const { return std::log10(val); }
25352583
};
25362584

2585+
template <typename T>
2586+
struct Log10<ComplexType<T>> {
2587+
HOSTDEVICE ComplexType<T> operator()(const ComplexType<T>& val) const {
2588+
return ComplexType<T>(std::log10(std::complex<T>(val)));
2589+
}
2590+
};
2591+
25372592
template <>
25382593
struct Log10<dtype::float16> {
25392594
HOSTDEVICE dtype::float16 operator()(const dtype::float16& val) const {
@@ -2574,11 +2629,35 @@ struct Log10GradFunctor : public BaseActivationFunctor<T> {
25742629
static constexpr ActBwdOpFwdDeps FwdDeps() { return ActBwdOpFwdDeps::kDepX; }
25752630
};
25762631

2632+
template <typename T>
2633+
struct Log10GradFunctor<ComplexType<T>>
2634+
: public BaseActivationFunctor<ComplexType<T>> {
2635+
template <typename Device,
2636+
typename X,
2637+
typename Out,
2638+
typename dOut,
2639+
typename dX>
2640+
void operator()(Device d, X x, Out out UNUSED, dOut dout, dX dx) const {
2641+
dx.device(d) = dout * (static_cast<ComplexType<T>>(1) /
2642+
(x * static_cast<ComplexType<T>>(log(10))))
2643+
.unaryExpr(Conj<T>());
2644+
}
2645+
2646+
static constexpr ActBwdOpFwdDeps FwdDeps() { return ActBwdOpFwdDeps::kDepX; }
2647+
};
2648+
25772649
template <typename T>
25782650
struct Log1p {
25792651
HOSTDEVICE T operator()(const T& val) const { return std::log1p(val); }
25802652
};
25812653

2654+
template <typename T>
2655+
struct Log1p<ComplexType<T>> {
2656+
HOSTDEVICE ComplexType<T> operator()(const ComplexType<T>& val) const {
2657+
return ComplexType<T>(std::log(std::complex<T>(1) + std::complex<T>(val)));
2658+
}
2659+
};
2660+
25822661
template <>
25832662
struct Log1p<dtype::float16> {
25842663
HOSTDEVICE dtype::float16 operator()(const dtype::float16& val) const {
@@ -2618,6 +2697,23 @@ struct Log1pGradFunctor : public BaseActivationFunctor<T> {
26182697
static constexpr ActBwdOpFwdDeps FwdDeps() { return ActBwdOpFwdDeps::kDepX; }
26192698
};
26202699

2700+
template <typename T>
2701+
struct Log1pGradFunctor<ComplexType<T>>
2702+
: public BaseActivationFunctor<ComplexType<T>> {
2703+
template <typename Device,
2704+
typename X,
2705+
typename Out,
2706+
typename dOut,
2707+
typename dX>
2708+
void operator()(Device d, X x, Out out UNUSED, dOut dout, dX dx) const {
2709+
dx.device(d) = dout * (static_cast<ComplexType<T>>(1) /
2710+
(x + static_cast<ComplexType<T>>(1)))
2711+
.unaryExpr(Conj<T>());
2712+
}
2713+
2714+
static constexpr ActBwdOpFwdDeps FwdDeps() { return ActBwdOpFwdDeps::kDepX; }
2715+
};
2716+
26212717
template <typename T>
26222718
struct LogGradGradFunctor : public BaseActivationFunctor<T> {
26232719
template <typename Device>
@@ -2651,6 +2747,42 @@ struct LogGradGradFunctor : public BaseActivationFunctor<T> {
26512747
static constexpr ActBwdOpFwdDeps FwdDeps() { return ActBwdOpFwdDeps::kDepX; }
26522748
};
26532749

2750+
template <typename T>
2751+
struct LogGradGradFunctor<ComplexType<T>>
2752+
: public BaseActivationFunctor<ComplexType<T>> {
2753+
template <typename Device>
2754+
void operator()(const Device& dev,
2755+
const DenseTensor* X,
2756+
const DenseTensor* ddX,
2757+
DenseTensor* ddOut,
2758+
const DenseTensor* dOut,
2759+
DenseTensor* dX) const {
2760+
auto* d = dev.eigen_device();
2761+
auto ddx = EigenVector<ComplexType<T>>::Flatten(
2762+
GET_DATA_SAFELY(ddX, "Input", "DDX", "LogGradGrad"));
2763+
auto x = EigenVector<ComplexType<T>>::Flatten(
2764+
GET_DATA_SAFELY(X, "Input", "X", "LogGradGrad"));
2765+
// ddout = ddx / x; dx = -(dout / x) * (ddx / x)
2766+
// calculate dx first, so ddout can inplace ddx
2767+
if (dX) {
2768+
auto dout = EigenVector<ComplexType<T>>::Flatten(
2769+
GET_DATA_SAFELY(dOut, "Output", "DOut", "LogGradGrad"));
2770+
auto dx = EigenVector<ComplexType<T>>::Flatten(
2771+
GET_DATA_SAFELY(dX, "Output", "DX", "LogGradGrad"));
2772+
dx.device(*d) = dout * static_cast<ComplexType<T>>(-1) * ddx /
2773+
(x * x).unaryExpr(Conj<T>());
2774+
}
2775+
if (ddOut) {
2776+
auto ddout = EigenVector<ComplexType<T>>::Flatten(
2777+
GET_DATA_SAFELY(ddOut, "Output", "DDOut", "LogGradGrad"));
2778+
ddout.device(*d) =
2779+
ddx * static_cast<ComplexType<T>>(1) / x.unaryExpr(Conj<T>());
2780+
}
2781+
}
2782+
2783+
static constexpr ActBwdOpFwdDeps FwdDeps() { return ActBwdOpFwdDeps::kDepX; }
2784+
};
2785+
26542786
// HardSwish = min(max(0, x+3), 6) * x / 6
26552787
template <typename T>
26562788
struct HardSwishFunctor : public BaseActivationFunctor<T> {
@@ -4642,6 +4774,16 @@ struct CudaLogFunctor : public BaseActivationFunctor<T> {
46424774
}
46434775
};
46444776

4777+
template <typename T>
4778+
struct CudaLogFunctor<ComplexType<T>>
4779+
: public BaseActivationFunctor<ComplexType<T>> {
4780+
// log(x) = log(x)
4781+
__device__ __forceinline__ ComplexType<T> operator()(
4782+
const ComplexType<T> arg_x) const {
4783+
return static_cast<ComplexType<T>>(log(arg_x));
4784+
}
4785+
};
4786+
46454787
template <typename T>
46464788
struct CudaLogGradFunctor : public BaseActivationFunctor<T> {
46474789
// dx = dout / x
@@ -4652,6 +4794,18 @@ struct CudaLogGradFunctor : public BaseActivationFunctor<T> {
46524794
static constexpr ActBwdOpFwdDeps FwdDeps() { return ActBwdOpFwdDeps::kDepX; }
46534795
};
46544796

4797+
template <typename T>
4798+
struct CudaLogGradFunctor<ComplexType<T>>
4799+
: public BaseActivationFunctor<ComplexType<T>> {
4800+
// dx = dout / conj(x)
4801+
__device__ __forceinline__ ComplexType<T> operator()(
4802+
const ComplexType<T> dout, const ComplexType<T> x) const {
4803+
return dout / conj(x);
4804+
}
4805+
4806+
static constexpr ActBwdOpFwdDeps FwdDeps() { return ActBwdOpFwdDeps::kDepX; }
4807+
};
4808+
46554809
template <typename T>
46564810
struct CudaLog1pFunctor : public BaseActivationFunctor<T> {
46574811
using MPType = typename phi::dtype::MPTypeTrait<T>::Type;
@@ -4665,6 +4819,17 @@ struct CudaLog1pFunctor : public BaseActivationFunctor<T> {
46654819
}
46664820
};
46674821

4822+
template <typename T>
4823+
struct CudaLog1pFunctor<ComplexType<T>>
4824+
: public BaseActivationFunctor<ComplexType<T>> {
4825+
// log1p(x) = log(1 + x)
4826+
__device__ __forceinline__ ComplexType<T> operator()(
4827+
const ComplexType<T> arg_x) const {
4828+
return static_cast<ComplexType<T>>(
4829+
log(static_cast<ComplexType<T>>(1) + arg_x));
4830+
}
4831+
};
4832+
46684833
template <typename T>
46694834
struct CudaLog1pGradFunctor : public BaseActivationFunctor<T> {
46704835
T one = static_cast<T>(1.0f);
@@ -4677,6 +4842,20 @@ struct CudaLog1pGradFunctor : public BaseActivationFunctor<T> {
46774842
static constexpr ActBwdOpFwdDeps FwdDeps() { return ActBwdOpFwdDeps::kDepX; }
46784843
};
46794844

4845+
template <typename T>
4846+
struct CudaLog1pGradFunctor<ComplexType<T>>
4847+
: public BaseActivationFunctor<ComplexType<T>> {
4848+
ComplexType<T> one = static_cast<ComplexType<T>>(1.0f);
4849+
4850+
// dx = dout / conj(1 + x)
4851+
__device__ __forceinline__ ComplexType<T> operator()(
4852+
const ComplexType<T> dout, const ComplexType<T> x) const {
4853+
return dout / conj(one + x);
4854+
}
4855+
4856+
static constexpr ActBwdOpFwdDeps FwdDeps() { return ActBwdOpFwdDeps::kDepX; }
4857+
};
4858+
46804859
template <typename T>
46814860
__device__ __forceinline__
46824861
std::conditional_t<std::is_integral<T>::value, float, T>
@@ -4709,6 +4888,17 @@ struct CudaLog2Functor : public BaseActivationFunctor<T> {
47094888
}
47104889
};
47114890

4891+
template <typename T>
4892+
struct CudaLog2Functor<ComplexType<T>>
4893+
: public BaseActivationFunctor<ComplexType<T>> {
4894+
// log2(x) = log(x)/log(2)
4895+
__device__ __forceinline__ ComplexType<T> operator()(
4896+
const ComplexType<T> arg_x) const {
4897+
return static_cast<ComplexType<T>>(log(arg_x) /
4898+
static_cast<ComplexType<T>>(log(2.0f)));
4899+
}
4900+
};
4901+
47124902
template <typename T>
47134903
struct CudaLog2GradFunctor : public BaseActivationFunctor<T> {
47144904
using MPType = typename phi::dtype::MPTypeTrait<T>::Type;
@@ -4722,6 +4912,18 @@ struct CudaLog2GradFunctor : public BaseActivationFunctor<T> {
47224912
static constexpr ActBwdOpFwdDeps FwdDeps() { return ActBwdOpFwdDeps::kDepX; }
47234913
};
47244914

4915+
template <typename T>
4916+
struct CudaLog2GradFunctor<ComplexType<T>>
4917+
: public BaseActivationFunctor<ComplexType<T>> {
4918+
// dx = dout / conj(x * log(2))
4919+
__device__ __forceinline__ ComplexType<T> operator()(
4920+
const ComplexType<T> dout, const ComplexType<T> x) const {
4921+
return dout / conj(x * static_cast<ComplexType<T>>(log(2.0f)));
4922+
}
4923+
4924+
static constexpr ActBwdOpFwdDeps FwdDeps() { return ActBwdOpFwdDeps::kDepX; }
4925+
};
4926+
47254927
template <typename T>
47264928
__device__ __forceinline__
47274929
std::conditional_t<std::is_integral<T>::value, float, T>
@@ -4754,6 +4956,17 @@ struct CudaLog10Functor : public BaseActivationFunctor<T> {
47544956
}
47554957
};
47564958

4959+
template <typename T>
4960+
struct CudaLog10Functor<ComplexType<T>>
4961+
: public BaseActivationFunctor<ComplexType<T>> {
4962+
// log10(x) = log(x)/log(10)
4963+
__device__ __forceinline__ ComplexType<T> operator()(
4964+
const ComplexType<T> arg_x) const {
4965+
return static_cast<ComplexType<T>>(log(arg_x) /
4966+
static_cast<ComplexType<T>>(log(10.0f)));
4967+
}
4968+
};
4969+
47574970
template <typename T>
47584971
struct CudaLog10GradFunctor : public BaseActivationFunctor<T> {
47594972
using MPType = typename phi::dtype::MPTypeTrait<T>::Type;
@@ -4767,6 +4980,18 @@ struct CudaLog10GradFunctor : public BaseActivationFunctor<T> {
47674980
static constexpr ActBwdOpFwdDeps FwdDeps() { return ActBwdOpFwdDeps::kDepX; }
47684981
};
47694982

4983+
template <typename T>
4984+
struct CudaLog10GradFunctor<ComplexType<T>>
4985+
: public BaseActivationFunctor<ComplexType<T>> {
4986+
// dx = dout / conj(x * log(10))
4987+
__device__ __forceinline__ ComplexType<T> operator()(
4988+
const ComplexType<T> dout, const ComplexType<T> x) const {
4989+
return dout / conj(x * static_cast<ComplexType<T>>(log(10.0f)));
4990+
}
4991+
4992+
static constexpr ActBwdOpFwdDeps FwdDeps() { return ActBwdOpFwdDeps::kDepX; }
4993+
};
4994+
47704995
template <typename T>
47714996
struct CudaSwishFunctor : public BaseActivationFunctor<T> {
47724997
using MPType = typename phi::dtype::MPTypeTrait<T>::Type;

0 commit comments

Comments
 (0)