Skip to content

Commit 4f5e4f5

Browse files
authored
[DCU] fix bugs and surpport some fused ops (#63217)
* [DCU] fix bugs and surpport some fused ops * [DCU] fix a small bug * Update fused_dropout_act_bias.h * update fused_dropout_act_bias.h * fix depthwise conv grad op bug * fix hip graph test bugs * update * fix hip graph dropout bug * code style
1 parent 0532d9a commit 4f5e4f5

21 files changed

+514
-338
lines changed

paddle/fluid/platform/device/gpu/gpu_info.cc

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -217,6 +217,7 @@ class RecordedGpuMallocHelper {
217217
CUDADeviceGuard guard(dev_id_);
218218
gpuError_t result;
219219
#ifdef PADDLE_WITH_HIP
220+
phi::backends::gpu::CUDAGraphCaptureModeGuard capture_mode_guard;
220221
if (UNLIKELY(malloc_managed_memory)) {
221222
result = hipMallocManaged(ptr, size);
222223
} else {

paddle/phi/core/visit_type.h

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -355,7 +355,7 @@ namespace phi {
355355
"`"); \
356356
} \
357357
}()
358-
#if defined(PADDLE_WITH_XPU) || defined(PADDLE_WITH_HIP)
358+
#if defined(PADDLE_WITH_XPU)
359359
#define PD_VISIT_ALL_TYPES(TYPE, NAME, ...) \
360360
[&] { \
361361
const auto& __dtype__ = TYPE; \

paddle/phi/kernels/CMakeLists.txt

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -209,11 +209,9 @@ if(WITH_ROCM)
209209
"gpu/lu_kernel.cu"
210210
"gpu/matrix_rank_kernel.cu"
211211
"gpu/matrix_rank_tol_kernel.cu"
212-
"gpu/multiclass_nms3_kernel.cu"
213212
"gpu/put_along_axis_grad_kernel.cu"
214213
"gpu/put_along_axis_kernel.cu"
215214
"gpu/qr_kernel.cu"
216-
"gpu/rms_norm_grad_kernel.cu"
217215
"gpu/svd_kernel.cu"
218216
"gpudnn/mha_cudnn_frontend.cu"
219217
"fusion/gpu/block_multi_head_attention_kernel.cu"

paddle/phi/kernels/funcs/dropout_impl.cu.h

Lines changed: 9 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -349,19 +349,6 @@ void DropoutFwGPUKernelDriver(
349349
} else {
350350
bool copy_in_kernel = GetSeedDataAndIncrement(
351351
dev_ctx, seed, is_fix_seed, seed_val, offset, &seed_data, &increment);
352-
#ifdef PADDLE_WITH_HIP
353-
VectorizedRandomGenerator<T>
354-
<<<grid_size, block_size, 0, stream>>>(0,
355-
size,
356-
seed_data,
357-
dropout_prob,
358-
x_data,
359-
mask_data,
360-
y_data,
361-
upscale_in_train,
362-
increment,
363-
main_offset);
364-
#else
365352
const phi::GPUContext* dev_ctx_p = &dev_ctx;
366353
auto gen_cuda = dev_ctx.GetGenerator();
367354
auto state_index = gen_cuda->GetStateIndex();
@@ -370,10 +357,11 @@ void DropoutFwGPUKernelDriver(
370357
parameterSetter = [offset, dev_ctx_p, state_index, is_fix_seed](
371358
phi::backends::gpu::gpuKernelParams& params) {
372359
if (!is_fix_seed) {
373-
// we assume seed is null pointer
374-
// seed copy to cpu is meaningless here
360+
// we assume seed is null pointer
361+
// seed copy to cpu is meaningless here
362+
#ifndef PADDLE_WITH_HIP
375363
assert(seed_tensor_ptr == nullptr);
376-
364+
#endif
377365
auto gen_cuda = dev_ctx_p->GetGenerator();
378366
// ensure the generator use correct state index
379367
gen_cuda->SetStateIndex(state_index);
@@ -393,9 +381,14 @@ void DropoutFwGPUKernelDriver(
393381
cudaKernelCallback = [=](unsigned int id) {
394382
void* functionPtr =
395383
reinterpret_cast<void*>(&(VectorizedRandomGenerator<T>));
384+
#ifdef PADDLE_WITH_HIP
385+
hipFunction_t cudaFunc =
386+
reinterpret_cast<hipFunction_t>(functionPtr);
387+
#else
396388
cudaFunction_t cudaFunc;
397389
PADDLE_ENFORCE_GPU_SUCCESS(
398390
cudaGetFuncBySymbol(&cudaFunc, functionPtr));
391+
#endif
399392
VLOG(10) << "[cudaKernelCallback] cudaFunc = " << cudaFunc
400393
<< " functionPtr = " << functionPtr;
401394

@@ -417,7 +410,6 @@ void DropoutFwGPUKernelDriver(
417410

418411
VLOG(10) << "NON_CUDA_GRAPH seed = " << seed_data
419412
<< ", increment = " << increment;
420-
#endif
421413
}
422414
} else {
423415
if (upscale_in_train) {

paddle/phi/kernels/funcs/layer_norm_impl.cu.h

Lines changed: 21 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -166,14 +166,14 @@ __inline__ __device__ double rsqrt_(const double val) {
166166
return ::rsqrt(val);
167167
}
168168

169-
#if CUDA_ARCH_FP16_SUPPORTED(__CUDA_ARCH__)
169+
#if CUDA_ARCH_FP16_SUPPORTED(__CUDA_ARCH__) || defined(PADDLE_WITH_HIP)
170170
template <>
171171
__inline__ __device__ half rsqrt_(const half val) {
172172
return hrsqrt(val);
173173
}
174174
#endif
175175

176-
#ifdef PADDLE_WITH_CUDA
176+
#if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP)
177177
template <typename T,
178178
typename U,
179179
typename ScaleT = U,
@@ -254,7 +254,11 @@ __global__ __launch_bounds__(THREADS_PER_CTA) void fast_ln_fwd_kernel(
254254

255255
#pragma unroll
256256
for (int it = 1; it < THREADS_PER_WARP; it *= 2) {
257+
#ifdef PADDLE_WITH_HIP
258+
mu_local += __shfl_xor(mu_local, it);
259+
#else
257260
mu_local += __shfl_xor_sync(uint32_t(-1), mu_local, it);
261+
#endif
258262
}
259263
if (WARPS_N > 1) {
260264
if (lane == 0) {
@@ -290,7 +294,11 @@ __global__ __launch_bounds__(THREADS_PER_CTA) void fast_ln_fwd_kernel(
290294

291295
#pragma unroll
292296
for (int it = 1; it < THREADS_PER_WARP; it *= 2) {
297+
#ifdef PADDLE_WITH_HIP
298+
var_local += __shfl_xor(var_local, it);
299+
#else
293300
var_local += __shfl_xor_sync(uint32_t(-1), var_local, it);
301+
#endif
294302
}
295303

296304
if (WARPS_N > 1) {
@@ -546,7 +554,7 @@ __inline__ __device__ void cuLoadAddStridedInputs(const int64_t i1_block,
546554
}
547555
}
548556

549-
#ifdef PADDLE_WITH_CUDA
557+
#if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP)
550558
template <bool IsFusedDropoutResidualLn,
551559
bool NeedDDropoutSrcPtr,
552560
typename T,
@@ -678,16 +686,26 @@ __global__ __launch_bounds__(THREADS_PER_CTA) void fused_ln_bwd_fast_kernel(
678686
#pragma unroll
679687
// row reduction among 32 threads.
680688
for (int it = 1; it < THREADS_PER_WARP; it *= 2) {
689+
#ifdef PADDLE_WITH_HIP
690+
sum_loss1 += __shfl_xor(sum_loss1, it);
691+
sum_loss2 += __shfl_xor(sum_loss2, it);
692+
#else
681693
sum_loss1 += __shfl_xor_sync(uint32_t(-1), sum_loss1, it);
682694
sum_loss2 += __shfl_xor_sync(uint32_t(-1), sum_loss2, it);
695+
#endif
683696
}
684697
sum_loss1 *= rn;
685698
sum_loss2 *= rn;
686699
} else {
687700
#pragma unroll
688701
for (int it = 16; it > 0; it /= 2) {
702+
#ifdef PADDLE_WITH_HIP
703+
sum_loss1 += __shfl_down(sum_loss1, it);
704+
sum_loss2 += __shfl_down(sum_loss2, it);
705+
#else
689706
sum_loss1 += __shfl_down_sync(uint32_t(-1), sum_loss1, it);
690707
sum_loss2 += __shfl_down_sync(uint32_t(-1), sum_loss2, it);
708+
#endif
691709
}
692710

693711
if (lane == 0) {

paddle/phi/kernels/fusion/gpu/fused_bias_dropout_residual_layer_norm_grad_kernel.cu

Lines changed: 15 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,12 @@
1111
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
1212
// See the License for the specific language governing permissions and
1313
// limitations under the License.
14-
#ifndef PADDLE_WITH_HIP
14+
#ifdef PADDLE_WITH_HIP
15+
#include <hip/hip_fp16.h>
16+
#include <hip/hip_runtime.h>
17+
#include <hipcub/hipcub.hpp>
18+
namespace cub = hipcub;
19+
#else
1520
#include <cuda_fp16.h>
1621
#include <cub/cub.cuh>
1722
#endif
@@ -21,9 +26,7 @@
2126
#include "paddle/phi/core/kernel_registry.h"
2227
#include "paddle/phi/core/tensor_utils.h"
2328
#include "paddle/phi/kernels/funcs/layer_norm_impl.cu.h"
24-
#ifndef PADDLE_WITH_HIP
2529
#include "paddle/phi/kernels/fusion/gpu/fused_dropout_helper.h"
26-
#endif
2730

2831
namespace phi {
2932
namespace fusion {
@@ -51,7 +54,6 @@ void FusedBiasDropoutResidualLnGradKernel(
5154
DenseTensor* bias_grad,
5255
DenseTensor* ln_scale_grad,
5356
DenseTensor* ln_bias_grad) {
54-
#ifndef PADDLE_WITH_HIP
5557
using U = LayerNormParamType<T>;
5658
auto* d_y_data = y_grad.data<T>();
5759
auto* ln_scale_data =
@@ -114,19 +116,24 @@ void FusedBiasDropoutResidualLnGradKernel(
114116
d_x_data,
115117
d_bias_data,
116118
d_residual_data);
117-
#else
118-
PADDLE_THROW(phi::errors::Unimplemented(
119-
"FusedBiasDropoutResidualLnGradKernel not surpport for rocm"));
120-
#endif
121119
}
122120

123121
} // namespace fusion
124122
} // namespace phi
125123

124+
#ifdef PADDLE_WITH_HIP
125+
PD_REGISTER_KERNEL(fused_bias_dropout_residual_layer_norm_grad,
126+
GPU,
127+
ALL_LAYOUT,
128+
phi::fusion::FusedBiasDropoutResidualLnGradKernel,
129+
float,
130+
phi::dtype::float16) {}
131+
#else
126132
PD_REGISTER_KERNEL(fused_bias_dropout_residual_layer_norm_grad,
127133
GPU,
128134
ALL_LAYOUT,
129135
phi::fusion::FusedBiasDropoutResidualLnGradKernel,
130136
float,
131137
double,
132138
phi::dtype::float16) {}
139+
#endif

paddle/phi/kernels/fusion/gpu/fused_bias_dropout_residual_layer_norm_kernel.cu

Lines changed: 11 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -17,9 +17,7 @@
1717
#include "paddle/phi/core/kernel_registry.h"
1818
#include "paddle/phi/core/tensor_utils.h"
1919
#include "paddle/phi/kernels/funcs/layer_norm_impl.cu.h"
20-
#ifndef PADDLE_WITH_HIP
2120
#include "paddle/phi/kernels/fusion/gpu/fused_dropout_helper.h"
22-
#endif
2321

2422
namespace phi {
2523
namespace fusion {
@@ -42,7 +40,6 @@ void FusedBiasDropoutResidualLnKernel(
4240
DenseTensor* dropout_mask_out,
4341
DenseTensor* ln_mean,
4442
DenseTensor* ln_variance) {
45-
#ifndef PADDLE_WITH_HIP
4643
using U = phi::funcs::LayerNormParamType<T>;
4744
auto* x_data = x.data<T>();
4845
auto* bias_data = (bias.get_ptr() == nullptr) ? nullptr : bias->data<T>();
@@ -95,14 +92,20 @@ void FusedBiasDropoutResidualLnKernel(
9592
y_data,
9693
ln_mean_data,
9794
ln_var_data);
98-
#else
99-
PADDLE_THROW(phi::errors::Unimplemented(
100-
"FusedBiasDropoutResidualLnKernel not support for rocm"));
101-
#endif
10295
}
10396
} // namespace fusion
10497
} // namespace phi
10598

99+
#ifdef PADDLE_WITH_HIP
100+
PD_REGISTER_KERNEL(fused_bias_dropout_residual_layer_norm,
101+
GPU,
102+
ALL_LAYOUT,
103+
phi::fusion::FusedBiasDropoutResidualLnKernel,
104+
float,
105+
phi::dtype::float16) {
106+
kernel->OutputAt(1).SetDataType(phi::DataType::UINT8);
107+
}
108+
#else
106109
PD_REGISTER_KERNEL(fused_bias_dropout_residual_layer_norm,
107110
GPU,
108111
ALL_LAYOUT,
@@ -112,3 +115,4 @@ PD_REGISTER_KERNEL(fused_bias_dropout_residual_layer_norm,
112115
phi::dtype::float16) {
113116
kernel->OutputAt(1).SetDataType(phi::DataType::UINT8);
114117
}
118+
#endif

paddle/phi/kernels/fusion/gpu/fused_dropout_act_bias.h

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -35,7 +35,11 @@ struct GeluFunctor {
3535
template <typename T>
3636
struct FastGeluFunctor {
3737
inline __device__ T operator()(const T x) const {
38+
#ifdef PADDLE_WITH_HIP
39+
assert(0 && "ROCM does not support FastGelu");
40+
#else
3841
return phi::GeluFwd<T, true>(x);
42+
#endif
3943
}
4044
};
4145

@@ -92,8 +96,8 @@ __global__ void FusedDropoutActBias(
9296
int row_id = blockIdx.y;
9397
int idx = row_id * cols + col_id;
9498

95-
curandStatePhilox4_32_10_t state;
96-
curand_init(seed, idx, increment, &state);
99+
GPURAND(StatePhilox4_32_10_t) state;
100+
GPURAND(_init)(seed, idx, increment, &state);
97101

98102
const T factor =
99103
phi::fusion::GetFactor<T>(dropout_prob, is_upscale_in_train, is_test);

paddle/phi/kernels/fusion/gpu/fused_dropout_add_grad_kernel.cu

Lines changed: 4 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -202,18 +202,6 @@ void FusedDropoutAddGradKernel(const Context& dev_ctx,
202202
? NoMaskBwFunctor<T, float>(1.0f - dropout_rate)
203203
: NoMaskBwFunctor<T, float>(1.0f - dropout_rate, 1.0f);
204204

205-
#ifdef PADDLE_WITH_HIP
206-
VectorizedDropoutBackward<T, NoMaskBwFunctor<T, float>>
207-
<<<grid_size, block_size, 0, stream>>>(0,
208-
numel,
209-
seed_data, // idx: 2 need save
210-
x_grad_data,
211-
y_grad_data,
212-
out_grad_data,
213-
increment, // idx: 6 need save
214-
main_offset,
215-
functor);
216-
#else
217205
// we assume seed/offset is same across iterations
218206
// seed_offset_data should preserved by cudaGraph pool
219207
const phi::GPUContext* dev_ctx_p = &dev_ctx;
@@ -233,9 +221,13 @@ void FusedDropoutAddGradKernel(const Context& dev_ctx,
233221
cudaKernelCallback = [=](unsigned int id) {
234222
void* functionPtr = reinterpret_cast<void*>(
235223
&(VectorizedDropoutBackward<T, NoMaskBwFunctor<T, float>>));
224+
#ifdef PADDLE_WITH_HIP
225+
hipFunction_t cudaFunc = reinterpret_cast<hipFunction_t>(functionPtr);
226+
#else
236227
cudaFunction_t cudaFunc;
237228
PADDLE_ENFORCE_GPU_SUCCESS(
238229
cudaGetFuncBySymbol(&cudaFunc, functionPtr));
230+
#endif
239231
VLOG(10) << "[cudaKernelCallback] cudaFunc = " << cudaFunc
240232
<< " functionPtr = " << functionPtr;
241233

@@ -257,7 +249,6 @@ void FusedDropoutAddGradKernel(const Context& dev_ctx,
257249

258250
VLOG(10) << "NON_CUDA_GRAPH seed = " << seed_data
259251
<< ", increment = " << increment;
260-
#endif
261252
}
262253
}
263254

paddle/phi/kernels/fusion/gpu/fused_dropout_add_kernel.cu

Lines changed: 4 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -186,18 +186,6 @@ void FusedDropoutAddKernel(const Context& dev_ctx,
186186
auto dst_functor =
187187
NoMaskFwFunctor<T, float>(1.0f - dropout_rate, upscale_in_train);
188188

189-
#ifdef PADDLE_WITH_HIP
190-
VectorizedDropoutForward<T, NoMaskFwFunctor<T, float>>
191-
<<<grid_size, block_size, 0, stream>>>(0,
192-
numel,
193-
seed_data, // need save
194-
x_data,
195-
y_data,
196-
out_data,
197-
increment, // need save
198-
main_offset,
199-
dst_functor);
200-
#else
201189
// we assume seed/offset is same across iterations
202190
// seed_offset_data should preserved by cudaGraph pool
203191
const phi::GPUContext* dev_ctx_p = &dev_ctx;
@@ -237,9 +225,13 @@ void FusedDropoutAddKernel(const Context& dev_ctx,
237225
cudaKernelCallback = [=](unsigned int id) {
238226
void* functionPtr = reinterpret_cast<void*>(
239227
&(VectorizedDropoutForward<T, NoMaskFwFunctor<T, float>>));
228+
#ifdef PADDLE_WITH_HIP
229+
hipFunction_t cudaFunc = reinterpret_cast<hipFunction_t>(functionPtr);
230+
#else
240231
cudaFunction_t cudaFunc;
241232
PADDLE_ENFORCE_GPU_SUCCESS(
242233
cudaGetFuncBySymbol(&cudaFunc, functionPtr));
234+
#endif
243235
VLOG(10) << "[cudaKernelCallback] cudaFunc = " << cudaFunc
244236
<< " functionPtr = " << functionPtr;
245237

@@ -260,7 +252,6 @@ void FusedDropoutAddKernel(const Context& dev_ctx,
260252

261253
VLOG(10) << "NON_CUDA_GRAPH seed = " << seed_data
262254
<< ", increment = " << increment;
263-
#endif
264255
} else {
265256
using MT = typename phi::dtype::MPTypeTrait<T>::Type;
266257
MT factor = static_cast<MT>(1.0f - dropout_rate);

0 commit comments

Comments
 (0)