@@ -54,8 +54,8 @@ class MPTypeTrait<platform::float16> {
5454};
5555
5656/* *
57- * @brief will be used in BlockYReduce, get the index of reduce_num in shared
58- * memory
57+ * @brief Will be used in BlockYReduce, get the index of reduce_num in shared
58+ * memory.
5959 */
6060__device__ __forceinline__ int SharedMemoryIndex (int index) {
6161 return (threadIdx.y + index) * blockDim.x + threadIdx.x ;
@@ -83,7 +83,7 @@ __device__ __forceinline__ T WarpReduce(T val, ReduceOp reducer) {
8383 */
8484
8585/* *
86- * @brief BlockXReduce reduce along blockDim.x
86+ * @brief BlockXReduce reduce along blockDim.x.
8787 */
8888template <typename T, typename ReduceOp>
8989__device__ __forceinline__ T BlockXReduce (T val, ReduceOp reducer) {
@@ -115,7 +115,7 @@ __device__ __forceinline__ T BlockXReduce(T val, ReduceOp reducer) {
115115}
116116
117117/* *
118- * @brief BlockYReduce reduce along blockDim.y
118+ * @brief BlockYReduce reduce along blockDim.y.
119119 */
120120template <typename T, typename ReduceOp>
121121__device__ __forceinline__ T BlockYReduce (T val, ReduceOp reducer) {
@@ -135,24 +135,33 @@ __device__ __forceinline__ T BlockYReduce(T val, ReduceOp reducer) {
135135} // namespace details
136136
137137/* *
138- * @brief unary function
139- * @param
140- * T: data type of in
141- * OutT: data type of out
142- * NX: the cols of in
143- * NY: the rows of in
144- * BlockSize: the config of this device
145- * OpFunc: compute functor which have an operator() as following
146- * template <typename T, typename OutT>
138+ * @brief Perform unary calculation according to OpFunc. Size of input and
139+ * output are the same.
140+ *
141+ * @template paraments
142+ * InT: Data type of in.
143+ * OutT: Data type of out.
144+ * NX: The number of data columns loaded by each thread.
145+ * NY: The number of data rows loaded by each thread.
146+ * BlockSize: Identifies the current device thread index method. For GPU,
147+ * threadIdx.x is used as the thread index, and for xpu, core_id() is used as
148+ * the index. Currently only GPU was supported.
149+ * OpFunc: Compute functor which has an operator() as following:
150+ * template <typename InT, typename OutT>
147151 * struct XxxFunctor {
148- * HOSTDEVICE OutT operator()(const T & a) const {
152+ * HOSTDEVICE OutT operator()(const InT & a) const {
149153 * return ...;
150154 * }
151155 * };
156+ *
157+ * @param:
158+ * out: The register pointer of out, the size is NX * NY.
159+ * in: The register pointer of in, the size is NX * NY.
160+ * compute: Compute function which was declared like OpFunc<InT, OutT>().
152161 */
153- template <typename T , typename OutT, int NX, int NY, int BlockSize,
162+ template <typename InT , typename OutT, int NX, int NY, int BlockSize,
154163 class OpFunc >
155- __device__ __forceinline__ void ElementwiseUnary (OutT* out, const T * in,
164+ __device__ __forceinline__ void ElementwiseUnary (OutT* out, const InT * in,
156165 OpFunc compute) {
157166#pragma unroll
158167 for (int idx = 0 ; idx < NX * NY; idx++) {
@@ -161,25 +170,35 @@ __device__ __forceinline__ void ElementwiseUnary(OutT* out, const T* in,
161170}
162171
163172/* *
164- * @brief binary function, in1 and in2 have same shape
165- * @param
166- * T: data type of in1, in2
167- * OutT: data type of out
168- * NX: the cols of in1, in2
169- * NY: the rows of in1, in2
170- * BlockSize: the config of this device
171- * OpFunc: compute functor which have an operator() as following
172- * template <typename T, typename OutT>
173+ * @brief Binary calculation according to OpFunc. Size of The input and output
174+ * are the same.
175+ *
176+ * @template paraments
177+ * InT: Data type of in1 and in2.
178+ * OutT: Data type of out.
179+ * NX: The number of data columns loaded by each thread.
180+ * NY: The number of data rows loaded by each thread.
181+ * BlockSize: Identifies the current device thread index method. For GPU,
182+ * threadIdx.x is used as the thread index, and for xpu, core_id() is used as
183+ * the index. Currently only GPU was supported.
184+ * OpFunc: Compute functor which has an operator() as following:
185+ * template <typename InT, typename OutT>
173186 * struct XxxFunctor {
174- * HOSTDEVICE OutT operator()(const T & a, const T & b) const {
187+ * HOSTDEVICE OutT operator()(const InT & a, const InT & b) const {
175188 * return ...;
176189 * }
177190 * };
191+ *
192+ * @param:
193+ * out: The register pointer of out, the size is NX * NY.
194+ * in1: The register pointer of fist input, size is NX * NY.
195+ * in2: The register pointer of second input, size is NX * NY.
196+ * compute: Compute function which was declared like OpFunc<InT, OutT>().
178197 */
179- template <typename T , typename OutT, int NX, int NY, int BlockSize,
198+ template <typename InT , typename OutT, int NX, int NY, int BlockSize,
180199 class OpFunc >
181- __device__ __forceinline__ void ElementwiseBinary (OutT* out, const T * in1,
182- const T * in2,
200+ __device__ __forceinline__ void ElementwiseBinary (OutT* out, const InT * in1,
201+ const InT * in2,
183202 OpFunc compute) {
184203#pragma unroll
185204 for (int idx = 0 ; idx < NX * NY; ++idx) {
@@ -188,25 +207,38 @@ __device__ __forceinline__ void ElementwiseBinary(OutT* out, const T* in1,
188207}
189208
190209/* *
191- * @brief ternary function, in1, in2 and in3 have same shape
192- * @param
193- * T: data type of in1, in2, in3
194- * OutT: data type of out
195- * NX: the cols of in1, in2
196- * NY: the rows of in1, in2
197- * BlockSize: the config of this device
198- * OpFunc: compute functor which have an operator() as following
199- * template <typename T, typename OutT>
210+ * @brief Ternary calculation according to OpFunc. Size of input and output
211+ * are the same.
212+ *
213+ * @template paraments
214+ * InT: Data type of in1 and in2.
215+ * OutT: Data type of out.
216+ * NX: The number of data columns loaded by each thread.
217+ * NY: The number of data rows loaded by each thread.
218+ * BlockSize: Identifies the current device thread index method. For GPU,
219+ * threadIdx.x is used as the thread index, and for xpu, core_id() is used as
220+ * the index. Currently only GPU was supported.
221+ * OpFunc: Compute functor which has an operator() as following
222+ * template <typename InT, typename OutT>
200223 * struct XxxFunctor {
201- * HOSTDEVICE OutT operator()(const T& a, const T& b, const T& c) const {
224+ * HOSTDEVICE OutT operator()(const InT& a, const InT& b, const InT& c)
225+ * const {
202226 * return ...;
203227 * }
204228 * };
229+ *
230+ * @param
231+ * out: The register pointer of out, the size is NX * NY.
232+ * in1: The register pointer of fist input, size is NX * NY.
233+ * in2: The register pointer of second input, size is NX * NY.
234+ * in3: The register pointer of third input, size is NX * NY.
235+ * compute: Compute function which was declared like OpFunc<InT, OutT>().
205236 */
206- template <typename T , typename OutT, int NX, int NY, int BlockSize,
237+ template <typename InT , typename OutT, int NX, int NY, int BlockSize,
207238 class OpFunc >
208- __device__ __forceinline__ void ElementwiseTernary (OutT* out, const T* in1,
209- const T* in2, const T* in3,
239+ __device__ __forceinline__ void ElementwiseTernary (OutT* out, const InT* in1,
240+ const InT* in2,
241+ const InT* in3,
210242 OpFunc compute) {
211243#pragma unroll
212244 for (int idx = 0 ; idx < NX * NY; ++idx) {
@@ -215,27 +247,36 @@ __device__ __forceinline__ void ElementwiseTernary(OutT* out, const T* in1,
215247}
216248
217249/* *
218- * @brief a general function for elementwise computation, all inputs have
219- * the same shape.
220- * @param
221- * T: data type of in1, in2, in3
222- * OutT: data type of out
223- * NX: the cols of in1, in2
224- * NY: the rows of in1, in2
225- * BlockSize: the config of this device
226- * OpFunc: compute functor which have an operator() as following
227- * template <typename T, typename OutT>
250+ * @brief Multivariate calculation according to OpFunc. Size of input and output
251+ * are the same.
252+ *
253+ * @template paraments
254+ * InT: Data type of in1, in2 and in3.
255+ * OutT: Data type of out.
256+ * NX: The number of data columns loaded by each thread.
257+ * NY: The number of data rows loaded by each thread.
258+ * BlockSize: Identifies the current device thread index method. For GPU,
259+ * threadIdx.x is used as the thread index, and for xpu, core_id() is used as
260+ * the index. Currently only GPU was supported.
261+ * Arity: The size of ins
262+ * OpFunc: Compute functor which has an operator() as following:
263+ * template <typename InT, typename OutT>
228264 * struct XxxFunctor {
229- * HOSTDEVICE OutT operator()(const T * args) const {
265+ * HOSTDEVICE OutT operator()(const InT * args) const {
230266 * return ...;
231267 * }
232268 * };
269+ *
270+ * @param
271+ * out: The register pointer of out, the size is NX * NY.
272+ * ins: An array of pointers consisting of multiple inputs.
273+ * compute: Compute function which was declared like OpFunc<InT, OutT>().
233274 */
234- template <typename T , typename OutT, int NX, int NY, int BlockSize, int Arity,
275+ template <typename InT , typename OutT, int NX, int NY, int BlockSize, int Arity,
235276 class OpFunc >
236- __device__ __forceinline__ void ElementwiseAny (OutT* out, T (*ins)[NX * NY],
277+ __device__ __forceinline__ void ElementwiseAny (OutT* out, InT (*ins)[NX * NY],
237278 OpFunc compute) {
238- T args[Arity];
279+ InT args[Arity];
239280#pragma unroll
240281 for (int idx = 0 ; idx < NX * NY; ++idx) {
241282#pragma unroll
@@ -247,20 +288,36 @@ __device__ __forceinline__ void ElementwiseAny(OutT* out, T (*ins)[NX * NY],
247288}
248289
249290/* *
250- * @brief cycle binary function, in1's shape size is [1, NX], in2's shape size
251- * is [NY, NX], out's shape size is [NY, NX]
291+ * @brief Binary calculation according to OpFunc. Shape of in1 and in2 are the
292+ * different. Shape of in1 is [1, NX], but in2's shape is [NY, NX], the output
293+ * shape is [NY, NX].
294+ *
295+ * @template paraments
296+ * InT: Data type of in1 and in2.
297+ * OutT: Data type of out.
298+ * NX: The number of data columns loaded by each thread.
299+ * NY: The number of data rows loaded by each thread.
300+ * BlockSize: Identifies the current device thread index method. For GPU,
301+ * threadIdx.x is used as the thread index, and for xpu, core_id() is used as
302+ * the index. Currently only GPU was supported.
303+ * OpFunc: Compute functor which has an operator() as following
304+ * template <typename InT, typename OutT>
305+ * struct XxxFunctor {
306+ * HOSTDEVICE OutT operator()(const InT& a, const InT& b) const {
307+ * return ...;
308+ * }
309+ * };
310+ *
252311 * @param
253- * T: data type of in1, in2
254- * OutT: data type of out
255- * NX: the cols of in1, in2
256- * NY: the rows of in1, in2
257- * BlockSize: the config of this device
258- * OpFunc: compute functor eg: in1 + in2, in1 - in2
312+ * out: The register pointer of out, the size is NX * NY.
313+ * in1: The register pointer of fist input, size is NX * 1.
314+ * in2: The register pointer of second input, size is NX * NY.
315+ * compute: Compute function which was declared like OpFunc<InT, OutT>().
259316 */
260- template <typename T , typename OutT, int NX, int NY, int BlockSize,
317+ template <typename InT , typename OutT, int NX, int NY, int BlockSize,
261318 class OpFunc >
262- __device__ __forceinline__ void CycleBinary (OutT* out, const T * in1,
263- const T * in2, OpFunc compute) {
319+ __device__ __forceinline__ void CycleBinary (OutT* out, const InT * in1,
320+ const InT * in2, OpFunc compute) {
264321#pragma unroll
265322 for (int idx = 0 ; idx < NX; idx++) {
266323#pragma unroll
@@ -272,26 +329,37 @@ __device__ __forceinline__ void CycleBinary(OutT* out, const T* in1,
272329}
273330
274331/* *
275- * @brief reduce function, in's shape size is [NX, NY].
276- * If ReduceMode == kLocalMode then reduce NX, the shape of out is [NY, 1],
277- * if ReduceMode == kGlobalMode then reduce between different threads, the
278- * shape of out is [NY, NX]. If reduce_last_dim is false and reduce_num was
279- * split, BlockYReduce will be called. If reduce_last_dim is true and
280- * reduce_num was split, BlockXReduce will be called
281- * @typename
282- * T: data type of in
283- * NX: the cols of in
284- * NY: the rows of in
285- * BlockSize: the config of this device
286- * OpFunc: reduce functor, eg: CustomSum, CustomMean in reduce_functor_op.h
287- * @param:
288- * reducer: reduce functor, eg: CustomSum<T>()
289- * reduce_last_dim: if in's last dim need to be reduce then reduce_last_dim =
290- * true
332+ * @brief The Reduce provides collective methods for computing a parallel
333+ * reduction of items partitioned across a CUDA block and intra thread. When
334+ * ReduceMode == kLocalMode, thread reduce along nx. When ReduceMode ==
335+ * kGlobalMode, use shared memory to reduce between threads.
336+ *
337+ * @template paraments
338+ * T: The type of data.
339+ * NX: The number of data continuously loaded by each thread.
340+ * NY: The number of data rows loaded by each thread, only NY = 1 was supported.
341+ * BlockSize: Identifies the current device thread index method. For GPU,
342+ * threadIdx.x is used as the thread index, and for xpu, core_id() is used as
343+ * the index. Currently only GPU was supported.
344+ * ReduceFunctor: Compute functor which has an operator() as following
345+ * template <typename InT>
346+ * struct ReduceFunctor {
347+ * HOSTDEVICE InT operator()(const InT& a, const InT& b) const {
348+ * return ...;
349+ * }
350+ * };
351+ * ReduceMode: Reduce mode, can be kLocalMode, kGlobalMode.
352+ *
353+ * @param
354+ * out: The register pointer of out, the size is NX * NY.
355+ * in: The register pointer of in, the size is NX * NY.
356+ * reducer: Compute function which was declared like ReduceFunctor<InT>().
357+ * reduce_last_dim: if the last dim gets involved in reduction.
291358 */
292- template <typename T, int NX, int NY, int BlockSize, class OpFunc ,
359+ template <typename T, int NX, int NY, int BlockSize, class ReduceFunctor ,
293360 details::ReduceMode Mode>
294- __device__ __forceinline__ void Reduce (T* out, const T* in, OpFunc reducer,
361+ __device__ __forceinline__ void Reduce (T* out, const T* in,
362+ ReduceFunctor reducer,
295363 bool reduce_last_dim) {
296364 int block_index = blockDim.y ;
297365
@@ -302,15 +370,15 @@ __device__ __forceinline__ void Reduce(T* out, const T* in, OpFunc reducer,
302370 if (block_reduce_y) {
303371#pragma unroll
304372 for (int i = 0 ; i < NY * NX; i++) { // reduce along blockdim.y
305- out[i] = details::BlockYReduce<T, OpFunc >(out[i], reducer);
373+ out[i] = details::BlockYReduce<T, ReduceFunctor >(out[i], reducer);
306374 }
307375 }
308376
309377 // when last dimension need to be reduced
310378 if (reduce_last_dim) {
311379#pragma unroll
312380 for (int i = 0 ; i < NY * NX; i++) { // reduce along blockDim.x
313- out[i] = details::BlockXReduce<T, OpFunc >(out[i], reducer);
381+ out[i] = details::BlockXReduce<T, ReduceFunctor >(out[i], reducer);
314382 }
315383 }
316384 } else { // else kLocalMode
0 commit comments