Skip to content

Commit 452bcbe

Browse files
authored
[Pten]Move kernel_primitives lib to Pten directory (PaddlePaddle#39169)
* move kernel_primitives * use pten's errors
1 parent bd5c962 commit 452bcbe

File tree

11 files changed

+578
-403
lines changed

11 files changed

+578
-403
lines changed

paddle/fluid/operators/kernel_primitives/functor_primitives.h

Lines changed: 2 additions & 233 deletions
Original file line numberDiff line numberDiff line change
@@ -13,241 +13,10 @@
1313
// limitations under the License.
1414

1515
#pragma once
16-
17-
#include "paddle/fluid/operators/amp/fp16_type_traits.h"
18-
#include "paddle/fluid/platform/enforce.h"
19-
#include "paddle/fluid/platform/float16.h"
20-
#include "paddle/pten/kernels/funcs/eigen/extensions.h"
16+
#include "paddle/pten/kernels/primitive/functor_primitives.h"
2117

2218
namespace paddle {
2319
namespace operators {
24-
namespace kernel_primitives {
25-
namespace details {
26-
27-
static __device__ __forceinline__ platform::float16 Exp(platform::float16 x) {
28-
return ::Eigen::numext::exp(x);
29-
}
30-
31-
static __device__ __forceinline__ float Exp(float x) { return expf(x); }
32-
33-
static __device__ __forceinline__ double Exp(double x) { return exp(x); }
34-
35-
static __device__ __forceinline__ platform::float16 Log(platform::float16 x) {
36-
return ::Eigen::numext::log(x);
37-
}
38-
39-
static __device__ __forceinline__ float Log(float x) { return logf(x); }
40-
41-
static __device__ __forceinline__ double Log(double x) { return log(x); }
42-
43-
} // namespace details
44-
45-
/******************************** Unary Functor *******************************/
46-
47-
/**
48-
* @brief Default unary exp functor
49-
*/
50-
template <typename Tx, typename Ty = Tx>
51-
struct ExpFunctor {
52-
HOSTDEVICE inline ExpFunctor() {}
53-
54-
HOSTDEVICE explicit inline ExpFunctor(int n) {}
55-
56-
HOSTDEVICE inline Ty operator()(const Tx x) const {
57-
return static_cast<Ty>(details::Exp(x));
58-
}
59-
};
60-
61-
/**
62-
* @brief Default unary identity functor
63-
*/
64-
template <typename Tx, typename Ty = Tx>
65-
struct IdentityFunctor {
66-
HOSTDEVICE inline IdentityFunctor() {}
67-
68-
HOSTDEVICE explicit inline IdentityFunctor(int n) {}
69-
70-
HOSTDEVICE inline Ty operator()(const Tx x) const {
71-
return static_cast<Ty>(x);
72-
}
73-
};
74-
75-
/**
76-
* @brief Default unary div functor. Divide by a constant
77-
*/
78-
template <typename Tx, typename Ty = Tx>
79-
struct DivideFunctor {
80-
private:
81-
using MPType = typename ::paddle::operators::details::MPTypeTrait<Tx>::Type;
82-
83-
public:
84-
HOSTDEVICE inline DivideFunctor() { n_inv = static_cast<MPType>(1.0f); }
85-
86-
HOSTDEVICE explicit inline DivideFunctor(int n) : n_inv((MPType)(1.0 / n)) {}
87-
88-
HOSTDEVICE inline Ty operator()(const Tx x) const {
89-
return static_cast<Ty>(static_cast<MPType>(x) * n_inv);
90-
}
91-
92-
private:
93-
MPType n_inv;
94-
};
95-
96-
/**
97-
* @brief Default inverse functor
98-
*/
99-
template <typename Tx, typename Ty = Tx>
100-
struct InverseFunctor {
101-
HOSTDEVICE inline InverseFunctor() {}
102-
103-
HOSTDEVICE explicit inline InverseFunctor(int n) {}
104-
105-
HOSTDEVICE inline Ty operator()(const Tx x) const {
106-
return static_cast<Ty>(-x);
107-
}
108-
};
109-
110-
/**
111-
* @brief Default unary square functor
112-
*/
113-
template <typename Tx, typename Ty = Tx>
114-
struct SquareFunctor {
115-
HOSTDEVICE inline SquareFunctor() {}
116-
117-
HOSTDEVICE explicit inline SquareFunctor(int n) {}
118-
119-
HOSTDEVICE inline Ty operator()(const Tx x) const {
120-
return static_cast<Ty>(x) * static_cast<Ty>(x);
121-
}
122-
};
123-
124-
/****************************** Binary Functor ********************************/
125-
126-
/**
127-
* @brief Default binary min functor
128-
*/
129-
template <typename T>
130-
struct MinFunctor {
131-
inline T initial() { return static_cast<T>(std::numeric_limits<T>::max()); }
132-
133-
__device__ __forceinline__ T operator()(const T a, const T b) const {
134-
return (b < a) ? b : a;
135-
}
136-
};
137-
138-
/**
139-
* @brief Default binary max functor
140-
*/
141-
template <typename T>
142-
struct MaxFunctor {
143-
inline T initial() {
144-
return static_cast<T>(std::numeric_limits<T>::lowest());
145-
}
146-
147-
__device__ __forceinline__ T operator()(const T a, const T b) const {
148-
return (b > a) ? b : a;
149-
}
150-
};
151-
152-
/**
153-
* @brief Default binary add functor
154-
*/
155-
template <typename T>
156-
struct AddFunctor {
157-
inline T initial() { return static_cast<T>(0.0f); }
158-
159-
__device__ __forceinline__ T operator()(const T a, const T b) const {
160-
return b + a;
161-
}
162-
};
163-
164-
/**
165-
* @brief Default binary add functor
166-
*/
167-
template <typename T>
168-
struct MulFunctor {
169-
inline T initial() { return static_cast<T>(1.0f); }
170-
171-
__device__ __forceinline__ T operator()(const T a, const T b) const {
172-
return b * a;
173-
}
174-
};
175-
176-
/**
177-
* @brief Default binary logic or functor
178-
*/
179-
template <typename T>
180-
struct LogicalOrFunctor {
181-
inline T initial() { return static_cast<T>(false); }
182-
183-
__device__ __forceinline__ T operator()(const T a, const T b) const {
184-
return b || a;
185-
}
186-
};
187-
188-
/**
189-
* @brief Default binary logic and functor
190-
*/
191-
template <typename T>
192-
struct LogicalAndFunctor {
193-
inline T initial() { return static_cast<T>(true); }
194-
195-
__device__ __forceinline__ T operator()(const T a, const T b) const {
196-
return b && a;
197-
}
198-
};
199-
200-
/**
201-
* @brief Default binary sub functor
202-
*/
203-
template <typename T>
204-
struct SubFunctor {
205-
inline T initial() { return static_cast<T>(0.0f); }
206-
207-
inline HOSTDEVICE T operator()(const T a, const T b) const { return a - b; }
208-
};
209-
210-
/**
211-
* @brief Default binary div functor
212-
*/
213-
template <typename T, typename Enable = void>
214-
struct DivFunctor {
215-
inline T initial() { return static_cast<T>(1.0f); }
216-
217-
inline HOSTDEVICE T operator()(const T a, const T b) const { return a / b; }
218-
};
219-
220-
template <typename T>
221-
struct DivFunctor<T,
222-
typename std::enable_if<std::is_integral<T>::value>::type> {
223-
inline T initial() { return static_cast<T>(1.0f); }
224-
225-
inline HOSTDEVICE T operator()(const T a, const T b) const {
226-
// For int32/int64, need to check whether the divison is zero.
227-
PADDLE_ENFORCE_NE(b, 0,
228-
platform::errors::InvalidArgument(
229-
"Integer division by zero encountered "
230-
"in (floor) divide. Please check the input value."));
231-
return a / b;
232-
}
233-
};
234-
235-
/**
236-
* @brief Default binary floor divide functor
237-
*/
238-
template <typename T>
239-
struct FloorDivFunctor {
240-
inline T initial() { return static_cast<T>(1.0f); }
241-
242-
inline HOSTDEVICE T operator()(const T a, const T b) const {
243-
PADDLE_ENFORCE_NE(b, 0,
244-
platform::errors::InvalidArgument(
245-
"Integer division by zero encountered "
246-
"in (floor) divide. Please check the input value."));
247-
return static_cast<T>(std::trunc(a / b));
248-
}
249-
};
250-
251-
} // namespace kernel_primitives
20+
namespace kernel_primitives = pten::kps;
25221
} // namespace operators
25322
} // namespace paddle

paddle/fluid/operators/kernel_primitives/kernel_primitives.h

Lines changed: 2 additions & 53 deletions
Original file line numberDiff line numberDiff line change
@@ -13,61 +13,10 @@
1313
// limitations under the License.
1414

1515
#pragma once
16-
#include "paddle/fluid/operators/kernel_primitives/helper_primitives.h"
17-
#ifdef PADDLE_WITH_XPU2
18-
#include "paddle/fluid/operators/kernel_primitives/compute_primitives_xpu2.h"
19-
#include "paddle/fluid/operators/kernel_primitives/datamover_primitives_xpu2.h"
20-
#include "paddle/fluid/operators/kernel_primitives/functor_primitives_xpu2.h"
21-
22-
#define KPStream XPUStream
23-
#define KPDevice paddle::platform::XPUDeviceContext
24-
#define _ptr_ _global_ptr_
25-
#define __forceinline__ __inline__
26-
#define __restrict__
27-
28-
#define THREAD_ID_X core_id()
29-
#define THREAD_ID_Y 0
30-
#define THREAD_ID_Z 0
31-
32-
#define BLOCK_NUM_X core_num()
33-
#define BLOCK_NUM_Y 0
34-
#define BLOCK_NUM_Z 0
35-
36-
#define BLOCK_ID_X cluster_id()
37-
#define BLOCK_ID_Y 0
38-
#define BLOCK_ID_Z 0
39-
40-
#define GRID_NUM_X cluster_num()
41-
#define GRID_NUM_Y 0
42-
#define GRID_NUM_Z 0
43-
#else
44-
#include "paddle/fluid/operators/kernel_primitives/compute_primitives.h"
45-
#include "paddle/fluid/operators/kernel_primitives/datamover_primitives.h"
46-
#include "paddle/fluid/operators/kernel_primitives/functor_primitives.h"
47-
48-
#define KPStream gpuStream_t
49-
#define KPDevice paddle::platform::CUDADeviceContext
50-
#define _ptr_
51-
52-
#define THREAD_ID_X threadIdx.x
53-
#define THREAD_ID_Y threadIdx.y
54-
#define THREAD_ID_Z threadIdx.z
55-
56-
#define BLOCK_NUM_X blockDim.x
57-
#define BLOCK_NUM_Y blockDim.y
58-
#define BLOCK_NUM_Z blockDim.z
59-
60-
#define BLOCK_ID_X blockIdx.x
61-
#define BLOCK_ID_Y blockIdx.y
62-
#define BLOCK_ID_Z blockIdx.z
63-
64-
#define GRID_NUM_X gridDim.x
65-
#define GRID_NUM_Y gridDim.y
66-
#define GRID_NUM_Z gridDim.z
67-
#endif
16+
#include "paddle/pten/kernels/primitive/kernel_primitives.h"
6817

6918
namespace paddle {
7019
namespace operators {
71-
namespace kernel_primitives {}
20+
namespace kernel_primitives = pten::kps;
7221
}
7322
}

paddle/pten/kernels/funcs/elementwise_base.h

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -22,12 +22,12 @@ limitations under the License. */
2222
#include "paddle/pten/kernels/empty_kernel.h"
2323

2424
#if defined(__NVCC__) || defined(__HIPCC__)
25-
#include "paddle/fluid/operators/kernel_primitives/kernel_primitives.h"
2625
#include "paddle/fluid/platform/aligned_vector.h"
2726
#include "paddle/fluid/platform/device/gpu/gpu_launch_config.h"
2827
#include "paddle/fluid/platform/function_traits.h"
28+
#include "paddle/pten/kernels/primitive/kernel_primitives.h"
2929

30-
namespace kps = paddle::operators::kernel_primitives;
30+
namespace kps = pten::kps;
3131

3232
#endif
3333

paddle/pten/kernels/gpu/reduce.h

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -34,13 +34,13 @@ namespace cub = hipcub;
3434

3535
#include "paddle/fluid/framework/op_registry.h"
3636
#include "paddle/fluid/operators/amp/fp16_type_traits.h"
37-
#include "paddle/fluid/operators/kernel_primitives/kernel_primitives.h"
3837
#include "paddle/fluid/platform/device/gpu/gpu_device_function.h"
3938
#include "paddle/fluid/platform/device/gpu/gpu_info.h"
4039
#include "paddle/fluid/platform/fast_divmod.h"
4140
#include "paddle/fluid/string/string_helper.h"
4241
#include "paddle/pten/core/array.h"
4342
#include "paddle/pten/core/enforce.h"
43+
#include "paddle/pten/kernels/primitive/kernel_primitives.h"
4444

4545
#include "paddle/pten/api/ext/dispatch.h"
4646
#include "paddle/pten/backends/gpu/gpu_context.h"
@@ -51,7 +51,7 @@ namespace cub = hipcub;
5151
#define REDUCE_SPLIT_BOUNDARY 512
5252
#define REDUCE_VEC_SIZE 4
5353

54-
namespace kps = paddle::operators::kernel_primitives;
54+
namespace kps = pten::kps;
5555

5656
namespace pten {
5757
namespace kernels {

0 commit comments

Comments
 (0)