|
13 | 13 | // limitations under the License. |
14 | 14 |
|
15 | 15 | #pragma once |
16 | | - |
17 | | -#include "paddle/fluid/operators/amp/fp16_type_traits.h" |
18 | | -#include "paddle/fluid/platform/enforce.h" |
19 | | -#include "paddle/fluid/platform/float16.h" |
20 | | -#include "paddle/pten/kernels/funcs/eigen/extensions.h" |
| 16 | +#include "paddle/pten/kernels/primitive/functor_primitives.h" |
21 | 17 |
|
22 | 18 | namespace paddle { |
23 | 19 | namespace operators { |
24 | | -namespace kernel_primitives { |
25 | | -namespace details { |
26 | | - |
27 | | -static __device__ __forceinline__ platform::float16 Exp(platform::float16 x) { |
28 | | - return ::Eigen::numext::exp(x); |
29 | | -} |
30 | | - |
31 | | -static __device__ __forceinline__ float Exp(float x) { return expf(x); } |
32 | | - |
33 | | -static __device__ __forceinline__ double Exp(double x) { return exp(x); } |
34 | | - |
35 | | -static __device__ __forceinline__ platform::float16 Log(platform::float16 x) { |
36 | | - return ::Eigen::numext::log(x); |
37 | | -} |
38 | | - |
39 | | -static __device__ __forceinline__ float Log(float x) { return logf(x); } |
40 | | - |
41 | | -static __device__ __forceinline__ double Log(double x) { return log(x); } |
42 | | - |
43 | | -} // namespace details |
44 | | - |
45 | | -/******************************** Unary Functor *******************************/ |
46 | | - |
47 | | -/** |
48 | | - * @brief Default unary exp functor |
49 | | - */ |
50 | | -template <typename Tx, typename Ty = Tx> |
51 | | -struct ExpFunctor { |
52 | | - HOSTDEVICE inline ExpFunctor() {} |
53 | | - |
54 | | - HOSTDEVICE explicit inline ExpFunctor(int n) {} |
55 | | - |
56 | | - HOSTDEVICE inline Ty operator()(const Tx x) const { |
57 | | - return static_cast<Ty>(details::Exp(x)); |
58 | | - } |
59 | | -}; |
60 | | - |
61 | | -/** |
62 | | - * @brief Default unary identity functor |
63 | | - */ |
64 | | -template <typename Tx, typename Ty = Tx> |
65 | | -struct IdentityFunctor { |
66 | | - HOSTDEVICE inline IdentityFunctor() {} |
67 | | - |
68 | | - HOSTDEVICE explicit inline IdentityFunctor(int n) {} |
69 | | - |
70 | | - HOSTDEVICE inline Ty operator()(const Tx x) const { |
71 | | - return static_cast<Ty>(x); |
72 | | - } |
73 | | -}; |
74 | | - |
75 | | -/** |
76 | | - * @brief Default unary div functor. Divide by a constant |
77 | | - */ |
78 | | -template <typename Tx, typename Ty = Tx> |
79 | | -struct DivideFunctor { |
80 | | - private: |
81 | | - using MPType = typename ::paddle::operators::details::MPTypeTrait<Tx>::Type; |
82 | | - |
83 | | - public: |
84 | | - HOSTDEVICE inline DivideFunctor() { n_inv = static_cast<MPType>(1.0f); } |
85 | | - |
86 | | - HOSTDEVICE explicit inline DivideFunctor(int n) : n_inv((MPType)(1.0 / n)) {} |
87 | | - |
88 | | - HOSTDEVICE inline Ty operator()(const Tx x) const { |
89 | | - return static_cast<Ty>(static_cast<MPType>(x) * n_inv); |
90 | | - } |
91 | | - |
92 | | - private: |
93 | | - MPType n_inv; |
94 | | -}; |
95 | | - |
96 | | -/** |
97 | | - * @brief Default inverse functor |
98 | | - */ |
99 | | -template <typename Tx, typename Ty = Tx> |
100 | | -struct InverseFunctor { |
101 | | - HOSTDEVICE inline InverseFunctor() {} |
102 | | - |
103 | | - HOSTDEVICE explicit inline InverseFunctor(int n) {} |
104 | | - |
105 | | - HOSTDEVICE inline Ty operator()(const Tx x) const { |
106 | | - return static_cast<Ty>(-x); |
107 | | - } |
108 | | -}; |
109 | | - |
110 | | -/** |
111 | | - * @brief Default unary square functor |
112 | | - */ |
113 | | -template <typename Tx, typename Ty = Tx> |
114 | | -struct SquareFunctor { |
115 | | - HOSTDEVICE inline SquareFunctor() {} |
116 | | - |
117 | | - HOSTDEVICE explicit inline SquareFunctor(int n) {} |
118 | | - |
119 | | - HOSTDEVICE inline Ty operator()(const Tx x) const { |
120 | | - return static_cast<Ty>(x) * static_cast<Ty>(x); |
121 | | - } |
122 | | -}; |
123 | | - |
124 | | -/****************************** Binary Functor ********************************/ |
125 | | - |
126 | | -/** |
127 | | - * @brief Default binary min functor |
128 | | - */ |
129 | | -template <typename T> |
130 | | -struct MinFunctor { |
131 | | - inline T initial() { return static_cast<T>(std::numeric_limits<T>::max()); } |
132 | | - |
133 | | - __device__ __forceinline__ T operator()(const T a, const T b) const { |
134 | | - return (b < a) ? b : a; |
135 | | - } |
136 | | -}; |
137 | | - |
138 | | -/** |
139 | | - * @brief Default binary max functor |
140 | | - */ |
141 | | -template <typename T> |
142 | | -struct MaxFunctor { |
143 | | - inline T initial() { |
144 | | - return static_cast<T>(std::numeric_limits<T>::lowest()); |
145 | | - } |
146 | | - |
147 | | - __device__ __forceinline__ T operator()(const T a, const T b) const { |
148 | | - return (b > a) ? b : a; |
149 | | - } |
150 | | -}; |
151 | | - |
152 | | -/** |
153 | | - * @brief Default binary add functor |
154 | | - */ |
155 | | -template <typename T> |
156 | | -struct AddFunctor { |
157 | | - inline T initial() { return static_cast<T>(0.0f); } |
158 | | - |
159 | | - __device__ __forceinline__ T operator()(const T a, const T b) const { |
160 | | - return b + a; |
161 | | - } |
162 | | -}; |
163 | | - |
164 | | -/** |
165 | | - * @brief Default binary add functor |
166 | | - */ |
167 | | -template <typename T> |
168 | | -struct MulFunctor { |
169 | | - inline T initial() { return static_cast<T>(1.0f); } |
170 | | - |
171 | | - __device__ __forceinline__ T operator()(const T a, const T b) const { |
172 | | - return b * a; |
173 | | - } |
174 | | -}; |
175 | | - |
176 | | -/** |
177 | | - * @brief Default binary logic or functor |
178 | | - */ |
179 | | -template <typename T> |
180 | | -struct LogicalOrFunctor { |
181 | | - inline T initial() { return static_cast<T>(false); } |
182 | | - |
183 | | - __device__ __forceinline__ T operator()(const T a, const T b) const { |
184 | | - return b || a; |
185 | | - } |
186 | | -}; |
187 | | - |
188 | | -/** |
189 | | - * @brief Default binary logic and functor |
190 | | - */ |
191 | | -template <typename T> |
192 | | -struct LogicalAndFunctor { |
193 | | - inline T initial() { return static_cast<T>(true); } |
194 | | - |
195 | | - __device__ __forceinline__ T operator()(const T a, const T b) const { |
196 | | - return b && a; |
197 | | - } |
198 | | -}; |
199 | | - |
200 | | -/** |
201 | | - * @brief Default binary sub functor |
202 | | - */ |
203 | | -template <typename T> |
204 | | -struct SubFunctor { |
205 | | - inline T initial() { return static_cast<T>(0.0f); } |
206 | | - |
207 | | - inline HOSTDEVICE T operator()(const T a, const T b) const { return a - b; } |
208 | | -}; |
209 | | - |
210 | | -/** |
211 | | - * @brief Default binary div functor |
212 | | - */ |
213 | | -template <typename T, typename Enable = void> |
214 | | -struct DivFunctor { |
215 | | - inline T initial() { return static_cast<T>(1.0f); } |
216 | | - |
217 | | - inline HOSTDEVICE T operator()(const T a, const T b) const { return a / b; } |
218 | | -}; |
219 | | - |
220 | | -template <typename T> |
221 | | -struct DivFunctor<T, |
222 | | - typename std::enable_if<std::is_integral<T>::value>::type> { |
223 | | - inline T initial() { return static_cast<T>(1.0f); } |
224 | | - |
225 | | - inline HOSTDEVICE T operator()(const T a, const T b) const { |
226 | | - // For int32/int64, need to check whether the divison is zero. |
227 | | - PADDLE_ENFORCE_NE(b, 0, |
228 | | - platform::errors::InvalidArgument( |
229 | | - "Integer division by zero encountered " |
230 | | - "in (floor) divide. Please check the input value.")); |
231 | | - return a / b; |
232 | | - } |
233 | | -}; |
234 | | - |
235 | | -/** |
236 | | - * @brief Default binary floor divide functor |
237 | | - */ |
238 | | -template <typename T> |
239 | | -struct FloorDivFunctor { |
240 | | - inline T initial() { return static_cast<T>(1.0f); } |
241 | | - |
242 | | - inline HOSTDEVICE T operator()(const T a, const T b) const { |
243 | | - PADDLE_ENFORCE_NE(b, 0, |
244 | | - platform::errors::InvalidArgument( |
245 | | - "Integer division by zero encountered " |
246 | | - "in (floor) divide. Please check the input value.")); |
247 | | - return static_cast<T>(std::trunc(a / b)); |
248 | | - } |
249 | | -}; |
250 | | - |
251 | | -} // namespace kernel_primitives |
| 20 | +namespace kernel_primitives = pten::kps; |
252 | 21 | } // namespace operators |
253 | 22 | } // namespace paddle |
0 commit comments