From b8e1db01b5a377857306c2f60e7490c436561f6a Mon Sep 17 00:00:00 2001 From: zhangkaihuo Date: Mon, 12 Dec 2022 03:29:24 +0000 Subject: [PATCH 1/5] remove sort --- paddle/phi/kernels/sparse/gpu/conv.cu.h | 142 ++++++++++++++++++++---- 1 file changed, 122 insertions(+), 20 deletions(-) diff --git a/paddle/phi/kernels/sparse/gpu/conv.cu.h b/paddle/phi/kernels/sparse/gpu/conv.cu.h index 8618171b8f905a..180332e0c1d92f 100644 --- a/paddle/phi/kernels/sparse/gpu/conv.cu.h +++ b/paddle/phi/kernels/sparse/gpu/conv.cu.h @@ -15,8 +15,8 @@ limitations under the License. */ #pragma once #include -#include #include +#include #include "paddle/phi/kernels/sparse/conv_kernel.h" #include "paddle/phi/backends/gpu/gpu_context.h" @@ -199,23 +199,107 @@ __global__ void UniqueKernel(const IntT* in_indexs, } } +inline __device__ uint32_t BitCount(const uint32_t data) { + uint32_t n = data; + n = (n &0x55555555) + ((n >>1) &0x55555555); + n = (n &0x33333333) + ((n >>2) &0x33333333); + n = (n &0x0f0f0f0f) + ((n >>4) &0x0f0f0f0f); + n = (n &0x00ff00ff) + ((n >>8) &0x00ff00ff); + n = (n &0x0000ffff) + ((n >>16) &0x0000ffff); + return n; +} + +static __global__ void GetOutIndexsCounter( + const int* flags, const int n, int* out) { + int tid = threadIdx.x + blockDim.x * blockIdx.x; + __shared__ int block_count; + if(threadIdx.x == 0) { + block_count = 0; + } + __syncthreads(); + + if(tid < n) { + // get the count of 1 in flags[tid] + uint32_t count = BitCount(static_cast(flags[tid])); + // add to block_count + atomicAdd(&block_count, static_cast(count)); + } + __syncthreads(); + // write to out + if(threadIdx.x == 0) { + out[blockIdx.x] = block_count; + } +} + +template +__global__ void GetOutIndexs(const int* flags, + const int n, + const int* offsets, + const int out_nnz, + int* out) { + int tid = threadIdx.x + blockDim.x * blockIdx.x; + __shared__ int block_counts[BS]; + __shared__ int block_outs[BS * 32]; + + // block_counts[threadIdx.x] = 0; + int count = 0; + + if(tid < n) { + // get the count of 1 in flags[tid] + int flag = flags[tid]; + count = BitCount(static_cast(flag)); + // block_counts[threadIdx.x] = count; + } + + // call block prefix_sum + // using namespace cub; + typedef cub::BlockScan BlockScan; + __shared__ typename BlockScan::TempStorage temp_storage; + BlockScan(temp_storage).ExclusiveSum(count, count); + __syncthreads(); + + // block_counts[threadIdx.x] = count; + // write index to out + + if(tid < n) { + // get the count of 1 in flags[tid] + int flag = flags[tid]; + // int j = block_counts[threadIdx.x]; + int j = count; + // TODO(zhangkaihuo): opt the loop + for(int i = 0; i < 32; ++i) { + if((1 & (flag >> i)) == 1) { + block_outs[j++] = (tid << 5) + i; + } + } + } + + __syncthreads(); + // write to block_outs + int start = offsets[blockIdx.x]; + int end = blockIdx.x == gridDim.x-1 ? out_nnz : offsets[blockIdx.x + 1]; + for(int i = threadIdx.x; i < end-start; i+=blockDim.x) { + out[start + i] = block_outs[i]; + } +} + template __global__ void GroupIndexs(const int* out_index_table, - const int n, - const int kernel_size, - IntT* out_indexs, - int* out_index_counts, - int* out_index_groups) { - CUDA_KERNEL_LOOP_TYPE(i, n, int64_t) { - IntT index = out_indexs[i]; - int real_index = out_index_table[index]; - out_indexs[i] = real_index; - - // kernel_size at most - int j = atomicAdd(out_index_counts + real_index, 1); - // nnz * kernel_size - out_index_groups[real_index * kernel_size + j] = i; - } + const int n, + const int kernel_size, + IntT* out_indexs, + int* out_index_counts, + int* out_index_groups) { + CUDA_KERNEL_LOOP_TYPE(i, n, int64_t) { + IntT index = out_indexs[i]; + int real_index = out_index_table[index]; + out_indexs[i] = real_index; + + // kernel_size at most + int j = atomicAdd(out_index_counts + real_index, 1); + // nnz * kernel_size + out_index_groups[real_index * kernel_size + j] = i; + } } /** @@ -725,13 +809,31 @@ int ProductRuleBook(const Context& dev_ctx, gpuMemcpyDeviceToHost, dev_ctx.stream()); dev_ctx.Wait(); +// #ifdef PADDLE_WITH_HIP +// thrust::sort(thrust::hip::par.on(dev_ctx.stream()), +// #else +// thrust::sort(thrust::cuda::par.on(dev_ctx.stream()), +// #endif +// out_index_ptr, +// out_index_ptr + out_nnz); + + if(true) { + const int threads = 256; + const int blocks = (index_flags.numel() + threads - 1) / threads; + GetOutIndexsCounter<<>>( + index_flags_ptr, index_flags.numel(), out_index_table_ptr); #ifdef PADDLE_WITH_HIP - thrust::sort(thrust::hip::par.on(dev_ctx.stream()), + thrust::exclusive_scan(thrust::hip::par.on(dev_ctx.stream()), #else - thrust::sort(thrust::cuda::par.on(dev_ctx.stream()), + thrust::exclusive_scan(thrust::cuda::par.on(dev_ctx.stream()), #endif - out_index_ptr, - out_index_ptr + out_nnz); + out_index_table_ptr, + out_index_table_ptr + blocks, + out_index_table_ptr); + GetOutIndexs<<>>( + index_flags_ptr, index_flags.numel(), out_index_table_ptr, + out_nnz, out_index_ptr); + } const int64_t sparse_dim = 4; phi::DenseTensor out_indices = From 09f041d508682ead711dddcf20705cfd5b41e7bc Mon Sep 17 00:00:00 2001 From: zhangkaihuo Date: Mon, 12 Dec 2022 08:15:06 +0000 Subject: [PATCH 2/5] add todo --- paddle/phi/kernels/sparse/gpu/conv.cu.h | 36 +++++++++---------------- 1 file changed, 12 insertions(+), 24 deletions(-) diff --git a/paddle/phi/kernels/sparse/gpu/conv.cu.h b/paddle/phi/kernels/sparse/gpu/conv.cu.h index 180332e0c1d92f..66b7fd7129eb96 100644 --- a/paddle/phi/kernels/sparse/gpu/conv.cu.h +++ b/paddle/phi/kernels/sparse/gpu/conv.cu.h @@ -222,6 +222,7 @@ static __global__ void GetOutIndexsCounter( // get the count of 1 in flags[tid] uint32_t count = BitCount(static_cast(flags[tid])); // add to block_count + // TODO(zhangkaihuo): replace with block reduce_sum atomicAdd(&block_count, static_cast(count)); } __syncthreads(); @@ -241,14 +242,12 @@ __global__ void GetOutIndexs(const int* flags, __shared__ int block_counts[BS]; __shared__ int block_outs[BS * 32]; - // block_counts[threadIdx.x] = 0; int count = 0; if(tid < n) { // get the count of 1 in flags[tid] int flag = flags[tid]; count = BitCount(static_cast(flag)); - // block_counts[threadIdx.x] = count; } // call block prefix_sum @@ -258,9 +257,7 @@ __global__ void GetOutIndexs(const int* flags, BlockScan(temp_storage).ExclusiveSum(count, count); __syncthreads(); - // block_counts[threadIdx.x] = count; // write index to out - if(tid < n) { // get the count of 1 in flags[tid] int flag = flags[tid]; @@ -809,31 +806,22 @@ int ProductRuleBook(const Context& dev_ctx, gpuMemcpyDeviceToHost, dev_ctx.stream()); dev_ctx.Wait(); -// #ifdef PADDLE_WITH_HIP -// thrust::sort(thrust::hip::par.on(dev_ctx.stream()), -// #else -// thrust::sort(thrust::cuda::par.on(dev_ctx.stream()), -// #endif -// out_index_ptr, -// out_index_ptr + out_nnz); - - if(true) { - const int threads = 256; - const int blocks = (index_flags.numel() + threads - 1) / threads; - GetOutIndexsCounter<<>>( - index_flags_ptr, index_flags.numel(), out_index_table_ptr); + + const int threads = 256; + const int blocks = (index_flags.numel() + threads - 1) / threads; + GetOutIndexsCounter<<>>( + index_flags_ptr, index_flags.numel(), out_index_table_ptr); #ifdef PADDLE_WITH_HIP - thrust::exclusive_scan(thrust::hip::par.on(dev_ctx.stream()), + thrust::exclusive_scan(thrust::hip::par.on(dev_ctx.stream()), #else - thrust::exclusive_scan(thrust::cuda::par.on(dev_ctx.stream()), + thrust::exclusive_scan(thrust::cuda::par.on(dev_ctx.stream()), #endif - out_index_table_ptr, - out_index_table_ptr + blocks, - out_index_table_ptr); - GetOutIndexs<<>>( + out_index_table_ptr, + out_index_table_ptr + blocks, + out_index_table_ptr); + GetOutIndexs<<>>( index_flags_ptr, index_flags.numel(), out_index_table_ptr, out_nnz, out_index_ptr); - } const int64_t sparse_dim = 4; phi::DenseTensor out_indices = From 7285c9afbbf87366c86fade4c39b0efcf8d91d87 Mon Sep 17 00:00:00 2001 From: zhangkaihuo Date: Mon, 12 Dec 2022 09:28:02 +0000 Subject: [PATCH 3/5] for sm75 --- paddle/phi/kernels/sparse/gpu/conv.cu.h | 190 +++++++++--------- paddle/phi/kernels/sparse/gpu/conv_kernel.cu | 4 +- .../kernels/sparse/gpu/gather_gemm_scatter.cu | 8 +- .../kernels/sparse/gpu/gather_gemm_scatter.h | 27 ++- 4 files changed, 132 insertions(+), 97 deletions(-) diff --git a/paddle/phi/kernels/sparse/gpu/conv.cu.h b/paddle/phi/kernels/sparse/gpu/conv.cu.h index 66b7fd7129eb96..66d1a6f9d4e89e 100644 --- a/paddle/phi/kernels/sparse/gpu/conv.cu.h +++ b/paddle/phi/kernels/sparse/gpu/conv.cu.h @@ -200,103 +200,104 @@ __global__ void UniqueKernel(const IntT* in_indexs, } inline __device__ uint32_t BitCount(const uint32_t data) { - uint32_t n = data; - n = (n &0x55555555) + ((n >>1) &0x55555555); - n = (n &0x33333333) + ((n >>2) &0x33333333); - n = (n &0x0f0f0f0f) + ((n >>4) &0x0f0f0f0f); - n = (n &0x00ff00ff) + ((n >>8) &0x00ff00ff); - n = (n &0x0000ffff) + ((n >>16) &0x0000ffff); - return n; + uint32_t n = data; + n = (n & 0x55555555) + ((n >> 1) & 0x55555555); + n = (n & 0x33333333) + ((n >> 2) & 0x33333333); + n = (n & 0x0f0f0f0f) + ((n >> 4) & 0x0f0f0f0f); + n = (n & 0x00ff00ff) + ((n >> 8) & 0x00ff00ff); + n = (n & 0x0000ffff) + ((n >> 16) & 0x0000ffff); + return n; } -static __global__ void GetOutIndexsCounter( - const int* flags, const int n, int* out) { - int tid = threadIdx.x + blockDim.x * blockIdx.x; - __shared__ int block_count; - if(threadIdx.x == 0) { - block_count = 0; - } - __syncthreads(); - - if(tid < n) { - // get the count of 1 in flags[tid] - uint32_t count = BitCount(static_cast(flags[tid])); - // add to block_count - // TODO(zhangkaihuo): replace with block reduce_sum - atomicAdd(&block_count, static_cast(count)); - } - __syncthreads(); - // write to out - if(threadIdx.x == 0) { - out[blockIdx.x] = block_count; - } +static __global__ void GetOutIndexsCounter(const int* flags, + const int n, + int* out) { + int tid = threadIdx.x + blockDim.x * blockIdx.x; + __shared__ int block_count; + if (threadIdx.x == 0) { + block_count = 0; + } + __syncthreads(); + + if (tid < n) { + // get the count of 1 in flags[tid] + uint32_t count = BitCount(static_cast(flags[tid])); + // add to block_count + // TODO(zhangkaihuo): replace with block reduce_sum + atomicAdd(&block_count, static_cast(count)); + } + __syncthreads(); + // write to out + if (threadIdx.x == 0) { + out[blockIdx.x] = block_count; + } } -template +template __global__ void GetOutIndexs(const int* flags, - const int n, - const int* offsets, - const int out_nnz, - int* out) { - int tid = threadIdx.x + blockDim.x * blockIdx.x; - __shared__ int block_counts[BS]; - __shared__ int block_outs[BS * 32]; - - int count = 0; - - if(tid < n) { - // get the count of 1 in flags[tid] - int flag = flags[tid]; - count = BitCount(static_cast(flag)); - } + const int n, + const int* offsets, + const int out_nnz, + int* out) { + int tid = threadIdx.x + blockDim.x * blockIdx.x; + __shared__ int block_counts[BS]; + __shared__ int block_outs[BS * 32]; - // call block prefix_sum - // using namespace cub; - typedef cub::BlockScan BlockScan; - __shared__ typename BlockScan::TempStorage temp_storage; - BlockScan(temp_storage).ExclusiveSum(count, count); - __syncthreads(); - - // write index to out - if(tid < n) { - // get the count of 1 in flags[tid] - int flag = flags[tid]; - // int j = block_counts[threadIdx.x]; - int j = count; - // TODO(zhangkaihuo): opt the loop - for(int i = 0; i < 32; ++i) { - if((1 & (flag >> i)) == 1) { - block_outs[j++] = (tid << 5) + i; - } - } - } + int count = 0; + + if (tid < n) { + // get the count of 1 in flags[tid] + int flag = flags[tid]; + count = BitCount(static_cast(flag)); + } + + // call block prefix_sum + // using namespace cub; + typedef cub::BlockScan BlockScan; + __shared__ typename BlockScan::TempStorage temp_storage; + BlockScan(temp_storage).ExclusiveSum(count, count); + __syncthreads(); - __syncthreads(); - // write to block_outs - int start = offsets[blockIdx.x]; - int end = blockIdx.x == gridDim.x-1 ? out_nnz : offsets[blockIdx.x + 1]; - for(int i = threadIdx.x; i < end-start; i+=blockDim.x) { - out[start + i] = block_outs[i]; + // write index to out + if (tid < n) { + // get the count of 1 in flags[tid] + int flag = flags[tid]; + // int j = block_counts[threadIdx.x]; + int j = count; + // TODO(zhangkaihuo): opt the loop + for (int i = 0; i < 32; ++i) { + if ((1 & (flag >> i)) == 1) { + block_outs[j++] = (tid << 5) + i; + } } + } + + __syncthreads(); + // write to block_outs + int start = offsets[blockIdx.x]; + int end = blockIdx.x == gridDim.x - 1 ? out_nnz : offsets[blockIdx.x + 1]; + for (int i = threadIdx.x; i < end - start; i += blockDim.x) { + out[start + i] = block_outs[i]; + } } template __global__ void GroupIndexs(const int* out_index_table, - const int n, - const int kernel_size, - IntT* out_indexs, - int* out_index_counts, - int* out_index_groups) { - CUDA_KERNEL_LOOP_TYPE(i, n, int64_t) { - IntT index = out_indexs[i]; - int real_index = out_index_table[index]; - out_indexs[i] = real_index; - - // kernel_size at most - int j = atomicAdd(out_index_counts + real_index, 1); - // nnz * kernel_size - out_index_groups[real_index * kernel_size + j] = i; - } + const int n, + const int kernel_size, + IntT* out_indexs, + int* out_index_counts, + int* out_index_groups) { + CUDA_KERNEL_LOOP_TYPE(i, n, int64_t) { + IntT index = out_indexs[i]; + int real_index = out_index_table[index]; + out_indexs[i] = real_index; + + // kernel_size at most + int j = atomicAdd(out_index_counts + real_index, 1); + // nnz * kernel_size + out_index_groups[real_index * kernel_size + j] = i; + } } /** @@ -810,18 +811,21 @@ int ProductRuleBook(const Context& dev_ctx, const int threads = 256; const int blocks = (index_flags.numel() + threads - 1) / threads; GetOutIndexsCounter<<>>( - index_flags_ptr, index_flags.numel(), out_index_table_ptr); + index_flags_ptr, index_flags.numel(), out_index_table_ptr); #ifdef PADDLE_WITH_HIP thrust::exclusive_scan(thrust::hip::par.on(dev_ctx.stream()), #else - thrust::exclusive_scan(thrust::cuda::par.on(dev_ctx.stream()), + thrust::exclusive_scan(thrust::cuda::par.on(dev_ctx.stream()), #endif - out_index_table_ptr, - out_index_table_ptr + blocks, - out_index_table_ptr); - GetOutIndexs<<>>( - index_flags_ptr, index_flags.numel(), out_index_table_ptr, - out_nnz, out_index_ptr); + out_index_table_ptr, + out_index_table_ptr + blocks, + out_index_table_ptr); + GetOutIndexs + <<>>(index_flags_ptr, + index_flags.numel(), + out_index_table_ptr, + out_nnz, + out_index_ptr); const int64_t sparse_dim = 4; phi::DenseTensor out_indices = diff --git a/paddle/phi/kernels/sparse/gpu/conv_kernel.cu b/paddle/phi/kernels/sparse/gpu/conv_kernel.cu index 87037581e52f79..e6f3ca33649187 100644 --- a/paddle/phi/kernels/sparse/gpu/conv_kernel.cu +++ b/paddle/phi/kernels/sparse/gpu/conv_kernel.cu @@ -125,7 +125,7 @@ void Conv3dCooGPUKernel(const GPUContext& dev_ctx, #ifdef PADDLE_WITH_CUTLASS bool cutlass = true; - if (dev_ctx.GetComputeCapability() < 80) cutlass = false; + if (dev_ctx.GetComputeCapability() < 75) cutlass = false; if (in_channels % 4 != 0 || out_channels % 4 != 0) { if (std::is_same::value) cutlass = false; if (std::is_same::value) cutlass = false; @@ -173,7 +173,7 @@ void Conv3dCooGPUKernel(const GPUContext& dev_ctx, if constexpr (std::is_same::value && std::is_same::value) { fp32_gather_gemm_scatter gather_gemm_scatter = - getBestFp32Kernel(M, N, K); + getBestFp32Kernel(M, N, K, dev_ctx.GetComputeCapability()); gather_gemm_scatter(dev_ctx, x.non_zero_elements().data(), tmp_kernel_ptr, diff --git a/paddle/phi/kernels/sparse/gpu/gather_gemm_scatter.cu b/paddle/phi/kernels/sparse/gpu/gather_gemm_scatter.cu index 48727c8f8513df..cfbaa7f1d63068 100644 --- a/paddle/phi/kernels/sparse/gpu/gather_gemm_scatter.cu +++ b/paddle/phi/kernels/sparse/gpu/gather_gemm_scatter.cu @@ -72,7 +72,13 @@ fp16_gather_gemm_scatter getBestFp16Kernel(const int M, } fp32_gather_gemm_scatter getBestFp32Kernel(const int M, const int N, - const int K) { + const int K, + const int SM) { + if (SM == 75) { + return launchKernel< + float, + cutlass_tensorop_s1688gemm_f16_64x64_32x2_nn_align4::Gemm>; + } if (K == 4 && N == 16) { return launchKernel< float, diff --git a/paddle/phi/kernels/sparse/gpu/gather_gemm_scatter.h b/paddle/phi/kernels/sparse/gpu/gather_gemm_scatter.h index 462cd710340678..b596ff545383fe 100644 --- a/paddle/phi/kernels/sparse/gpu/gather_gemm_scatter.h +++ b/paddle/phi/kernels/sparse/gpu/gather_gemm_scatter.h @@ -66,7 +66,8 @@ fp16_gather_gemm_scatter getBestFp16Kernel(const int M, const int N); fp32_gather_gemm_scatter getBestFp32Kernel(const int M, const int K, - const int N); + const int N, + const int SM); fp64_gather_gemm_scatter getBestFp64Kernel(const int M, const int K, const int N); @@ -550,6 +551,30 @@ struct cutlass_tensorop_d884gemm_32x16_16x5_nn_align1 { false, true>; }; + +// sm75 +struct cutlass_tensorop_s1688gemm_f16_64x64_32x2_nn_align4 { + using Gemm = cutlass::gemm::device::GemmUniversal< + cutlass::half_t, + cutlass::layout::RowMajor, + cutlass::half_t, + cutlass::layout::RowMajor, + float, + cutlass::layout::RowMajor, + float, + cutlass::arch::OpClassTensorOp, + cutlass::arch::Sm75, + cutlass::gemm::GemmShape<64, 64, 32>, + cutlass::gemm::GemmShape<32, 32, 32>, + cutlass::gemm::GemmShape<16, 8, 8>, + cutlass::epilogue::thread::LinearCombination, + cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<8>, + 2, + 8, + 8, + cutlass::arch::OpMultiplyAdd>; +}; + } // namespace sparse } // namespace phi #endif From 79f097e97d650aad907b7db24ef8a564c87494dd Mon Sep 17 00:00:00 2001 From: zhangkaihuo Date: Mon, 12 Dec 2022 11:21:58 +0000 Subject: [PATCH 4/5] rename variable --- paddle/phi/kernels/sparse/gpu/conv.cu.h | 14 +++++++------- 1 file changed, 7 insertions(+), 7 deletions(-) diff --git a/paddle/phi/kernels/sparse/gpu/conv.cu.h b/paddle/phi/kernels/sparse/gpu/conv.cu.h index 66d1a6f9d4e89e..a34a2ebb8cb248 100644 --- a/paddle/phi/kernels/sparse/gpu/conv.cu.h +++ b/paddle/phi/kernels/sparse/gpu/conv.cu.h @@ -200,13 +200,13 @@ __global__ void UniqueKernel(const IntT* in_indexs, } inline __device__ uint32_t BitCount(const uint32_t data) { - uint32_t n = data; - n = (n & 0x55555555) + ((n >> 1) & 0x55555555); - n = (n & 0x33333333) + ((n >> 2) & 0x33333333); - n = (n & 0x0f0f0f0f) + ((n >> 4) & 0x0f0f0f0f); - n = (n & 0x00ff00ff) + ((n >> 8) & 0x00ff00ff); - n = (n & 0x0000ffff) + ((n >> 16) & 0x0000ffff); - return n; + uint32_t count = data; + count = (count & 0x55555555) + ((count >> 1) & 0x55555555); + count = (count & 0x33333333) + ((count >> 2) & 0x33333333); + count = (count & 0x0f0f0f0f) + ((count >> 4) & 0x0f0f0f0f); + count = (count & 0x00ff00ff) + ((count >> 8) & 0x00ff00ff); + count = (count & 0x0000ffff) + ((count >> 16) & 0x0000ffff); + return count; } static __global__ void GetOutIndexsCounter(const int* flags, From 330d63c82b698867677201ac1961a4f76fdf8eb1 Mon Sep 17 00:00:00 2001 From: zhangkaihuo Date: Tue, 13 Dec 2022 06:07:59 +0000 Subject: [PATCH 5/5] for hip --- paddle/phi/kernels/sparse/gpu/conv.cu.h | 6 ++++++ 1 file changed, 6 insertions(+) diff --git a/paddle/phi/kernels/sparse/gpu/conv.cu.h b/paddle/phi/kernels/sparse/gpu/conv.cu.h index a34a2ebb8cb248..61457e506b22d1 100644 --- a/paddle/phi/kernels/sparse/gpu/conv.cu.h +++ b/paddle/phi/kernels/sparse/gpu/conv.cu.h @@ -16,7 +16,13 @@ limitations under the License. */ #include #include +#ifdef __NVCC__ #include +#endif +#ifdef __HIPCC__ +#include +namespace cub = hipcub; +#endif #include "paddle/phi/kernels/sparse/conv_kernel.h" #include "paddle/phi/backends/gpu/gpu_context.h"