diff --git a/paddle/fluid/platform/device/gpu/gpu_info.cc b/paddle/fluid/platform/device/gpu/gpu_info.cc index 36189cc7e4c90d..73704b04cf90b2 100644 --- a/paddle/fluid/platform/device/gpu/gpu_info.cc +++ b/paddle/fluid/platform/device/gpu/gpu_info.cc @@ -217,6 +217,7 @@ class RecordedGpuMallocHelper { CUDADeviceGuard guard(dev_id_); gpuError_t result; #ifdef PADDLE_WITH_HIP + phi::backends::gpu::CUDAGraphCaptureModeGuard capture_mode_guard; if (UNLIKELY(malloc_managed_memory)) { result = hipMallocManaged(ptr, size); } else { diff --git a/paddle/phi/core/visit_type.h b/paddle/phi/core/visit_type.h index ad30da4ddcd6f0..03da0544500920 100644 --- a/paddle/phi/core/visit_type.h +++ b/paddle/phi/core/visit_type.h @@ -355,7 +355,7 @@ namespace phi { "`"); \ } \ }() -#if defined(PADDLE_WITH_XPU) || defined(PADDLE_WITH_HIP) +#if defined(PADDLE_WITH_XPU) #define PD_VISIT_ALL_TYPES(TYPE, NAME, ...) \ [&] { \ const auto& __dtype__ = TYPE; \ diff --git a/paddle/phi/kernels/CMakeLists.txt b/paddle/phi/kernels/CMakeLists.txt index abe752e27fd391..891888bf8b5850 100644 --- a/paddle/phi/kernels/CMakeLists.txt +++ b/paddle/phi/kernels/CMakeLists.txt @@ -209,11 +209,9 @@ if(WITH_ROCM) "gpu/lu_kernel.cu" "gpu/matrix_rank_kernel.cu" "gpu/matrix_rank_tol_kernel.cu" - "gpu/multiclass_nms3_kernel.cu" "gpu/put_along_axis_grad_kernel.cu" "gpu/put_along_axis_kernel.cu" "gpu/qr_kernel.cu" - "gpu/rms_norm_grad_kernel.cu" "gpu/svd_kernel.cu" "gpudnn/mha_cudnn_frontend.cu" "fusion/gpu/block_multi_head_attention_kernel.cu" diff --git a/paddle/phi/kernels/funcs/dropout_impl.cu.h b/paddle/phi/kernels/funcs/dropout_impl.cu.h index 463272a37c00d3..855b6fe6c8e15c 100644 --- a/paddle/phi/kernels/funcs/dropout_impl.cu.h +++ b/paddle/phi/kernels/funcs/dropout_impl.cu.h @@ -349,19 +349,6 @@ void DropoutFwGPUKernelDriver( } else { bool copy_in_kernel = GetSeedDataAndIncrement( dev_ctx, seed, is_fix_seed, seed_val, offset, &seed_data, &increment); -#ifdef PADDLE_WITH_HIP - VectorizedRandomGenerator - <<>>(0, - size, - seed_data, - dropout_prob, - x_data, - mask_data, - y_data, - upscale_in_train, - increment, - main_offset); -#else const phi::GPUContext* dev_ctx_p = &dev_ctx; auto gen_cuda = dev_ctx.GetGenerator(); auto state_index = gen_cuda->GetStateIndex(); @@ -370,10 +357,11 @@ void DropoutFwGPUKernelDriver( parameterSetter = [offset, dev_ctx_p, state_index, is_fix_seed]( phi::backends::gpu::gpuKernelParams& params) { if (!is_fix_seed) { - // we assume seed is null pointer - // seed copy to cpu is meaningless here + // we assume seed is null pointer + // seed copy to cpu is meaningless here +#ifndef PADDLE_WITH_HIP assert(seed_tensor_ptr == nullptr); - +#endif auto gen_cuda = dev_ctx_p->GetGenerator(); // ensure the generator use correct state index gen_cuda->SetStateIndex(state_index); @@ -393,9 +381,14 @@ void DropoutFwGPUKernelDriver( cudaKernelCallback = [=](unsigned int id) { void* functionPtr = reinterpret_cast(&(VectorizedRandomGenerator)); +#ifdef PADDLE_WITH_HIP + hipFunction_t cudaFunc = + reinterpret_cast(functionPtr); +#else cudaFunction_t cudaFunc; PADDLE_ENFORCE_GPU_SUCCESS( cudaGetFuncBySymbol(&cudaFunc, functionPtr)); +#endif VLOG(10) << "[cudaKernelCallback] cudaFunc = " << cudaFunc << " functionPtr = " << functionPtr; @@ -417,7 +410,6 @@ void DropoutFwGPUKernelDriver( VLOG(10) << "NON_CUDA_GRAPH seed = " << seed_data << ", increment = " << increment; -#endif } } else { if (upscale_in_train) { diff --git a/paddle/phi/kernels/funcs/layer_norm_impl.cu.h b/paddle/phi/kernels/funcs/layer_norm_impl.cu.h index 6a82875819161b..3eee52efcbebe6 100644 --- a/paddle/phi/kernels/funcs/layer_norm_impl.cu.h +++ b/paddle/phi/kernels/funcs/layer_norm_impl.cu.h @@ -166,14 +166,14 @@ __inline__ __device__ double rsqrt_(const double val) { return ::rsqrt(val); } -#if CUDA_ARCH_FP16_SUPPORTED(__CUDA_ARCH__) +#if CUDA_ARCH_FP16_SUPPORTED(__CUDA_ARCH__) || defined(PADDLE_WITH_HIP) template <> __inline__ __device__ half rsqrt_(const half val) { return hrsqrt(val); } #endif -#ifdef PADDLE_WITH_CUDA +#if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP) template 1) { if (lane == 0) { @@ -290,7 +294,11 @@ __global__ __launch_bounds__(THREADS_PER_CTA) void fast_ln_fwd_kernel( #pragma unroll for (int it = 1; it < THREADS_PER_WARP; it *= 2) { +#ifdef PADDLE_WITH_HIP + var_local += __shfl_xor(var_local, it); +#else var_local += __shfl_xor_sync(uint32_t(-1), var_local, it); +#endif } if (WARPS_N > 1) { @@ -546,7 +554,7 @@ __inline__ __device__ void cuLoadAddStridedInputs(const int64_t i1_block, } } -#ifdef PADDLE_WITH_CUDA +#if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP) template 0; it /= 2) { +#ifdef PADDLE_WITH_HIP + sum_loss1 += __shfl_down(sum_loss1, it); + sum_loss2 += __shfl_down(sum_loss2, it); +#else sum_loss1 += __shfl_down_sync(uint32_t(-1), sum_loss1, it); sum_loss2 += __shfl_down_sync(uint32_t(-1), sum_loss2, it); +#endif } if (lane == 0) { diff --git a/paddle/phi/kernels/fusion/gpu/fused_bias_dropout_residual_layer_norm_grad_kernel.cu b/paddle/phi/kernels/fusion/gpu/fused_bias_dropout_residual_layer_norm_grad_kernel.cu index 60a82cfe7c1980..48819c12a8dc0e 100644 --- a/paddle/phi/kernels/fusion/gpu/fused_bias_dropout_residual_layer_norm_grad_kernel.cu +++ b/paddle/phi/kernels/fusion/gpu/fused_bias_dropout_residual_layer_norm_grad_kernel.cu @@ -11,7 +11,12 @@ // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. -#ifndef PADDLE_WITH_HIP +#ifdef PADDLE_WITH_HIP +#include +#include +#include +namespace cub = hipcub; +#else #include #include #endif @@ -21,9 +26,7 @@ #include "paddle/phi/core/kernel_registry.h" #include "paddle/phi/core/tensor_utils.h" #include "paddle/phi/kernels/funcs/layer_norm_impl.cu.h" -#ifndef PADDLE_WITH_HIP #include "paddle/phi/kernels/fusion/gpu/fused_dropout_helper.h" -#endif namespace phi { namespace fusion { @@ -51,7 +54,6 @@ void FusedBiasDropoutResidualLnGradKernel( DenseTensor* bias_grad, DenseTensor* ln_scale_grad, DenseTensor* ln_bias_grad) { -#ifndef PADDLE_WITH_HIP using U = LayerNormParamType; auto* d_y_data = y_grad.data(); auto* ln_scale_data = @@ -114,15 +116,19 @@ void FusedBiasDropoutResidualLnGradKernel( d_x_data, d_bias_data, d_residual_data); -#else - PADDLE_THROW(phi::errors::Unimplemented( - "FusedBiasDropoutResidualLnGradKernel not surpport for rocm")); -#endif } } // namespace fusion } // namespace phi +#ifdef PADDLE_WITH_HIP +PD_REGISTER_KERNEL(fused_bias_dropout_residual_layer_norm_grad, + GPU, + ALL_LAYOUT, + phi::fusion::FusedBiasDropoutResidualLnGradKernel, + float, + phi::dtype::float16) {} +#else PD_REGISTER_KERNEL(fused_bias_dropout_residual_layer_norm_grad, GPU, ALL_LAYOUT, @@ -130,3 +136,4 @@ PD_REGISTER_KERNEL(fused_bias_dropout_residual_layer_norm_grad, float, double, phi::dtype::float16) {} +#endif diff --git a/paddle/phi/kernels/fusion/gpu/fused_bias_dropout_residual_layer_norm_kernel.cu b/paddle/phi/kernels/fusion/gpu/fused_bias_dropout_residual_layer_norm_kernel.cu index 37450d3a4e178b..ca0bcbe7f2466a 100644 --- a/paddle/phi/kernels/fusion/gpu/fused_bias_dropout_residual_layer_norm_kernel.cu +++ b/paddle/phi/kernels/fusion/gpu/fused_bias_dropout_residual_layer_norm_kernel.cu @@ -17,9 +17,7 @@ #include "paddle/phi/core/kernel_registry.h" #include "paddle/phi/core/tensor_utils.h" #include "paddle/phi/kernels/funcs/layer_norm_impl.cu.h" -#ifndef PADDLE_WITH_HIP #include "paddle/phi/kernels/fusion/gpu/fused_dropout_helper.h" -#endif namespace phi { namespace fusion { @@ -42,7 +40,6 @@ void FusedBiasDropoutResidualLnKernel( DenseTensor* dropout_mask_out, DenseTensor* ln_mean, DenseTensor* ln_variance) { -#ifndef PADDLE_WITH_HIP using U = phi::funcs::LayerNormParamType; auto* x_data = x.data(); auto* bias_data = (bias.get_ptr() == nullptr) ? nullptr : bias->data(); @@ -95,14 +92,20 @@ void FusedBiasDropoutResidualLnKernel( y_data, ln_mean_data, ln_var_data); -#else - PADDLE_THROW(phi::errors::Unimplemented( - "FusedBiasDropoutResidualLnKernel not support for rocm")); -#endif } } // namespace fusion } // namespace phi +#ifdef PADDLE_WITH_HIP +PD_REGISTER_KERNEL(fused_bias_dropout_residual_layer_norm, + GPU, + ALL_LAYOUT, + phi::fusion::FusedBiasDropoutResidualLnKernel, + float, + phi::dtype::float16) { + kernel->OutputAt(1).SetDataType(phi::DataType::UINT8); +} +#else PD_REGISTER_KERNEL(fused_bias_dropout_residual_layer_norm, GPU, ALL_LAYOUT, @@ -112,3 +115,4 @@ PD_REGISTER_KERNEL(fused_bias_dropout_residual_layer_norm, phi::dtype::float16) { kernel->OutputAt(1).SetDataType(phi::DataType::UINT8); } +#endif diff --git a/paddle/phi/kernels/fusion/gpu/fused_dropout_act_bias.h b/paddle/phi/kernels/fusion/gpu/fused_dropout_act_bias.h index e5f5c9ba50ba45..d2cd2f1b545a7c 100644 --- a/paddle/phi/kernels/fusion/gpu/fused_dropout_act_bias.h +++ b/paddle/phi/kernels/fusion/gpu/fused_dropout_act_bias.h @@ -35,7 +35,11 @@ struct GeluFunctor { template struct FastGeluFunctor { inline __device__ T operator()(const T x) const { +#ifdef PADDLE_WITH_HIP + assert(0 && "ROCM does not support FastGelu"); +#else return phi::GeluFwd(x); +#endif } }; @@ -92,8 +96,8 @@ __global__ void FusedDropoutActBias( int row_id = blockIdx.y; int idx = row_id * cols + col_id; - curandStatePhilox4_32_10_t state; - curand_init(seed, idx, increment, &state); + GPURAND(StatePhilox4_32_10_t) state; + GPURAND(_init)(seed, idx, increment, &state); const T factor = phi::fusion::GetFactor(dropout_prob, is_upscale_in_train, is_test); diff --git a/paddle/phi/kernels/fusion/gpu/fused_dropout_add_grad_kernel.cu b/paddle/phi/kernels/fusion/gpu/fused_dropout_add_grad_kernel.cu index 801f070251fb2c..8994d521382335 100644 --- a/paddle/phi/kernels/fusion/gpu/fused_dropout_add_grad_kernel.cu +++ b/paddle/phi/kernels/fusion/gpu/fused_dropout_add_grad_kernel.cu @@ -202,18 +202,6 @@ void FusedDropoutAddGradKernel(const Context& dev_ctx, ? NoMaskBwFunctor(1.0f - dropout_rate) : NoMaskBwFunctor(1.0f - dropout_rate, 1.0f); -#ifdef PADDLE_WITH_HIP - VectorizedDropoutBackward> - <<>>(0, - numel, - seed_data, // idx: 2 need save - x_grad_data, - y_grad_data, - out_grad_data, - increment, // idx: 6 need save - main_offset, - functor); -#else // we assume seed/offset is same across iterations // seed_offset_data should preserved by cudaGraph pool const phi::GPUContext* dev_ctx_p = &dev_ctx; @@ -233,9 +221,13 @@ void FusedDropoutAddGradKernel(const Context& dev_ctx, cudaKernelCallback = [=](unsigned int id) { void* functionPtr = reinterpret_cast( &(VectorizedDropoutBackward>)); +#ifdef PADDLE_WITH_HIP + hipFunction_t cudaFunc = reinterpret_cast(functionPtr); +#else cudaFunction_t cudaFunc; PADDLE_ENFORCE_GPU_SUCCESS( cudaGetFuncBySymbol(&cudaFunc, functionPtr)); +#endif VLOG(10) << "[cudaKernelCallback] cudaFunc = " << cudaFunc << " functionPtr = " << functionPtr; @@ -257,7 +249,6 @@ void FusedDropoutAddGradKernel(const Context& dev_ctx, VLOG(10) << "NON_CUDA_GRAPH seed = " << seed_data << ", increment = " << increment; -#endif } } diff --git a/paddle/phi/kernels/fusion/gpu/fused_dropout_add_kernel.cu b/paddle/phi/kernels/fusion/gpu/fused_dropout_add_kernel.cu index c95c5fbf0ca3de..54ec3604bbee93 100644 --- a/paddle/phi/kernels/fusion/gpu/fused_dropout_add_kernel.cu +++ b/paddle/phi/kernels/fusion/gpu/fused_dropout_add_kernel.cu @@ -186,18 +186,6 @@ void FusedDropoutAddKernel(const Context& dev_ctx, auto dst_functor = NoMaskFwFunctor(1.0f - dropout_rate, upscale_in_train); -#ifdef PADDLE_WITH_HIP - VectorizedDropoutForward> - <<>>(0, - numel, - seed_data, // need save - x_data, - y_data, - out_data, - increment, // need save - main_offset, - dst_functor); -#else // we assume seed/offset is same across iterations // seed_offset_data should preserved by cudaGraph pool const phi::GPUContext* dev_ctx_p = &dev_ctx; @@ -237,9 +225,13 @@ void FusedDropoutAddKernel(const Context& dev_ctx, cudaKernelCallback = [=](unsigned int id) { void* functionPtr = reinterpret_cast( &(VectorizedDropoutForward>)); +#ifdef PADDLE_WITH_HIP + hipFunction_t cudaFunc = reinterpret_cast(functionPtr); +#else cudaFunction_t cudaFunc; PADDLE_ENFORCE_GPU_SUCCESS( cudaGetFuncBySymbol(&cudaFunc, functionPtr)); +#endif VLOG(10) << "[cudaKernelCallback] cudaFunc = " << cudaFunc << " functionPtr = " << functionPtr; @@ -260,7 +252,6 @@ void FusedDropoutAddKernel(const Context& dev_ctx, VLOG(10) << "NON_CUDA_GRAPH seed = " << seed_data << ", increment = " << increment; -#endif } else { using MT = typename phi::dtype::MPTypeTrait::Type; MT factor = static_cast(1.0f - dropout_rate); diff --git a/paddle/phi/kernels/fusion/gpu/fused_dropout_common.h b/paddle/phi/kernels/fusion/gpu/fused_dropout_common.h index 2ef46378b1b9bd..ef9ecbb435fdba 100644 --- a/paddle/phi/kernels/fusion/gpu/fused_dropout_common.h +++ b/paddle/phi/kernels/fusion/gpu/fused_dropout_common.h @@ -20,10 +20,25 @@ limitations under the License. */ #include #endif +#ifdef PADDLE_WITH_HIP +#include +#include +#include +#include +#endif + #include "paddle/phi/backends/gpu/gpu_launch_config.h" #include "paddle/phi/kernels/funcs/aligned_vector.h" #include "paddle/phi/kernels/funcs/layer_norm_impl.cu.h" +#ifdef PADDLE_WITH_HIP +#define GPU(str) hip##str +#define GPURAND(str) hiprand##str +#else +#define GPU(str) cuda##str +#define GPURAND(str) curand##str +#endif + namespace phi { namespace fusion { @@ -63,26 +78,29 @@ inline phi::backends::gpu::GpuLaunchConfig Get1DBlocksAnd2DGrids( } template -__forceinline__ __device__ void RandVec(curandStatePhilox4_32_10_t *state, +__forceinline__ __device__ void RandVec(GPURAND(StatePhilox4_32_10_t) * state, float *data); template <> -__forceinline__ __device__ void RandVec<1>(curandStatePhilox4_32_10_t *state, +__forceinline__ __device__ void RandVec<1>(GPURAND(StatePhilox4_32_10_t) * + state, float *data) { - data[0] = curand_uniform(state); + data[0] = GPURAND(_uniform)(state); } template <> -__forceinline__ __device__ void RandVec<2>(curandStatePhilox4_32_10_t *state, +__forceinline__ __device__ void RandVec<2>(GPURAND(StatePhilox4_32_10_t) * + state, float *data) { - data[0] = curand_uniform(state); - data[1] = curand_uniform(state); + data[0] = GPURAND(_uniform)(state); + data[1] = GPURAND(_uniform)(state); } template <> -__forceinline__ __device__ void RandVec<4>(curandStatePhilox4_32_10_t *state, +__forceinline__ __device__ void RandVec<4>(GPURAND(StatePhilox4_32_10_t) * + state, float *data) { - float4 rand4 = curand_uniform4(state); + float4 rand4 = GPURAND(_uniform4)(state); data[0] = rand4.x; data[1] = rand4.y; data[2] = rand4.w; @@ -90,7 +108,8 @@ __forceinline__ __device__ void RandVec<4>(curandStatePhilox4_32_10_t *state, } template <> -__forceinline__ __device__ void RandVec<8>(curandStatePhilox4_32_10_t *state, +__forceinline__ __device__ void RandVec<8>(GPURAND(StatePhilox4_32_10_t) * + state, float *data) { RandVec<4>(state, data); RandVec<4>(state, data + 4); @@ -99,7 +118,7 @@ __forceinline__ __device__ void RandVec<8>(curandStatePhilox4_32_10_t *state, template inline void SetZero(const phi::GPUContext &ctx, T *ptr, const size_t size) { PADDLE_ENFORCE_GPU_SUCCESS( - cudaMemsetAsync(ptr, 0, size * sizeof(T), ctx.stream())); + GPU(MemsetAsync)(ptr, 0, size * sizeof(T), ctx.stream())); } /** diff --git a/paddle/phi/kernels/fusion/gpu/fused_layernorm_kernel.cu b/paddle/phi/kernels/fusion/gpu/fused_layernorm_kernel.cu index e31b24e7e105e5..221019531a5486 100644 --- a/paddle/phi/kernels/fusion/gpu/fused_layernorm_kernel.cu +++ b/paddle/phi/kernels/fusion/gpu/fused_layernorm_kernel.cu @@ -38,10 +38,19 @@ limitations under the License. #include "paddle/phi/backends/gpu/gpu_context.h" #include "paddle/phi/common/amp_type_traits.h" #include "paddle/phi/core/kernel_registry.h" -#ifndef PADDLE_WITH_HIP -#include #include "paddle/phi/kernels/fusion/gpu/attention_layer.norm.h" #include "paddle/phi/kernels/fusion/gpu/fused_dropout_helper.h" +#ifdef PADDLE_WITH_HIP +#include +#include +#include +namespace cub = hipcub; +#define GPU(str) hip##str +#define GPUMultiProcessorCount hipDeviceAttributeMultiprocessorCount +#else +#include +#define GPU(str) cuda##str +#define GPUMultiProcessorCount cudaDevAttrMultiProcessorCount #endif namespace phi { @@ -50,9 +59,11 @@ namespace fusion { namespace { -#ifndef PADDLE_WITH_HIP - +#ifdef PADDLE_WITH_HIP +constexpr int kWarpSize = 64; +#else constexpr int kWarpSize = 32; +#endif template struct SumOp { @@ -74,7 +85,11 @@ template