Skip to content

Commit addd5fc

Browse files
authored
miss format (#34771)
1 parent 4d2994c commit addd5fc

File tree

1 file changed

+11
-3
lines changed

1 file changed

+11
-3
lines changed

paddle/fluid/operators/math/bert_encoder_functor.cu

Lines changed: 11 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,14 @@ namespace paddle {
2525
namespace operators {
2626
namespace math {
2727

28+
template <typename T>
29+
__device__ __forceinline__ T local_rsqrt(T num) {
30+
return rsqrt(static_cast<float>(num));
31+
}
32+
#if CUDA_ARCH_FP16_SUPPORTED(__CUDA_ARCH__)
33+
__device__ __forceinline__ half local_rsqrt(half num) { return hrsqrt(num); }
34+
#endif
35+
2836
template <typename T, int TPB>
2937
__device__ inline void LayerNormSmall(T val, const kvp<T> &thread_data,
3038
const int ld, const int idx,
@@ -39,7 +47,7 @@ __device__ inline void LayerNormSmall(T val, const kvp<T> &thread_data,
3947

4048
if (threadIdx.x == 0) {
4149
mu = sum_kv.key;
42-
rsigma = rsqrt(sum_kv.value - mu * mu + eps);
50+
rsigma = local_rsqrt(sum_kv.value - mu * mu + eps);
4351
}
4452
__syncthreads();
4553

@@ -63,7 +71,7 @@ __device__ inline void LayerNorm(const kvp<T> &thread_data, const int ld,
6371

6472
if (threadIdx.x == 0) {
6573
mu = sum_kv.key;
66-
rsigma = rsqrt(sum_kv.value - mu * mu + eps);
74+
rsigma = local_rsqrt(sum_kv.value - mu * mu + eps);
6775
}
6876
__syncthreads();
6977

@@ -89,7 +97,7 @@ __device__ inline void LayerNorm2(const kvp<T> &thread_data, const int ld,
8997

9098
if (threadIdx.x == 0) {
9199
mu = sum_kv.key;
92-
rsigma = rsqrt(sum_kv.value - mu * mu + eps);
100+
rsigma = local_rsqrt(sum_kv.value - mu * mu + eps);
93101
}
94102
__syncthreads();
95103

0 commit comments

Comments
 (0)