Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
110 changes: 105 additions & 5 deletions paddle/phi/kernels/sparse/gpu/conv.cu.h
Original file line number Diff line number Diff line change
Expand Up @@ -15,8 +15,14 @@ limitations under the License. */
#pragma once

#include <thrust/remove.h>
#include <thrust/sort.h>
#include <thrust/unique.h>
#ifdef __NVCC__
#include <cub/block/block_scan.cuh>
#endif
#ifdef __HIPCC__
#include <hipcub/hipcub.hpp>
namespace cub = hipcub;
#endif
#include "paddle/phi/kernels/sparse/conv_kernel.h"

#include "paddle/phi/backends/gpu/gpu_context.h"
Expand Down Expand Up @@ -199,6 +205,88 @@ __global__ void UniqueKernel(const IntT* in_indexs,
}
}

inline __device__ uint32_t BitCount(const uint32_t data) {
uint32_t count = data;
count = (count & 0x55555555) + ((count >> 1) & 0x55555555);
count = (count & 0x33333333) + ((count >> 2) & 0x33333333);
count = (count & 0x0f0f0f0f) + ((count >> 4) & 0x0f0f0f0f);
count = (count & 0x00ff00ff) + ((count >> 8) & 0x00ff00ff);
count = (count & 0x0000ffff) + ((count >> 16) & 0x0000ffff);
return count;
}

static __global__ void GetOutIndexsCounter(const int* flags,
const int n,
int* out) {
int tid = threadIdx.x + blockDim.x * blockIdx.x;
__shared__ int block_count;
if (threadIdx.x == 0) {
block_count = 0;
}
__syncthreads();

if (tid < n) {
// get the count of 1 in flags[tid]
uint32_t count = BitCount(static_cast<uint32_t>(flags[tid]));
// add to block_count
// TODO(zhangkaihuo): replace with block reduce_sum
atomicAdd(&block_count, static_cast<int>(count));
}
__syncthreads();
// write to out
if (threadIdx.x == 0) {
out[blockIdx.x] = block_count;
}
}

template <int BS>
__global__ void GetOutIndexs(const int* flags,
const int n,
const int* offsets,
const int out_nnz,
int* out) {
int tid = threadIdx.x + blockDim.x * blockIdx.x;
__shared__ int block_counts[BS];
__shared__ int block_outs[BS * 32];

int count = 0;

if (tid < n) {
// get the count of 1 in flags[tid]
int flag = flags[tid];
count = BitCount(static_cast<uint32_t>(flag));
}

// call block prefix_sum
// using namespace cub;
typedef cub::BlockScan<int, BS> BlockScan;
__shared__ typename BlockScan::TempStorage temp_storage;
BlockScan(temp_storage).ExclusiveSum(count, count);
__syncthreads();

// write index to out
if (tid < n) {
// get the count of 1 in flags[tid]
int flag = flags[tid];
// int j = block_counts[threadIdx.x];
int j = count;
// TODO(zhangkaihuo): opt the loop
for (int i = 0; i < 32; ++i) {
if ((1 & (flag >> i)) == 1) {
block_outs[j++] = (tid << 5) + i;
}
}
}

__syncthreads();
// write to block_outs
int start = offsets[blockIdx.x];
int end = blockIdx.x == gridDim.x - 1 ? out_nnz : offsets[blockIdx.x + 1];
for (int i = threadIdx.x; i < end - start; i += blockDim.x) {
out[start + i] = block_outs[i];
}
}

template <typename IntT>
__global__ void GroupIndexs(const int* out_index_table,
const int n,
Expand Down Expand Up @@ -725,13 +813,25 @@ int ProductRuleBook(const Context& dev_ctx,
gpuMemcpyDeviceToHost,
dev_ctx.stream());
dev_ctx.Wait();

const int threads = 256;
const int blocks = (index_flags.numel() + threads - 1) / threads;
GetOutIndexsCounter<<<blocks, threads, 0, dev_ctx.stream()>>>(
index_flags_ptr, index_flags.numel(), out_index_table_ptr);
#ifdef PADDLE_WITH_HIP
thrust::sort(thrust::hip::par.on(dev_ctx.stream()),
thrust::exclusive_scan(thrust::hip::par.on(dev_ctx.stream()),
#else
thrust::sort(thrust::cuda::par.on(dev_ctx.stream()),
thrust::exclusive_scan(thrust::cuda::par.on(dev_ctx.stream()),
#endif
out_index_ptr,
out_index_ptr + out_nnz);
out_index_table_ptr,
out_index_table_ptr + blocks,
out_index_table_ptr);
GetOutIndexs<threads>
<<<blocks, threads, 0, dev_ctx.stream()>>>(index_flags_ptr,
index_flags.numel(),
out_index_table_ptr,
out_nnz,
out_index_ptr);

const int64_t sparse_dim = 4;
phi::DenseTensor out_indices =
Expand Down
4 changes: 2 additions & 2 deletions paddle/phi/kernels/sparse/gpu/conv_kernel.cu
Original file line number Diff line number Diff line change
Expand Up @@ -125,7 +125,7 @@ void Conv3dCooGPUKernel(const GPUContext& dev_ctx,

#ifdef PADDLE_WITH_CUTLASS
bool cutlass = true;
if (dev_ctx.GetComputeCapability() < 80) cutlass = false;
if (dev_ctx.GetComputeCapability() < 75) cutlass = false;
if (in_channels % 4 != 0 || out_channels % 4 != 0) {
if (std::is_same<T, phi::dtype::float16>::value) cutlass = false;
if (std::is_same<T, float>::value) cutlass = false;
Expand Down Expand Up @@ -173,7 +173,7 @@ void Conv3dCooGPUKernel(const GPUContext& dev_ctx,
if constexpr (std::is_same<T, float>::value &&
std::is_same<IntT, int32_t>::value) {
fp32_gather_gemm_scatter gather_gemm_scatter =
getBestFp32Kernel(M, N, K);
getBestFp32Kernel(M, N, K, dev_ctx.GetComputeCapability());
gather_gemm_scatter(dev_ctx,
x.non_zero_elements().data<T>(),
tmp_kernel_ptr,
Expand Down
8 changes: 7 additions & 1 deletion paddle/phi/kernels/sparse/gpu/gather_gemm_scatter.cu
Original file line number Diff line number Diff line change
Expand Up @@ -72,7 +72,13 @@ fp16_gather_gemm_scatter getBestFp16Kernel(const int M,
}
fp32_gather_gemm_scatter getBestFp32Kernel(const int M,
const int N,
const int K) {
const int K,
const int SM) {
if (SM == 75) {
return launchKernel<
float,
cutlass_tensorop_s1688gemm_f16_64x64_32x2_nn_align4::Gemm>;
}
if (K == 4 && N == 16) {
return launchKernel<
float,
Expand Down
27 changes: 26 additions & 1 deletion paddle/phi/kernels/sparse/gpu/gather_gemm_scatter.h
Original file line number Diff line number Diff line change
Expand Up @@ -66,7 +66,8 @@ fp16_gather_gemm_scatter getBestFp16Kernel(const int M,
const int N);
fp32_gather_gemm_scatter getBestFp32Kernel(const int M,
const int K,
const int N);
const int N,
const int SM);
fp64_gather_gemm_scatter getBestFp64Kernel(const int M,
const int K,
const int N);
Expand Down Expand Up @@ -550,6 +551,30 @@ struct cutlass_tensorop_d884gemm_32x16_16x5_nn_align1 {
false,
true>;
};

// sm75
struct cutlass_tensorop_s1688gemm_f16_64x64_32x2_nn_align4 {
using Gemm = cutlass::gemm::device::GemmUniversal<
cutlass::half_t,
cutlass::layout::RowMajor,
cutlass::half_t,
cutlass::layout::RowMajor,
float,
cutlass::layout::RowMajor,
float,
cutlass::arch::OpClassTensorOp,
cutlass::arch::Sm75,
cutlass::gemm::GemmShape<64, 64, 32>,
cutlass::gemm::GemmShape<32, 32, 32>,
cutlass::gemm::GemmShape<16, 8, 8>,
cutlass::epilogue::thread::LinearCombination<float, 4, float, float>,
cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<8>,
2,
8,
8,
cutlass::arch::OpMultiplyAdd>;
};

} // namespace sparse
} // namespace phi
#endif