From 856224278a01611b299109d9da01b40925e42da0 Mon Sep 17 00:00:00 2001 From: Guo Xiangmin Date: Wed, 14 May 2025 04:50:42 +0000 Subject: [PATCH 1/3] fix bincount kernel for big tensor --- paddle/phi/kernels/cpu/bincount_kernel.cc | 6 +- paddle/phi/kernels/gpu/bincount_kernel.cu | 105 ++++++++++++++++------ 2 files changed, 79 insertions(+), 32 deletions(-) diff --git a/paddle/phi/kernels/cpu/bincount_kernel.cc b/paddle/phi/kernels/cpu/bincount_kernel.cc index 1d091aaf4b9694..b10895ed36d043 100644 --- a/paddle/phi/kernels/cpu/bincount_kernel.cc +++ b/paddle/phi/kernels/cpu/bincount_kernel.cc @@ -24,7 +24,7 @@ template void BincountInner(const Context& dev_ctx, const DenseTensor& x, const paddle::optional& weights, - int minlength, + int64_t minlength, DenseTensor* out) { const DenseTensor* input = &x; DenseTensor* output = out; @@ -48,7 +48,7 @@ void BincountInner(const Context& dev_ctx, int64_t output_size = static_cast(*std::max_element( input_data, input_data + input_numel)) + 1L; - output_size = std::max(output_size, static_cast(minlength)); + output_size = std::max(output_size, minlength); phi::DDim out_dim{output_size}; output->Resize(out_dim); @@ -89,7 +89,7 @@ void BincountKernel(const Context& dev_ctx, const paddle::optional& weights, const Scalar& minlength, DenseTensor* out) { - int int_minlength = minlength.to(); + int64_t int_minlength = minlength.to(); PADDLE_ENFORCE_GE(int_minlength, 0, common::errors::InvalidArgument( diff --git a/paddle/phi/kernels/gpu/bincount_kernel.cu b/paddle/phi/kernels/gpu/bincount_kernel.cu index d8fb88b0a9b3e2..a24a98994f5150 100644 --- a/paddle/phi/kernels/gpu/bincount_kernel.cu +++ b/paddle/phi/kernels/gpu/bincount_kernel.cu @@ -24,22 +24,64 @@ namespace phi { using phi::PADDLE_CUDA_NUM_THREADS; -inline int GET_BLOCKS(const int N) { +inline int64_t GET_BLOCKS(const int64_t N) { return (N + PADDLE_CUDA_NUM_THREADS - 1) / PADDLE_CUDA_NUM_THREADS; } +template +__global__ void KernelReduceMinMax(const T* input, + int64_t numel, + T* min_out, + T* max_out) { + __shared__ T smin[PADDLE_CUDA_NUM_THREADS]; + __shared__ T smax[PADDLE_CUDA_NUM_THREADS]; + int tid = threadIdx.x; + int64_t global_thread_id = + static_cast(blockIdx.x) * blockDim.x + threadIdx.x; + int64_t stride = static_cast(gridDim.x) * blockDim.x; + + T local_min = std::numeric_limits::max(); + T local_max = std::numeric_limits::lowest(); + + for (int64_t i = global_thread_id; i < numel; i += stride) { + T val = input[i]; + local_min = min(local_min, val); + local_max = max(local_max, val); + } + + smin[tid] = local_min; + smax[tid] = local_max; + __syncthreads(); + + for (int offset = blockDim.x / 2; offset > 0; offset >>= 1) { + if (tid < offset) { + smin[tid] = min(smin[tid], smin[tid + offset]); + smax[tid] = max(smax[tid], smax[tid + offset]); + } + __syncthreads(); + } + + if (tid == 0) { + phi::CudaAtomicMin(min_out, smin[0]); + phi::CudaAtomicMax(max_out, smax[0]); + } +} + template __global__ void KernelBincount(const InputT* input, - const int total_elements, + const int64_t total_elements, const bool has_weights, const T* weights, OutT* output) { - int tid = blockIdx.x * blockDim.x + threadIdx.x; - if (tid < total_elements) { + int64_t global_tid = + static_cast(blockIdx.x) * blockDim.x + threadIdx.x; + int64_t stride = static_cast(gridDim.x) * blockDim.x; + for (int64_t i = global_tid; i < total_elements; i += stride) { + InputT index = input[i]; if (!has_weights) { - phi::CudaAtomicAdd(&output[input[tid]], 1L); + phi::CudaAtomicAdd(&output[index], 1L); } else { - phi::CudaAtomicAdd(&output[input[tid]], static_cast(weights[tid])); + phi::CudaAtomicAdd(&output[index], static_cast(weights[i])); } } } @@ -48,13 +90,13 @@ template void BincountCUDAInner(const Context& dev_ctx, const DenseTensor& x, const paddle::optional& weights, - int minlength, + int64_t minlength, DenseTensor* out) { const DenseTensor* input = &x; DenseTensor* output = out; const InputT* input_data = input->data(); - const int input_numel = input->numel(); + int64_t input_numel = static_cast(input->numel()); if (input_data == nullptr) { phi::DDim out_dim{0}; @@ -62,25 +104,30 @@ void BincountCUDAInner(const Context& dev_ctx, dev_ctx.template Alloc(output); return; } - auto input_x = EigenVector::Flatten(*input); - DenseTensor input_min_t, input_max_t; - input_max_t.Resize({1}); - auto* input_max_data = dev_ctx.template Alloc(&input_max_t); - input_min_t.Resize({1}); - auto* input_min_data = dev_ctx.template Alloc(&input_min_t); - auto input_max_scala = EigenScalar::From(input_max_t); - auto input_min_scala = EigenScalar::From(input_min_t); + DenseTensor input_min_max_cpu; + input_min_max_cpu.Resize({2}); + input_min_max_cpu.mutable_data(phi::CPUPlace()); + input_min_max_cpu.data()[0] = std::numeric_limits::max(); + input_min_max_cpu.data()[1] = std::numeric_limits::lowest(); + + DenseTensor input_min_max_t; + input_min_max_t.Resize({2}); + auto* input_min_max_data = dev_ctx.template Alloc(&input_min_max_t); + + phi::Copy( + dev_ctx, input_min_max_cpu, dev_ctx.GetPlace(), true, &input_min_max_t); - auto* place = dev_ctx.eigen_device(); - input_max_scala.device(*place) = input_x.maximum(); - input_min_scala.device(*place) = input_x.minimum(); + int64_t max_grid_x = dev_ctx.GetCUDAMaxGridDimSize()[0]; + int64_t num_blocks = std::min(GET_BLOCKS(input_numel), max_grid_x); + KernelReduceMinMax + <<>>( + input_data, input_numel, input_min_max_data, input_min_max_data + 1); - DenseTensor input_min_cpu, input_max_cpu; - phi::Copy(dev_ctx, input_min_t, phi::CPUPlace(), true, &input_min_cpu); - phi::Copy(dev_ctx, input_max_t, phi::CPUPlace(), true, &input_max_cpu); + phi::Copy( + dev_ctx, input_min_max_t, phi::CPUPlace(), true, &input_min_max_cpu); - InputT input_min = input_min_cpu.data()[0]; + InputT input_min = input_min_max_cpu.data()[0]; PADDLE_ENFORCE_GE( input_min, @@ -89,9 +136,9 @@ void BincountCUDAInner(const Context& dev_ctx, "The elements in input tensor must be non-negative ints")); int64_t output_size = - static_cast(input_max_cpu.data()[0]) + 1L; + static_cast(input_min_max_cpu.data()[1]) + 1L; - output_size = std::max(output_size, static_cast(minlength)); + output_size = std::max(output_size, minlength); phi::DDim out_dim{output_size}; output->Resize(out_dim); @@ -106,7 +153,7 @@ void BincountCUDAInner(const Context& dev_ctx, dev_ctx, output, static_cast(0)); KernelBincount - <<>>( + <<>>( input_data, input_numel, has_weights, weights_data, output_data); } else { if (weights->dtype() == DataType::FLOAT32) { @@ -115,14 +162,14 @@ void BincountCUDAInner(const Context& dev_ctx, dev_ctx, output, static_cast(0)); KernelBincount - <<>>( + <<>>( input_data, input_numel, has_weights, weights_data, output_data); } else { double* output_data = dev_ctx.template Alloc(output); phi::funcs::SetConstant()( dev_ctx, output, static_cast(0)); KernelBincount - <<>>( + <<>>( input_data, input_numel, has_weights, weights_data, output_data); } } @@ -134,7 +181,7 @@ void BincountKernel(const Context& dev_ctx, const paddle::optional& weights, const Scalar& minlength, DenseTensor* out) { - int int_minlength = minlength.to(); + int64_t int_minlength = minlength.to(); PADDLE_ENFORCE_GE(int_minlength, 0, common::errors::InvalidArgument( From 8cb73abd4c9cc72baf46889bb75e0a2c15c69a05 Mon Sep 17 00:00:00 2001 From: Guo Xiangmin Date: Fri, 16 May 2025 02:38:26 +0000 Subject: [PATCH 2/3] use HostAlloc to alloc memory --- paddle/phi/kernels/gpu/bincount_kernel.cu | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/paddle/phi/kernels/gpu/bincount_kernel.cu b/paddle/phi/kernels/gpu/bincount_kernel.cu index a24a98994f5150..45733c3356093d 100644 --- a/paddle/phi/kernels/gpu/bincount_kernel.cu +++ b/paddle/phi/kernels/gpu/bincount_kernel.cu @@ -107,7 +107,8 @@ void BincountCUDAInner(const Context& dev_ctx, DenseTensor input_min_max_cpu; input_min_max_cpu.Resize({2}); - input_min_max_cpu.mutable_data(phi::CPUPlace()); + auto* input_min_max_cpu_data = + dev_ctx.template HostAlloc(&input_min_max_cpu); input_min_max_cpu.data()[0] = std::numeric_limits::max(); input_min_max_cpu.data()[1] = std::numeric_limits::lowest(); From d0bd4a47601ea766d3e9e2232eceda093cac8b87 Mon Sep 17 00:00:00 2001 From: Guo Xiangmin Date: Mon, 19 May 2025 11:24:45 +0000 Subject: [PATCH 3/3] add cpu test case --- test/legacy_test/test_bincount_op.py | 12 ++++++++++++ 1 file changed, 12 insertions(+) diff --git a/test/legacy_test/test_bincount_op.py b/test/legacy_test/test_bincount_op.py index 8aa8aa0aa6d992..514dbce38c30a7 100644 --- a/test/legacy_test/test_bincount_op.py +++ b/test/legacy_test/test_bincount_op.py @@ -71,6 +71,18 @@ def test_dygraph(self): msg='bincount output is wrong, out =' + str(actual.numpy()), ) + def test_dygraph_cpu(self): + with base.dygraph.guard(): + paddle.device.set_device('cpu') + inputs_np = np.array([0, 1, 1, 3, 2, 1, 7]).astype(np.int64) + inputs = paddle.to_tensor(inputs_np) + actual = paddle.bincount(inputs) + expected = np.bincount(inputs) + self.assertTrue( + (actual.numpy() == expected).all(), + msg='bincount output is wrong, out =' + str(actual.numpy()), + ) + class TestBincountOpError(unittest.TestCase): """Test bincount op error."""