Skip to content

Commit bc0caca

Browse files
committed
fix bincount kernel for big tensor
1 parent 32d5f4b commit bc0caca

File tree

2 files changed

+79
-32
lines changed

2 files changed

+79
-32
lines changed

paddle/phi/kernels/cpu/bincount_kernel.cc

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -24,7 +24,7 @@ template <typename Context, typename T, typename InputT>
2424
void BincountInner(const Context& dev_ctx,
2525
const DenseTensor& x,
2626
const paddle::optional<DenseTensor>& weights,
27-
int minlength,
27+
int64_t minlength,
2828
DenseTensor* out) {
2929
const DenseTensor* input = &x;
3030
DenseTensor* output = out;
@@ -48,7 +48,7 @@ void BincountInner(const Context& dev_ctx,
4848
int64_t output_size = static_cast<int64_t>(*std::max_element(
4949
input_data, input_data + input_numel)) +
5050
1L;
51-
output_size = std::max(output_size, static_cast<int64_t>(minlength));
51+
output_size = std::max(output_size, minlength);
5252

5353
phi::DDim out_dim{output_size};
5454
output->Resize(out_dim);
@@ -89,7 +89,7 @@ void BincountKernel(const Context& dev_ctx,
8989
const paddle::optional<DenseTensor>& weights,
9090
const Scalar& minlength,
9191
DenseTensor* out) {
92-
int int_minlength = minlength.to<int>();
92+
int64_t int_minlength = minlength.to<int64_t>();
9393
PADDLE_ENFORCE_GE(int_minlength,
9494
0,
9595
common::errors::InvalidArgument(

paddle/phi/kernels/gpu/bincount_kernel.cu

Lines changed: 76 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -24,22 +24,64 @@ namespace phi {
2424

2525
using phi::PADDLE_CUDA_NUM_THREADS;
2626

27-
inline int GET_BLOCKS(const int N) {
27+
inline int64_t GET_BLOCKS(const int64_t N) {
2828
return (N + PADDLE_CUDA_NUM_THREADS - 1) / PADDLE_CUDA_NUM_THREADS;
2929
}
3030

31+
template <typename T>
32+
__global__ void KernelReduceMinMax(const T* input,
33+
int64_t numel,
34+
T* min_out,
35+
T* max_out) {
36+
__shared__ T smin[PADDLE_CUDA_NUM_THREADS];
37+
__shared__ T smax[PADDLE_CUDA_NUM_THREADS];
38+
int tid = threadIdx.x;
39+
int64_t global_thread_id =
40+
static_cast<int64_t>(blockIdx.x) * blockDim.x + threadIdx.x;
41+
int64_t stride = static_cast<int64_t>(gridDim.x) * blockDim.x;
42+
43+
T local_min = std::numeric_limits<T>::max();
44+
T local_max = std::numeric_limits<T>::lowest();
45+
46+
for (int64_t i = global_thread_id; i < numel; i += stride) {
47+
T val = input[i];
48+
local_min = min(local_min, val);
49+
local_max = max(local_max, val);
50+
}
51+
52+
smin[tid] = local_min;
53+
smax[tid] = local_max;
54+
__syncthreads();
55+
56+
for (int offset = blockDim.x / 2; offset > 0; offset >>= 1) {
57+
if (tid < offset) {
58+
smin[tid] = min(smin[tid], smin[tid + offset]);
59+
smax[tid] = max(smax[tid], smax[tid + offset]);
60+
}
61+
__syncthreads();
62+
}
63+
64+
if (tid == 0) {
65+
phi::CudaAtomicMin(min_out, smin[0]);
66+
phi::CudaAtomicMax(max_out, smax[0]);
67+
}
68+
}
69+
3170
template <typename T, typename InputT, typename OutT>
3271
__global__ void KernelBincount(const InputT* input,
33-
const int total_elements,
72+
const int64_t total_elements,
3473
const bool has_weights,
3574
const T* weights,
3675
OutT* output) {
37-
int tid = blockIdx.x * blockDim.x + threadIdx.x;
38-
if (tid < total_elements) {
76+
int64_t global_tid =
77+
static_cast<int64_t>(blockIdx.x) * blockDim.x + threadIdx.x;
78+
int64_t stride = static_cast<int64_t>(gridDim.x) * blockDim.x;
79+
for (int64_t i = global_tid; i < total_elements; i += stride) {
80+
InputT index = input[i];
3981
if (!has_weights) {
40-
phi::CudaAtomicAdd(&output[input[tid]], 1L);
82+
phi::CudaAtomicAdd(&output[index], 1L);
4183
} else {
42-
phi::CudaAtomicAdd(&output[input[tid]], static_cast<OutT>(weights[tid]));
84+
phi::CudaAtomicAdd(&output[index], static_cast<OutT>(weights[i]));
4385
}
4486
}
4587
}
@@ -48,39 +90,44 @@ template <typename Context, typename T, typename InputT>
4890
void BincountCUDAInner(const Context& dev_ctx,
4991
const DenseTensor& x,
5092
const paddle::optional<DenseTensor>& weights,
51-
int minlength,
93+
int64_t minlength,
5294
DenseTensor* out) {
5395
const DenseTensor* input = &x;
5496
DenseTensor* output = out;
5597
const InputT* input_data = input->data<InputT>();
5698

57-
const int input_numel = input->numel();
99+
int64_t input_numel = static_cast<int64_t>(input->numel());
58100

59101
if (input_data == nullptr) {
60102
phi::DDim out_dim{0};
61103
output->Resize(out_dim);
62104
dev_ctx.template Alloc<T>(output);
63105
return;
64106
}
65-
auto input_x = EigenVector<InputT>::Flatten(*input);
66-
DenseTensor input_min_t, input_max_t;
67-
input_max_t.Resize({1});
68-
auto* input_max_data = dev_ctx.template Alloc<InputT>(&input_max_t);
69-
input_min_t.Resize({1});
70-
auto* input_min_data = dev_ctx.template Alloc<InputT>(&input_min_t);
71107

72-
auto input_max_scala = EigenScalar<InputT>::From(input_max_t);
73-
auto input_min_scala = EigenScalar<InputT>::From(input_min_t);
108+
DenseTensor input_min_max_cpu;
109+
input_min_max_cpu.Resize({2});
110+
input_min_max_cpu.mutable_data<InputT>(phi::CPUPlace());
111+
input_min_max_cpu.data<InputT>()[0] = std::numeric_limits<InputT>::max();
112+
input_min_max_cpu.data<InputT>()[1] = std::numeric_limits<InputT>::lowest();
113+
114+
DenseTensor input_min_max_t;
115+
input_min_max_t.Resize({2});
116+
auto* input_min_max_data = dev_ctx.template Alloc<InputT>(&input_min_max_t);
117+
118+
phi::Copy(
119+
dev_ctx, input_min_max_cpu, dev_ctx.GetPlace(), true, &input_min_max_t);
74120

75-
auto* place = dev_ctx.eigen_device();
76-
input_max_scala.device(*place) = input_x.maximum();
77-
input_min_scala.device(*place) = input_x.minimum();
121+
int64_t max_grid_x = dev_ctx.GetCUDAMaxGridDimSize()[0];
122+
int64_t num_blocks = std::min(GET_BLOCKS(input_numel), max_grid_x);
123+
KernelReduceMinMax<InputT>
124+
<<<num_blocks, PADDLE_CUDA_NUM_THREADS, 0, dev_ctx.stream()>>>(
125+
input_data, input_numel, input_min_max_data, input_min_max_data + 1);
78126

79-
DenseTensor input_min_cpu, input_max_cpu;
80-
phi::Copy(dev_ctx, input_min_t, phi::CPUPlace(), true, &input_min_cpu);
81-
phi::Copy(dev_ctx, input_max_t, phi::CPUPlace(), true, &input_max_cpu);
127+
phi::Copy(
128+
dev_ctx, input_min_max_t, phi::CPUPlace(), true, &input_min_max_cpu);
82129

83-
InputT input_min = input_min_cpu.data<InputT>()[0];
130+
InputT input_min = input_min_max_cpu.data<InputT>()[0];
84131

85132
PADDLE_ENFORCE_GE(
86133
input_min,
@@ -89,9 +136,9 @@ void BincountCUDAInner(const Context& dev_ctx,
89136
"The elements in input tensor must be non-negative ints"));
90137

91138
int64_t output_size =
92-
static_cast<int64_t>(input_max_cpu.data<InputT>()[0]) + 1L;
139+
static_cast<int64_t>(input_min_max_cpu.data<InputT>()[1]) + 1L;
93140

94-
output_size = std::max(output_size, static_cast<int64_t>(minlength));
141+
output_size = std::max(output_size, minlength);
95142
phi::DDim out_dim{output_size};
96143
output->Resize(out_dim);
97144

@@ -106,7 +153,7 @@ void BincountCUDAInner(const Context& dev_ctx,
106153
dev_ctx, output, static_cast<int64_t>(0));
107154

108155
KernelBincount<T, InputT, int64_t>
109-
<<<GET_BLOCKS(input_numel), PADDLE_CUDA_NUM_THREADS, 0, stream>>>(
156+
<<<num_blocks, PADDLE_CUDA_NUM_THREADS, 0, stream>>>(
110157
input_data, input_numel, has_weights, weights_data, output_data);
111158
} else {
112159
if (weights->dtype() == DataType::FLOAT32) {
@@ -115,14 +162,14 @@ void BincountCUDAInner(const Context& dev_ctx,
115162
dev_ctx, output, static_cast<float>(0));
116163

117164
KernelBincount<T, InputT, float>
118-
<<<GET_BLOCKS(input_numel), PADDLE_CUDA_NUM_THREADS, 0, stream>>>(
165+
<<<num_blocks, PADDLE_CUDA_NUM_THREADS, 0, stream>>>(
119166
input_data, input_numel, has_weights, weights_data, output_data);
120167
} else {
121168
double* output_data = dev_ctx.template Alloc<double>(output);
122169
phi::funcs::SetConstant<Context, double>()(
123170
dev_ctx, output, static_cast<double>(0));
124171
KernelBincount<T, InputT, double>
125-
<<<GET_BLOCKS(input_numel), PADDLE_CUDA_NUM_THREADS, 0, stream>>>(
172+
<<<num_blocks, PADDLE_CUDA_NUM_THREADS, 0, stream>>>(
126173
input_data, input_numel, has_weights, weights_data, output_data);
127174
}
128175
}
@@ -134,7 +181,7 @@ void BincountKernel(const Context& dev_ctx,
134181
const paddle::optional<DenseTensor>& weights,
135182
const Scalar& minlength,
136183
DenseTensor* out) {
137-
int int_minlength = minlength.to<int>();
184+
int64_t int_minlength = minlength.to<int64_t>();
138185
PADDLE_ENFORCE_GE(int_minlength,
139186
0,
140187
common::errors::InvalidArgument(

0 commit comments

Comments
 (0)