@@ -2445,6 +2445,13 @@ struct Log {
24452445 HOSTDEVICE T operator ()(const T& val) const { return std::log (val); }
24462446};
24472447
2448+ template <typename T>
2449+ struct Log <ComplexType<T>> {
2450+ HOSTDEVICE ComplexType<T> operator ()(const ComplexType<T>& val) const {
2451+ return ComplexType<T>(std::log (std::complex <T>(val)));
2452+ }
2453+ };
2454+
24482455template <>
24492456struct Log <dtype::float16> {
24502457 HOSTDEVICE dtype::float16 operator ()(const dtype::float16& val) const {
@@ -2484,11 +2491,35 @@ struct LogGradFunctor : public BaseActivationFunctor<T> {
24842491 static constexpr ActBwdOpFwdDeps FwdDeps () { return ActBwdOpFwdDeps::kDepX ; }
24852492};
24862493
2494+ template <typename T>
2495+ struct LogGradFunctor <ComplexType<T>>
2496+ : public BaseActivationFunctor<ComplexType<T>> {
2497+ template <typename Device,
2498+ typename X,
2499+ typename Out,
2500+ typename dOut,
2501+ typename dX>
2502+ void operator ()(Device d, X x, Out out UNUSED, dOut dout, dX dx) const {
2503+ dx.device (d) =
2504+ dout * (static_cast <ComplexType<T>>(1 ) / x).unaryExpr (Conj<T>());
2505+ }
2506+
2507+ static constexpr ActBwdOpFwdDeps FwdDeps () { return ActBwdOpFwdDeps::kDepX ; }
2508+ };
2509+
24872510template <typename T>
24882511struct Log2 {
24892512 HOSTDEVICE T operator ()(const T& val) const { return std::log2 (val); }
24902513};
24912514
2515+ template <typename T>
2516+ struct Log2 <ComplexType<T>> {
2517+ HOSTDEVICE ComplexType<T> operator ()(const ComplexType<T>& val) const {
2518+ return ComplexType<T>(std::log (std::complex <T>(val)) /
2519+ std::log (std::complex <T>(2 )));
2520+ }
2521+ };
2522+
24922523template <>
24932524struct Log2 <dtype::float16> {
24942525 HOSTDEVICE dtype::float16 operator ()(const dtype::float16& val) const {
@@ -2529,11 +2560,35 @@ struct Log2GradFunctor : public BaseActivationFunctor<T> {
25292560 static constexpr ActBwdOpFwdDeps FwdDeps () { return ActBwdOpFwdDeps::kDepX ; }
25302561};
25312562
2563+ template <typename T>
2564+ struct Log2GradFunctor <ComplexType<T>>
2565+ : public BaseActivationFunctor<ComplexType<T>> {
2566+ template <typename Device,
2567+ typename X,
2568+ typename Out,
2569+ typename dOut,
2570+ typename dX>
2571+ void operator ()(Device d, X x, Out out UNUSED, dOut dout, dX dx) const {
2572+ dx.device (d) = dout * (static_cast <ComplexType<T>>(1 ) /
2573+ (x * static_cast <ComplexType<T>>(log (2 ))))
2574+ .unaryExpr (Conj<T>());
2575+ }
2576+
2577+ static constexpr ActBwdOpFwdDeps FwdDeps () { return ActBwdOpFwdDeps::kDepX ; }
2578+ };
2579+
25322580template <typename T>
25332581struct Log10 {
25342582 HOSTDEVICE T operator ()(const T& val) const { return std::log10 (val); }
25352583};
25362584
2585+ template <typename T>
2586+ struct Log10 <ComplexType<T>> {
2587+ HOSTDEVICE ComplexType<T> operator ()(const ComplexType<T>& val) const {
2588+ return ComplexType<T>(std::log10 (std::complex <T>(val)));
2589+ }
2590+ };
2591+
25372592template <>
25382593struct Log10 <dtype::float16> {
25392594 HOSTDEVICE dtype::float16 operator ()(const dtype::float16& val) const {
@@ -2574,11 +2629,35 @@ struct Log10GradFunctor : public BaseActivationFunctor<T> {
25742629 static constexpr ActBwdOpFwdDeps FwdDeps () { return ActBwdOpFwdDeps::kDepX ; }
25752630};
25762631
2632+ template <typename T>
2633+ struct Log10GradFunctor <ComplexType<T>>
2634+ : public BaseActivationFunctor<ComplexType<T>> {
2635+ template <typename Device,
2636+ typename X,
2637+ typename Out,
2638+ typename dOut,
2639+ typename dX>
2640+ void operator ()(Device d, X x, Out out UNUSED, dOut dout, dX dx) const {
2641+ dx.device (d) = dout * (static_cast <ComplexType<T>>(1 ) /
2642+ (x * static_cast <ComplexType<T>>(log (10 ))))
2643+ .unaryExpr (Conj<T>());
2644+ }
2645+
2646+ static constexpr ActBwdOpFwdDeps FwdDeps () { return ActBwdOpFwdDeps::kDepX ; }
2647+ };
2648+
25772649template <typename T>
25782650struct Log1p {
25792651 HOSTDEVICE T operator ()(const T& val) const { return std::log1p (val); }
25802652};
25812653
2654+ template <typename T>
2655+ struct Log1p <ComplexType<T>> {
2656+ HOSTDEVICE ComplexType<T> operator ()(const ComplexType<T>& val) const {
2657+ return ComplexType<T>(std::log (std::complex <T>(1 ) + std::complex <T>(val)));
2658+ }
2659+ };
2660+
25822661template <>
25832662struct Log1p <dtype::float16> {
25842663 HOSTDEVICE dtype::float16 operator ()(const dtype::float16& val) const {
@@ -2618,6 +2697,23 @@ struct Log1pGradFunctor : public BaseActivationFunctor<T> {
26182697 static constexpr ActBwdOpFwdDeps FwdDeps () { return ActBwdOpFwdDeps::kDepX ; }
26192698};
26202699
2700+ template <typename T>
2701+ struct Log1pGradFunctor <ComplexType<T>>
2702+ : public BaseActivationFunctor<ComplexType<T>> {
2703+ template <typename Device,
2704+ typename X,
2705+ typename Out,
2706+ typename dOut,
2707+ typename dX>
2708+ void operator ()(Device d, X x, Out out UNUSED, dOut dout, dX dx) const {
2709+ dx.device (d) = dout * (static_cast <ComplexType<T>>(1 ) /
2710+ (x + static_cast <ComplexType<T>>(1 )))
2711+ .unaryExpr (Conj<T>());
2712+ }
2713+
2714+ static constexpr ActBwdOpFwdDeps FwdDeps () { return ActBwdOpFwdDeps::kDepX ; }
2715+ };
2716+
26212717template <typename T>
26222718struct LogGradGradFunctor : public BaseActivationFunctor <T> {
26232719 template <typename Device>
@@ -2651,6 +2747,42 @@ struct LogGradGradFunctor : public BaseActivationFunctor<T> {
26512747 static constexpr ActBwdOpFwdDeps FwdDeps () { return ActBwdOpFwdDeps::kDepX ; }
26522748};
26532749
2750+ template <typename T>
2751+ struct LogGradGradFunctor <ComplexType<T>>
2752+ : public BaseActivationFunctor<ComplexType<T>> {
2753+ template <typename Device>
2754+ void operator ()(const Device& dev,
2755+ const DenseTensor* X,
2756+ const DenseTensor* ddX,
2757+ DenseTensor* ddOut,
2758+ const DenseTensor* dOut,
2759+ DenseTensor* dX) const {
2760+ auto * d = dev.eigen_device ();
2761+ auto ddx = EigenVector<ComplexType<T>>::Flatten (
2762+ GET_DATA_SAFELY (ddX, " Input" , " DDX" , " LogGradGrad" ));
2763+ auto x = EigenVector<ComplexType<T>>::Flatten (
2764+ GET_DATA_SAFELY (X, " Input" , " X" , " LogGradGrad" ));
2765+ // ddout = ddx / x; dx = -(dout / x) * (ddx / x)
2766+ // calculate dx first, so ddout can inplace ddx
2767+ if (dX) {
2768+ auto dout = EigenVector<ComplexType<T>>::Flatten (
2769+ GET_DATA_SAFELY (dOut, " Output" , " DOut" , " LogGradGrad" ));
2770+ auto dx = EigenVector<ComplexType<T>>::Flatten (
2771+ GET_DATA_SAFELY (dX, " Output" , " DX" , " LogGradGrad" ));
2772+ dx.device (*d) = dout * static_cast <ComplexType<T>>(-1 ) * ddx /
2773+ (x * x).unaryExpr (Conj<T>());
2774+ }
2775+ if (ddOut) {
2776+ auto ddout = EigenVector<ComplexType<T>>::Flatten (
2777+ GET_DATA_SAFELY (ddOut, " Output" , " DDOut" , " LogGradGrad" ));
2778+ ddout.device (*d) =
2779+ ddx * static_cast <ComplexType<T>>(1 ) / x.unaryExpr (Conj<T>());
2780+ }
2781+ }
2782+
2783+ static constexpr ActBwdOpFwdDeps FwdDeps () { return ActBwdOpFwdDeps::kDepX ; }
2784+ };
2785+
26542786// HardSwish = min(max(0, x+3), 6) * x / 6
26552787template <typename T>
26562788struct HardSwishFunctor : public BaseActivationFunctor <T> {
@@ -4642,6 +4774,16 @@ struct CudaLogFunctor : public BaseActivationFunctor<T> {
46424774 }
46434775};
46444776
4777+ template <typename T>
4778+ struct CudaLogFunctor <ComplexType<T>>
4779+ : public BaseActivationFunctor<ComplexType<T>> {
4780+ // log(x) = log(x)
4781+ __device__ __forceinline__ ComplexType<T> operator ()(
4782+ const ComplexType<T> arg_x) const {
4783+ return static_cast <ComplexType<T>>(log (arg_x));
4784+ }
4785+ };
4786+
46454787template <typename T>
46464788struct CudaLogGradFunctor : public BaseActivationFunctor <T> {
46474789 // dx = dout / x
@@ -4652,6 +4794,18 @@ struct CudaLogGradFunctor : public BaseActivationFunctor<T> {
46524794 static constexpr ActBwdOpFwdDeps FwdDeps () { return ActBwdOpFwdDeps::kDepX ; }
46534795};
46544796
4797+ template <typename T>
4798+ struct CudaLogGradFunctor <ComplexType<T>>
4799+ : public BaseActivationFunctor<ComplexType<T>> {
4800+ // dx = dout / conj(x)
4801+ __device__ __forceinline__ ComplexType<T> operator ()(
4802+ const ComplexType<T> dout, const ComplexType<T> x) const {
4803+ return dout / conj (x);
4804+ }
4805+
4806+ static constexpr ActBwdOpFwdDeps FwdDeps () { return ActBwdOpFwdDeps::kDepX ; }
4807+ };
4808+
46554809template <typename T>
46564810struct CudaLog1pFunctor : public BaseActivationFunctor <T> {
46574811 using MPType = typename phi::dtype::MPTypeTrait<T>::Type;
@@ -4665,6 +4819,17 @@ struct CudaLog1pFunctor : public BaseActivationFunctor<T> {
46654819 }
46664820};
46674821
4822+ template <typename T>
4823+ struct CudaLog1pFunctor <ComplexType<T>>
4824+ : public BaseActivationFunctor<ComplexType<T>> {
4825+ // log1p(x) = log(1 + x)
4826+ __device__ __forceinline__ ComplexType<T> operator ()(
4827+ const ComplexType<T> arg_x) const {
4828+ return static_cast <ComplexType<T>>(
4829+ log (static_cast <ComplexType<T>>(1 ) + arg_x));
4830+ }
4831+ };
4832+
46684833template <typename T>
46694834struct CudaLog1pGradFunctor : public BaseActivationFunctor <T> {
46704835 T one = static_cast <T>(1 .0f );
@@ -4677,6 +4842,20 @@ struct CudaLog1pGradFunctor : public BaseActivationFunctor<T> {
46774842 static constexpr ActBwdOpFwdDeps FwdDeps () { return ActBwdOpFwdDeps::kDepX ; }
46784843};
46794844
4845+ template <typename T>
4846+ struct CudaLog1pGradFunctor <ComplexType<T>>
4847+ : public BaseActivationFunctor<ComplexType<T>> {
4848+ ComplexType<T> one = static_cast <ComplexType<T>>(1 .0f );
4849+
4850+ // dx = dout / conj(1 + x)
4851+ __device__ __forceinline__ ComplexType<T> operator ()(
4852+ const ComplexType<T> dout, const ComplexType<T> x) const {
4853+ return dout / conj (one + x);
4854+ }
4855+
4856+ static constexpr ActBwdOpFwdDeps FwdDeps () { return ActBwdOpFwdDeps::kDepX ; }
4857+ };
4858+
46804859template <typename T>
46814860__device__ __forceinline__
46824861 std::conditional_t <std::is_integral<T>::value, float , T>
@@ -4709,6 +4888,17 @@ struct CudaLog2Functor : public BaseActivationFunctor<T> {
47094888 }
47104889};
47114890
4891+ template <typename T>
4892+ struct CudaLog2Functor <ComplexType<T>>
4893+ : public BaseActivationFunctor<ComplexType<T>> {
4894+ // log2(x) = log(x)/log(2)
4895+ __device__ __forceinline__ ComplexType<T> operator ()(
4896+ const ComplexType<T> arg_x) const {
4897+ return static_cast <ComplexType<T>>(log (arg_x) /
4898+ static_cast <ComplexType<T>>(log (2 .0f )));
4899+ }
4900+ };
4901+
47124902template <typename T>
47134903struct CudaLog2GradFunctor : public BaseActivationFunctor <T> {
47144904 using MPType = typename phi::dtype::MPTypeTrait<T>::Type;
@@ -4722,6 +4912,18 @@ struct CudaLog2GradFunctor : public BaseActivationFunctor<T> {
47224912 static constexpr ActBwdOpFwdDeps FwdDeps () { return ActBwdOpFwdDeps::kDepX ; }
47234913};
47244914
4915+ template <typename T>
4916+ struct CudaLog2GradFunctor <ComplexType<T>>
4917+ : public BaseActivationFunctor<ComplexType<T>> {
4918+ // dx = dout / conj(x * log(2))
4919+ __device__ __forceinline__ ComplexType<T> operator ()(
4920+ const ComplexType<T> dout, const ComplexType<T> x) const {
4921+ return dout / conj (x * static_cast <ComplexType<T>>(log (2 .0f )));
4922+ }
4923+
4924+ static constexpr ActBwdOpFwdDeps FwdDeps () { return ActBwdOpFwdDeps::kDepX ; }
4925+ };
4926+
47254927template <typename T>
47264928__device__ __forceinline__
47274929 std::conditional_t <std::is_integral<T>::value, float , T>
@@ -4754,6 +4956,17 @@ struct CudaLog10Functor : public BaseActivationFunctor<T> {
47544956 }
47554957};
47564958
4959+ template <typename T>
4960+ struct CudaLog10Functor <ComplexType<T>>
4961+ : public BaseActivationFunctor<ComplexType<T>> {
4962+ // log10(x) = log(x)/log(10)
4963+ __device__ __forceinline__ ComplexType<T> operator ()(
4964+ const ComplexType<T> arg_x) const {
4965+ return static_cast <ComplexType<T>>(log (arg_x) /
4966+ static_cast <ComplexType<T>>(log (10 .0f )));
4967+ }
4968+ };
4969+
47574970template <typename T>
47584971struct CudaLog10GradFunctor : public BaseActivationFunctor <T> {
47594972 using MPType = typename phi::dtype::MPTypeTrait<T>::Type;
@@ -4767,6 +4980,18 @@ struct CudaLog10GradFunctor : public BaseActivationFunctor<T> {
47674980 static constexpr ActBwdOpFwdDeps FwdDeps () { return ActBwdOpFwdDeps::kDepX ; }
47684981};
47694982
4983+ template <typename T>
4984+ struct CudaLog10GradFunctor <ComplexType<T>>
4985+ : public BaseActivationFunctor<ComplexType<T>> {
4986+ // dx = dout / conj(x * log(10))
4987+ __device__ __forceinline__ ComplexType<T> operator ()(
4988+ const ComplexType<T> dout, const ComplexType<T> x) const {
4989+ return dout / conj (x * static_cast <ComplexType<T>>(log (10 .0f )));
4990+ }
4991+
4992+ static constexpr ActBwdOpFwdDeps FwdDeps () { return ActBwdOpFwdDeps::kDepX ; }
4993+ };
4994+
47704995template <typename T>
47714996struct CudaSwishFunctor : public BaseActivationFunctor <T> {
47724997 using MPType = typename phi::dtype::MPTypeTrait<T>::Type;
0 commit comments