Skip to content

Commit 9d8de1a

Browse files
committed
[OPs] Bug fix, fix the segment mean for illegal syncthreads usage. (PaddlePaddle#32596)
* [OPs] Bug fix, fix the segment mean for illegal syncthreads usage.
1 parent 1515892 commit 9d8de1a

File tree

1 file changed

+78
-38
lines changed

1 file changed

+78
-38
lines changed

paddle/fluid/operators/math/segment_pooling.cu

Lines changed: 78 additions & 38 deletions
Original file line numberDiff line numberDiff line change
@@ -25,14 +25,12 @@ namespace operators {
2525
using Tensor = framework::Tensor;
2626

2727
template <typename T, typename Index, int DimTileSize>
28-
__global__ void SegmentMeanCustomKernel(
29-
const Index* segment_ids, const T* input, T* output, T* summed_ids,
30-
const Index input_length_size, const Index inner_dim_size,
31-
const Index output_length_size, const Index total_stripe_count) {
28+
__global__ void SegmentSumIdsKernel(const Index* segment_ids, T* summed_ids,
29+
const Index input_length_size,
30+
const Index total_stripe_count) {
3231
CUDA_KERNEL_LOOP(stripe_index, total_stripe_count) {
33-
const Index segment_offset = stripe_index % inner_dim_size;
34-
const Index dim_index_base =
35-
stripe_index / inner_dim_size * Index(DimTileSize);
32+
const Index segment_offset = stripe_index;
33+
const Index dim_index_base = stripe_index * Index(DimTileSize);
3634
const Index actual_height =
3735
min(Index(DimTileSize), input_length_size - dim_index_base);
3836

@@ -41,53 +39,81 @@ __global__ void SegmentMeanCustomKernel(
4139
if (dim_index_base > 0) {
4240
last_segment_id = segment_ids[dim_index_base - 1];
4341
}
44-
if (segment_offset == 0) {
45-
T sum = T(0);
46-
for (Index j = 0; j < actual_height; j++) {
47-
Index current_segment_id = segment_ids[dim_index_base + j];
48-
// Note(ZHUI): following check may cause
49-
// cudaErrorLaunchOutOfResources.
50-
// PADDLE_ENFORCE(current_segment_id >= last_segment_id,
51-
// "the segment ids should be sorted, but got "
52-
// "segment_ids[%d]:%d > segment_ids[%d]:%d.",
53-
// dim_index_base + j - 1, dim_index_base + j,
54-
// last_segment_id, current_segment_id);
55-
56-
if (j > 0 && current_segment_id > last_segment_id) {
42+
T sum = T(0);
43+
for (Index j = 0; j < actual_height; j++) {
44+
Index current_segment_id = segment_ids[dim_index_base + j];
45+
PADDLE_ENFORCE(current_segment_id >= last_segment_id,
46+
"the segment ids should be sorted, but got "
47+
"segment_ids[%d]:%d > segment_ids[%d]:%d.",
48+
dim_index_base + j - 1, dim_index_base + j,
49+
last_segment_id, current_segment_id);
50+
if (current_segment_id > last_segment_id) {
51+
for (Index interval_id = last_segment_id + 1;
52+
interval_id < current_segment_id; ++interval_id) {
53+
*(summed_ids + interval_id) = 0;
54+
}
55+
if (j > 0) {
5756
if (last_segment_id == first_segment_id) {
5857
platform::CudaAtomicAdd(summed_ids + last_segment_id, sum);
5958
} else {
6059
*(summed_ids + last_segment_id) = sum;
6160
}
6261
sum = T(0);
6362
}
64-
sum += T(1);
65-
last_segment_id = current_segment_id;
6663
}
67-
platform::CudaAtomicAdd(summed_ids + last_segment_id, sum);
64+
sum += T(1);
65+
last_segment_id = current_segment_id;
66+
}
67+
platform::CudaAtomicAdd(summed_ids + last_segment_id, sum);
68+
}
69+
}
70+
71+
template <typename T, typename Index, int DimTileSize>
72+
__global__ void SegmentMeanKernel(const Index* segment_ids, const T* input,
73+
T* output, T* summed_ids,
74+
const Index input_length_size,
75+
const Index inner_dim_size,
76+
const Index output_length_size,
77+
const Index total_stripe_count) {
78+
CUDA_KERNEL_LOOP(stripe_index, total_stripe_count) {
79+
const Index segment_offset = stripe_index % inner_dim_size;
80+
const Index dim_index_base =
81+
stripe_index / inner_dim_size * Index(DimTileSize);
82+
const Index actual_height =
83+
min(Index(DimTileSize), input_length_size - dim_index_base);
84+
85+
Index first_segment_id = segment_ids[dim_index_base];
86+
Index last_segment_id = -1;
87+
if (dim_index_base > 0) {
88+
last_segment_id = segment_ids[dim_index_base - 1];
6889
}
69-
// ensure last_segment_id is the largest
70-
last_segment_id = output_length_size;
71-
__syncthreads();
7290
T sum = T(0);
7391
for (Index j = 0; j < actual_height; j++) {
7492
Index current_segment_id = segment_ids[dim_index_base + j];
7593
if (current_segment_id > last_segment_id) {
76-
const Index output_index =
77-
last_segment_id * inner_dim_size + segment_offset;
78-
if (last_segment_id == first_segment_id) {
79-
platform::CudaAtomicAdd(output + output_index,
80-
sum / *(summed_ids + last_segment_id));
81-
} else {
82-
*(output + output_index) = sum / *(summed_ids + last_segment_id);
94+
// reset the interval value which do not have corresponding ids.
95+
for (Index interval_id = last_segment_id + 1;
96+
interval_id < current_segment_id; ++interval_id) {
97+
*(output + interval_id * inner_dim_size + segment_offset) = T(0);
98+
}
99+
100+
if (j > 0) {
101+
Index output_index =
102+
last_segment_id * inner_dim_size + segment_offset;
103+
104+
if (last_segment_id == first_segment_id) {
105+
platform::CudaAtomicAdd(output + output_index,
106+
sum / *(summed_ids + last_segment_id));
107+
} else {
108+
*(output + output_index) = sum / *(summed_ids + last_segment_id);
109+
}
110+
sum = T(0);
83111
}
84-
sum = T(0);
85112
}
86113
sum += input[(dim_index_base + j) * inner_dim_size + segment_offset];
87114
last_segment_id = current_segment_id;
88115
}
89-
const Index output_index =
90-
last_segment_id * inner_dim_size + segment_offset;
116+
Index output_index = last_segment_id * inner_dim_size + segment_offset;
91117
platform::CudaAtomicAdd(output + output_index,
92118
sum / *(summed_ids + last_segment_id));
93119
}
@@ -122,7 +148,7 @@ __global__ void SegmentOpsKernel(const Index* segment_ids, const T* input,
122148
// reset the interval value which do not have corresponding ids.
123149
for (Index interval_id = last_segment_id + 1;
124150
interval_id < current_segment_id; ++interval_id) {
125-
*(output + interval_id * inner_dim_size + segment_offset) = 0;
151+
*(output + interval_id * inner_dim_size + segment_offset) = T(0);
126152
}
127153
// don't update result when j=0
128154
if (j > 0) {
@@ -272,11 +298,25 @@ class SegmentPoolFunctor<platform::CUDADeviceContext, T, IndexT> {
272298
framework::Tensor* output,
273299
framework::Tensor* summed_ids = nullptr,
274300
const std::string pooltype = "SUM") {
301+
if (pooltype == "MEAN") {
302+
// Sum the segment id num first
303+
T DimTileSize = 8;
304+
auto input_length_size = segment_ids.numel();
305+
auto total_stripe_count =
306+
(input_length_size + DimTileSize - 1) / DimTileSize;
307+
auto config = platform::GetGpuLaunchConfig1D(ctx, total_stripe_count);
308+
SegmentSumIdsKernel<
309+
T, IndexT, IndexT(8)><<<config.block_per_grid.x,
310+
config.thread_per_block.x, 0, ctx.stream()>>>(
311+
segment_ids.data<IndexT>(), summed_ids->data<T>(), input_length_size,
312+
total_stripe_count);
313+
}
314+
275315
auto h = ArrangeHelper<IndexT>(input.numel(), segment_ids.dims()[0],
276316
output->dims()[0]);
277317
auto config = platform::GetGpuLaunchConfig1D(ctx, h.total_stripe_count);
278318
if (pooltype == "MEAN") {
279-
SegmentMeanCustomKernel<
319+
SegmentMeanKernel<
280320
T, IndexT, IndexT(8)><<<config.block_per_grid.x,
281321
config.thread_per_block.x, 0, ctx.stream()>>>(
282322
segment_ids.data<IndexT>(), input.data<T>(), output->data<T>(),

0 commit comments

Comments
 (0)