@@ -16,48 +16,31 @@ limitations under the License. */
1616
1717#include " hip/hip_fp16.h"
1818#include " hip/hip_runtime.h"
19- #include " paddle/fluid/platform/float16.h"
2019
2120namespace paddle {
2221namespace platform {
2322
2423#define CREATE_SHFL_MASK (mask, predicate ) mask = 0u ;
2524
26- static __forceinline__ __device__ float CudaShuffleDownSync (unsigned mask, float val,
25+ template <typename T>
26+ static __forceinline__ __device__ T CudaShuffleDownSync (unsigned mask, T val,
2727 int delta, int width = 32 ) {
2828 return __shfl_down (val, delta, width);
2929}
3030
31- static __forceinline__ __device__ float CudaShuffleSync (unsigned mask, float val, int src_line,
31+ template <typename T>
32+ static __forceinline__ __device__ T CudaShuffleSync (unsigned mask, T val, int src_line,
3233 int width = 32 ) {
3334 return __shfl (val, src_line, width);
3435}
3536
36- static __forceinline__ __device__ int CudaShuffleDownSync (unsigned mask, int val,
37- int delta, int width) {
38- return __shfl_down (val, delta, width);
39- }
40-
41- static __forceinline__ __device__ int CudaShuffleSync (unsigned mask, int val, int src_line,
42- int width) {
43- return __shfl (val, src_line, width);
44- }
45-
46- static __forceinline__ __device__ paddle::platform::float16 CudaShuffleDownSync (unsigned mask, paddle::platform::float16 val,
47- int delta, int width) {
48- return (float )__shfl_down ((float )val, delta, width);
49- }
50-
51- static __forceinline__ __device__ paddle::platform::float16 CudaShuffleSync (unsigned mask, paddle::platform::float16 val, int src_line,
52- int width) {
53- return (float )__shfl ((float )val, src_line, width);
54- }
55-
37+ template <>
5638static __forceinline__ __device__ double CudaShuffleDownSync (unsigned mask, double val,
5739 int delta, int width) {
5840 return (float )__shfl_down ((float )val, delta, width);
5941}
6042
43+ template <>
6144static __forceinline__ __device__ double CudaShuffleSync (unsigned mask, double val, int src_line,
6245 int width) {
6346 return (float )__shfl ((float )val, src_line, width);
0 commit comments