Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions paddle/fluid/platform/device/gpu/gpu_info.cc
Original file line number Diff line number Diff line change
Expand Up @@ -217,6 +217,7 @@ class RecordedGpuMallocHelper {
CUDADeviceGuard guard(dev_id_);
gpuError_t result;
#ifdef PADDLE_WITH_HIP
phi::backends::gpu::CUDAGraphCaptureModeGuard capture_mode_guard;
if (UNLIKELY(malloc_managed_memory)) {
result = hipMallocManaged(ptr, size);
} else {
Expand Down
2 changes: 1 addition & 1 deletion paddle/phi/core/visit_type.h
Original file line number Diff line number Diff line change
Expand Up @@ -355,7 +355,7 @@ namespace phi {
"`"); \
} \
}()
#if defined(PADDLE_WITH_XPU) || defined(PADDLE_WITH_HIP)
#if defined(PADDLE_WITH_XPU)
#define PD_VISIT_ALL_TYPES(TYPE, NAME, ...) \
[&] { \
const auto& __dtype__ = TYPE; \
Expand Down
2 changes: 0 additions & 2 deletions paddle/phi/kernels/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -209,11 +209,9 @@ if(WITH_ROCM)
"gpu/lu_kernel.cu"
"gpu/matrix_rank_kernel.cu"
"gpu/matrix_rank_tol_kernel.cu"
"gpu/multiclass_nms3_kernel.cu"
"gpu/put_along_axis_grad_kernel.cu"
"gpu/put_along_axis_kernel.cu"
"gpu/qr_kernel.cu"
"gpu/rms_norm_grad_kernel.cu"
"gpu/svd_kernel.cu"
"gpudnn/mha_cudnn_frontend.cu"
"fusion/gpu/block_multi_head_attention_kernel.cu"
Expand Down
26 changes: 9 additions & 17 deletions paddle/phi/kernels/funcs/dropout_impl.cu.h
Original file line number Diff line number Diff line change
Expand Up @@ -349,19 +349,6 @@ void DropoutFwGPUKernelDriver(
} else {
bool copy_in_kernel = GetSeedDataAndIncrement(
dev_ctx, seed, is_fix_seed, seed_val, offset, &seed_data, &increment);
#ifdef PADDLE_WITH_HIP
VectorizedRandomGenerator<T>
<<<grid_size, block_size, 0, stream>>>(0,
size,
seed_data,
dropout_prob,
x_data,
mask_data,
y_data,
upscale_in_train,
increment,
main_offset);
#else
const phi::GPUContext* dev_ctx_p = &dev_ctx;
auto gen_cuda = dev_ctx.GetGenerator();
auto state_index = gen_cuda->GetStateIndex();
Expand All @@ -370,10 +357,11 @@ void DropoutFwGPUKernelDriver(
parameterSetter = [offset, dev_ctx_p, state_index, is_fix_seed](
phi::backends::gpu::gpuKernelParams& params) {
if (!is_fix_seed) {
// we assume seed is null pointer
// seed copy to cpu is meaningless here
// we assume seed is null pointer
// seed copy to cpu is meaningless here
#ifndef PADDLE_WITH_HIP
assert(seed_tensor_ptr == nullptr);

#endif
auto gen_cuda = dev_ctx_p->GetGenerator();
// ensure the generator use correct state index
gen_cuda->SetStateIndex(state_index);
Expand All @@ -393,9 +381,14 @@ void DropoutFwGPUKernelDriver(
cudaKernelCallback = [=](unsigned int id) {
void* functionPtr =
reinterpret_cast<void*>(&(VectorizedRandomGenerator<T>));
#ifdef PADDLE_WITH_HIP
hipFunction_t cudaFunc =
reinterpret_cast<hipFunction_t>(functionPtr);
#else
cudaFunction_t cudaFunc;
PADDLE_ENFORCE_GPU_SUCCESS(
cudaGetFuncBySymbol(&cudaFunc, functionPtr));
#endif
VLOG(10) << "[cudaKernelCallback] cudaFunc = " << cudaFunc
<< " functionPtr = " << functionPtr;

Expand All @@ -417,7 +410,6 @@ void DropoutFwGPUKernelDriver(

VLOG(10) << "NON_CUDA_GRAPH seed = " << seed_data
<< ", increment = " << increment;
#endif
}
} else {
if (upscale_in_train) {
Expand Down
24 changes: 21 additions & 3 deletions paddle/phi/kernels/funcs/layer_norm_impl.cu.h
Original file line number Diff line number Diff line change
Expand Up @@ -166,14 +166,14 @@ __inline__ __device__ double rsqrt_(const double val) {
return ::rsqrt(val);
}

#if CUDA_ARCH_FP16_SUPPORTED(__CUDA_ARCH__)
#if CUDA_ARCH_FP16_SUPPORTED(__CUDA_ARCH__) || defined(PADDLE_WITH_HIP)
template <>
__inline__ __device__ half rsqrt_(const half val) {
return hrsqrt(val);
}
#endif

#ifdef PADDLE_WITH_CUDA
#if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP)
template <typename T,
typename U,
typename ScaleT = U,
Expand Down Expand Up @@ -254,7 +254,11 @@ __global__ __launch_bounds__(THREADS_PER_CTA) void fast_ln_fwd_kernel(

#pragma unroll
for (int it = 1; it < THREADS_PER_WARP; it *= 2) {
#ifdef PADDLE_WITH_HIP
mu_local += __shfl_xor(mu_local, it);
#else
mu_local += __shfl_xor_sync(uint32_t(-1), mu_local, it);
#endif
}
if (WARPS_N > 1) {
if (lane == 0) {
Expand Down Expand Up @@ -290,7 +294,11 @@ __global__ __launch_bounds__(THREADS_PER_CTA) void fast_ln_fwd_kernel(

#pragma unroll
for (int it = 1; it < THREADS_PER_WARP; it *= 2) {
#ifdef PADDLE_WITH_HIP
var_local += __shfl_xor(var_local, it);
#else
var_local += __shfl_xor_sync(uint32_t(-1), var_local, it);
#endif
}

if (WARPS_N > 1) {
Expand Down Expand Up @@ -546,7 +554,7 @@ __inline__ __device__ void cuLoadAddStridedInputs(const int64_t i1_block,
}
}

#ifdef PADDLE_WITH_CUDA
#if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP)
template <bool IsFusedDropoutResidualLn,
bool NeedDDropoutSrcPtr,
typename T,
Expand Down Expand Up @@ -678,16 +686,26 @@ __global__ __launch_bounds__(THREADS_PER_CTA) void fused_ln_bwd_fast_kernel(
#pragma unroll
// row reduction among 32 threads.
for (int it = 1; it < THREADS_PER_WARP; it *= 2) {
#ifdef PADDLE_WITH_HIP
sum_loss1 += __shfl_xor(sum_loss1, it);
sum_loss2 += __shfl_xor(sum_loss2, it);
#else
sum_loss1 += __shfl_xor_sync(uint32_t(-1), sum_loss1, it);
sum_loss2 += __shfl_xor_sync(uint32_t(-1), sum_loss2, it);
#endif
}
sum_loss1 *= rn;
sum_loss2 *= rn;
} else {
#pragma unroll
for (int it = 16; it > 0; it /= 2) {
#ifdef PADDLE_WITH_HIP
sum_loss1 += __shfl_down(sum_loss1, it);
sum_loss2 += __shfl_down(sum_loss2, it);
#else
sum_loss1 += __shfl_down_sync(uint32_t(-1), sum_loss1, it);
sum_loss2 += __shfl_down_sync(uint32_t(-1), sum_loss2, it);
#endif
}

if (lane == 0) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,12 @@
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#ifndef PADDLE_WITH_HIP
#ifdef PADDLE_WITH_HIP
#include <hip/hip_fp16.h>
#include <hip/hip_runtime.h>
#include <hipcub/hipcub.hpp>
namespace cub = hipcub;
#else
#include <cuda_fp16.h>
#include <cub/cub.cuh>
#endif
Expand All @@ -21,9 +26,7 @@
#include "paddle/phi/core/kernel_registry.h"
#include "paddle/phi/core/tensor_utils.h"
#include "paddle/phi/kernels/funcs/layer_norm_impl.cu.h"
#ifndef PADDLE_WITH_HIP
#include "paddle/phi/kernels/fusion/gpu/fused_dropout_helper.h"
#endif

namespace phi {
namespace fusion {
Expand Down Expand Up @@ -51,7 +54,6 @@ void FusedBiasDropoutResidualLnGradKernel(
DenseTensor* bias_grad,
DenseTensor* ln_scale_grad,
DenseTensor* ln_bias_grad) {
#ifndef PADDLE_WITH_HIP
using U = LayerNormParamType<T>;
auto* d_y_data = y_grad.data<T>();
auto* ln_scale_data =
Expand Down Expand Up @@ -114,19 +116,24 @@ void FusedBiasDropoutResidualLnGradKernel(
d_x_data,
d_bias_data,
d_residual_data);
#else
PADDLE_THROW(phi::errors::Unimplemented(
"FusedBiasDropoutResidualLnGradKernel not surpport for rocm"));
#endif
}

} // namespace fusion
} // namespace phi

#ifdef PADDLE_WITH_HIP
PD_REGISTER_KERNEL(fused_bias_dropout_residual_layer_norm_grad,
GPU,
ALL_LAYOUT,
phi::fusion::FusedBiasDropoutResidualLnGradKernel,
float,
phi::dtype::float16) {}
#else
PD_REGISTER_KERNEL(fused_bias_dropout_residual_layer_norm_grad,
GPU,
ALL_LAYOUT,
phi::fusion::FusedBiasDropoutResidualLnGradKernel,
float,
double,
phi::dtype::float16) {}
#endif
Original file line number Diff line number Diff line change
Expand Up @@ -17,9 +17,7 @@
#include "paddle/phi/core/kernel_registry.h"
#include "paddle/phi/core/tensor_utils.h"
#include "paddle/phi/kernels/funcs/layer_norm_impl.cu.h"
#ifndef PADDLE_WITH_HIP
#include "paddle/phi/kernels/fusion/gpu/fused_dropout_helper.h"
#endif

namespace phi {
namespace fusion {
Expand All @@ -42,7 +40,6 @@ void FusedBiasDropoutResidualLnKernel(
DenseTensor* dropout_mask_out,
DenseTensor* ln_mean,
DenseTensor* ln_variance) {
#ifndef PADDLE_WITH_HIP
using U = phi::funcs::LayerNormParamType<T>;
auto* x_data = x.data<T>();
auto* bias_data = (bias.get_ptr() == nullptr) ? nullptr : bias->data<T>();
Expand Down Expand Up @@ -95,14 +92,20 @@ void FusedBiasDropoutResidualLnKernel(
y_data,
ln_mean_data,
ln_var_data);
#else
PADDLE_THROW(phi::errors::Unimplemented(
"FusedBiasDropoutResidualLnKernel not support for rocm"));
#endif
}
} // namespace fusion
} // namespace phi

#ifdef PADDLE_WITH_HIP
PD_REGISTER_KERNEL(fused_bias_dropout_residual_layer_norm,
GPU,
ALL_LAYOUT,
phi::fusion::FusedBiasDropoutResidualLnKernel,
float,
phi::dtype::float16) {
kernel->OutputAt(1).SetDataType(phi::DataType::UINT8);
}
#else
PD_REGISTER_KERNEL(fused_bias_dropout_residual_layer_norm,
GPU,
ALL_LAYOUT,
Expand All @@ -112,3 +115,4 @@ PD_REGISTER_KERNEL(fused_bias_dropout_residual_layer_norm,
phi::dtype::float16) {
kernel->OutputAt(1).SetDataType(phi::DataType::UINT8);
}
#endif
8 changes: 6 additions & 2 deletions paddle/phi/kernels/fusion/gpu/fused_dropout_act_bias.h
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,11 @@ struct GeluFunctor {
template <typename T>
struct FastGeluFunctor {
inline __device__ T operator()(const T x) const {
#ifdef PADDLE_WITH_HIP
assert(0 && "ROCM does not support FastGelu");
#else
return phi::GeluFwd<T, true>(x);
#endif
}
};

Expand Down Expand Up @@ -92,8 +96,8 @@ __global__ void FusedDropoutActBias(
int row_id = blockIdx.y;
int idx = row_id * cols + col_id;

curandStatePhilox4_32_10_t state;
curand_init(seed, idx, increment, &state);
GPURAND(StatePhilox4_32_10_t) state;
GPURAND(_init)(seed, idx, increment, &state);

const T factor =
phi::fusion::GetFactor<T>(dropout_prob, is_upscale_in_train, is_test);
Expand Down
17 changes: 4 additions & 13 deletions paddle/phi/kernels/fusion/gpu/fused_dropout_add_grad_kernel.cu
Original file line number Diff line number Diff line change
Expand Up @@ -202,18 +202,6 @@ void FusedDropoutAddGradKernel(const Context& dev_ctx,
? NoMaskBwFunctor<T, float>(1.0f - dropout_rate)
: NoMaskBwFunctor<T, float>(1.0f - dropout_rate, 1.0f);

#ifdef PADDLE_WITH_HIP
VectorizedDropoutBackward<T, NoMaskBwFunctor<T, float>>
<<<grid_size, block_size, 0, stream>>>(0,
numel,
seed_data, // idx: 2 need save
x_grad_data,
y_grad_data,
out_grad_data,
increment, // idx: 6 need save
main_offset,
functor);
#else
// we assume seed/offset is same across iterations
// seed_offset_data should preserved by cudaGraph pool
const phi::GPUContext* dev_ctx_p = &dev_ctx;
Expand All @@ -233,9 +221,13 @@ void FusedDropoutAddGradKernel(const Context& dev_ctx,
cudaKernelCallback = [=](unsigned int id) {
void* functionPtr = reinterpret_cast<void*>(
&(VectorizedDropoutBackward<T, NoMaskBwFunctor<T, float>>));
#ifdef PADDLE_WITH_HIP
hipFunction_t cudaFunc = reinterpret_cast<hipFunction_t>(functionPtr);
#else
cudaFunction_t cudaFunc;
PADDLE_ENFORCE_GPU_SUCCESS(
cudaGetFuncBySymbol(&cudaFunc, functionPtr));
#endif
VLOG(10) << "[cudaKernelCallback] cudaFunc = " << cudaFunc
<< " functionPtr = " << functionPtr;

Expand All @@ -257,7 +249,6 @@ void FusedDropoutAddGradKernel(const Context& dev_ctx,

VLOG(10) << "NON_CUDA_GRAPH seed = " << seed_data
<< ", increment = " << increment;
#endif
}
}

Expand Down
17 changes: 4 additions & 13 deletions paddle/phi/kernels/fusion/gpu/fused_dropout_add_kernel.cu
Original file line number Diff line number Diff line change
Expand Up @@ -186,18 +186,6 @@ void FusedDropoutAddKernel(const Context& dev_ctx,
auto dst_functor =
NoMaskFwFunctor<T, float>(1.0f - dropout_rate, upscale_in_train);

#ifdef PADDLE_WITH_HIP
VectorizedDropoutForward<T, NoMaskFwFunctor<T, float>>
<<<grid_size, block_size, 0, stream>>>(0,
numel,
seed_data, // need save
x_data,
y_data,
out_data,
increment, // need save
main_offset,
dst_functor);
#else
// we assume seed/offset is same across iterations
// seed_offset_data should preserved by cudaGraph pool
const phi::GPUContext* dev_ctx_p = &dev_ctx;
Expand Down Expand Up @@ -237,9 +225,13 @@ void FusedDropoutAddKernel(const Context& dev_ctx,
cudaKernelCallback = [=](unsigned int id) {
void* functionPtr = reinterpret_cast<void*>(
&(VectorizedDropoutForward<T, NoMaskFwFunctor<T, float>>));
#ifdef PADDLE_WITH_HIP
hipFunction_t cudaFunc = reinterpret_cast<hipFunction_t>(functionPtr);
#else
cudaFunction_t cudaFunc;
PADDLE_ENFORCE_GPU_SUCCESS(
cudaGetFuncBySymbol(&cudaFunc, functionPtr));
#endif
VLOG(10) << "[cudaKernelCallback] cudaFunc = " << cudaFunc
<< " functionPtr = " << functionPtr;

Expand All @@ -260,7 +252,6 @@ void FusedDropoutAddKernel(const Context& dev_ctx,

VLOG(10) << "NON_CUDA_GRAPH seed = " << seed_data
<< ", increment = " << increment;
#endif
} else {
using MT = typename phi::dtype::MPTypeTrait<T>::Type;
MT factor = static_cast<MT>(1.0f - dropout_rate);
Expand Down
Loading