@@ -25,14 +25,12 @@ namespace operators {
2525using Tensor = framework::Tensor;
2626
2727template <typename T, typename Index, int DimTileSize>
28- __global__ void SegmentMeanCustomKernel (
29- const Index* segment_ids, const T* input, T* output, T* summed_ids,
30- const Index input_length_size, const Index inner_dim_size,
31- const Index output_length_size, const Index total_stripe_count) {
28+ __global__ void SegmentSumIdsKernel (const Index* segment_ids, T* summed_ids,
29+ const Index input_length_size,
30+ const Index total_stripe_count) {
3231 CUDA_KERNEL_LOOP (stripe_index, total_stripe_count) {
33- const Index segment_offset = stripe_index % inner_dim_size;
34- const Index dim_index_base =
35- stripe_index / inner_dim_size * Index (DimTileSize);
32+ const Index segment_offset = stripe_index;
33+ const Index dim_index_base = stripe_index * Index (DimTileSize);
3634 const Index actual_height =
3735 min (Index (DimTileSize), input_length_size - dim_index_base);
3836
@@ -41,53 +39,81 @@ __global__ void SegmentMeanCustomKernel(
4139 if (dim_index_base > 0 ) {
4240 last_segment_id = segment_ids[dim_index_base - 1 ];
4341 }
44- if (segment_offset == 0 ) {
45- T sum = T (0 );
46- for (Index j = 0 ; j < actual_height; j++) {
47- Index current_segment_id = segment_ids[dim_index_base + j];
48- // Note(ZHUI): following check may cause
49- // cudaErrorLaunchOutOfResources.
50- // PADDLE_ENFORCE(current_segment_id >= last_segment_id,
51- // "the segment ids should be sorted, but got "
52- // "segment_ids[%d]:%d > segment_ids[%d]:%d.",
53- // dim_index_base + j - 1, dim_index_base + j,
54- // last_segment_id, current_segment_id);
55-
56- if (j > 0 && current_segment_id > last_segment_id) {
42+ T sum = T (0 );
43+ for (Index j = 0 ; j < actual_height; j++) {
44+ Index current_segment_id = segment_ids[dim_index_base + j];
45+ PADDLE_ENFORCE (current_segment_id >= last_segment_id,
46+ " the segment ids should be sorted, but got "
47+ " segment_ids[%d]:%d > segment_ids[%d]:%d." ,
48+ dim_index_base + j - 1 , dim_index_base + j,
49+ last_segment_id, current_segment_id);
50+ if (current_segment_id > last_segment_id) {
51+ for (Index interval_id = last_segment_id + 1 ;
52+ interval_id < current_segment_id; ++interval_id) {
53+ *(summed_ids + interval_id) = 0 ;
54+ }
55+ if (j > 0 ) {
5756 if (last_segment_id == first_segment_id) {
5857 platform::CudaAtomicAdd (summed_ids + last_segment_id, sum);
5958 } else {
6059 *(summed_ids + last_segment_id) = sum;
6160 }
6261 sum = T (0 );
6362 }
64- sum += T (1 );
65- last_segment_id = current_segment_id;
6663 }
67- platform::CudaAtomicAdd (summed_ids + last_segment_id, sum);
64+ sum += T (1 );
65+ last_segment_id = current_segment_id;
66+ }
67+ platform::CudaAtomicAdd (summed_ids + last_segment_id, sum);
68+ }
69+ }
70+
71+ template <typename T, typename Index, int DimTileSize>
72+ __global__ void SegmentMeanKernel (const Index* segment_ids, const T* input,
73+ T* output, T* summed_ids,
74+ const Index input_length_size,
75+ const Index inner_dim_size,
76+ const Index output_length_size,
77+ const Index total_stripe_count) {
78+ CUDA_KERNEL_LOOP (stripe_index, total_stripe_count) {
79+ const Index segment_offset = stripe_index % inner_dim_size;
80+ const Index dim_index_base =
81+ stripe_index / inner_dim_size * Index (DimTileSize);
82+ const Index actual_height =
83+ min (Index (DimTileSize), input_length_size - dim_index_base);
84+
85+ Index first_segment_id = segment_ids[dim_index_base];
86+ Index last_segment_id = -1 ;
87+ if (dim_index_base > 0 ) {
88+ last_segment_id = segment_ids[dim_index_base - 1 ];
6889 }
69- // ensure last_segment_id is the largest
70- last_segment_id = output_length_size;
71- __syncthreads ();
7290 T sum = T (0 );
7391 for (Index j = 0 ; j < actual_height; j++) {
7492 Index current_segment_id = segment_ids[dim_index_base + j];
7593 if (current_segment_id > last_segment_id) {
76- const Index output_index =
77- last_segment_id * inner_dim_size + segment_offset;
78- if (last_segment_id == first_segment_id) {
79- platform::CudaAtomicAdd (output + output_index,
80- sum / *(summed_ids + last_segment_id));
81- } else {
82- *(output + output_index) = sum / *(summed_ids + last_segment_id);
94+ // reset the interval value which do not have corresponding ids.
95+ for (Index interval_id = last_segment_id + 1 ;
96+ interval_id < current_segment_id; ++interval_id) {
97+ *(output + interval_id * inner_dim_size + segment_offset) = T (0 );
98+ }
99+
100+ if (j > 0 ) {
101+ Index output_index =
102+ last_segment_id * inner_dim_size + segment_offset;
103+
104+ if (last_segment_id == first_segment_id) {
105+ platform::CudaAtomicAdd (output + output_index,
106+ sum / *(summed_ids + last_segment_id));
107+ } else {
108+ *(output + output_index) = sum / *(summed_ids + last_segment_id);
109+ }
110+ sum = T (0 );
83111 }
84- sum = T (0 );
85112 }
86113 sum += input[(dim_index_base + j) * inner_dim_size + segment_offset];
87114 last_segment_id = current_segment_id;
88115 }
89- const Index output_index =
90- last_segment_id * inner_dim_size + segment_offset;
116+ Index output_index = last_segment_id * inner_dim_size + segment_offset;
91117 platform::CudaAtomicAdd (output + output_index,
92118 sum / *(summed_ids + last_segment_id));
93119 }
@@ -122,7 +148,7 @@ __global__ void SegmentOpsKernel(const Index* segment_ids, const T* input,
122148 // reset the interval value which do not have corresponding ids.
123149 for (Index interval_id = last_segment_id + 1 ;
124150 interval_id < current_segment_id; ++interval_id) {
125- *(output + interval_id * inner_dim_size + segment_offset) = 0 ;
151+ *(output + interval_id * inner_dim_size + segment_offset) = T ( 0 ) ;
126152 }
127153 // don't update result when j=0
128154 if (j > 0 ) {
@@ -272,11 +298,25 @@ class SegmentPoolFunctor<platform::CUDADeviceContext, T, IndexT> {
272298 framework::Tensor* output,
273299 framework::Tensor* summed_ids = nullptr ,
274300 const std::string pooltype = " SUM" ) {
301+ if (pooltype == " MEAN" ) {
302+ // Sum the segment id num first
303+ T DimTileSize = 8 ;
304+ auto input_length_size = segment_ids.numel ();
305+ auto total_stripe_count =
306+ (input_length_size + DimTileSize - 1 ) / DimTileSize;
307+ auto config = platform::GetGpuLaunchConfig1D (ctx, total_stripe_count);
308+ SegmentSumIdsKernel<
309+ T, IndexT, IndexT (8 )><<<config.block_per_grid.x,
310+ config.thread_per_block.x, 0 , ctx.stream()>>> (
311+ segment_ids.data <IndexT>(), summed_ids->data <T>(), input_length_size,
312+ total_stripe_count);
313+ }
314+
275315 auto h = ArrangeHelper<IndexT>(input.numel (), segment_ids.dims ()[0 ],
276316 output->dims ()[0 ]);
277317 auto config = platform::GetGpuLaunchConfig1D (ctx, h.total_stripe_count );
278318 if (pooltype == " MEAN" ) {
279- SegmentMeanCustomKernel <
319+ SegmentMeanKernel <
280320 T, IndexT, IndexT (8 )><<<config.block_per_grid.x,
281321 config.thread_per_block.x, 0 , ctx.stream()>>> (
282322 segment_ids.data <IndexT>(), input.data <T>(), output->data <T>(),
0 commit comments