@@ -457,6 +457,60 @@ CUDA_ATOMIC_WRAPPER(Mul, float) {
457457 return __int_as_float (old);
458458}
459459
460+ __device__ __forceinline__ uint32_t __loadAligned (const uintptr_t base_addr,
461+ uint32_t mask,
462+ uint32_t shift) {
463+ // get 4B aligned address
464+ uint32_t aligned_value = *reinterpret_cast <const uint32_t *>(base_addr);
465+ return (aligned_value & mask) >> shift;
466+ }
467+
468+ CUDA_ATOMIC_WRAPPER (Mul, uint8_t ) {
469+ // get 4D aligned base address
470+ uintptr_t base_addr = reinterpret_cast <uintptr_t >(address) & (~3 );
471+ uint32_t offset = reinterpret_cast <uintptr_t >(address) - base_addr;
472+ uint32_t shift = offset * 8 ;
473+ uint32_t mask = 0xFFU << shift;
474+
475+ uint32_t old32 = __loadAligned (base_addr, mask, shift), assumed32 = 0 ;
476+
477+ do {
478+ assumed32 = old32;
479+ uint8_t current = static_cast <uint8_t >((old32 & mask) >> shift);
480+ uint8_t new_val = current * val;
481+ uint32_t new32 =
482+ (old32 & ~mask) | (static_cast <uint32_t >(new_val) << shift);
483+
484+ old32 =
485+ atomicCAS (reinterpret_cast <uint32_t *>(base_addr), assumed32, new32);
486+ } while (assumed32 != old32);
487+
488+ return static_cast <uint8_t >((old32 & mask) >> shift);
489+ }
490+
491+ CUDA_ATOMIC_WRAPPER (Mul, int16_t ) {
492+ // get 4D aligned base address
493+ uintptr_t base_addr = reinterpret_cast <uintptr_t >(address) & (~3 );
494+ uint32_t offset = (reinterpret_cast <uintptr_t >(address) - base_addr) / 2 ;
495+ uint32_t shift = offset * 16 ;
496+ uint32_t mask = 0xFFFFU << shift;
497+
498+ uint32_t old32 = __loadAligned (base_addr, mask, shift), assumed32 = 0 ;
499+
500+ do {
501+ assumed32 = old32;
502+ int16_t current = static_cast <int16_t >((old32 & mask) >> shift);
503+ int16_t new_val = current * val;
504+ uint32_t new32 =
505+ (old32 & ~mask) | (static_cast <uint32_t >(new_val) << shift);
506+
507+ old32 =
508+ atomicCAS (reinterpret_cast <uint32_t *>(base_addr), assumed32, new32);
509+ } while (assumed32 != old32);
510+
511+ return static_cast <int16_t >((old32 & mask) >> shift);
512+ }
513+
460514CUDA_ATOMIC_WRAPPER (Mul, double ) {
461515 unsigned long long int *const address_as_ull = // NOLINT
462516 reinterpret_cast <unsigned long long int *>(address); // NOLINT
@@ -943,6 +997,41 @@ CUDA_ATOMIC_WRAPPER(Min, phi::dtype::bfloat16) {
943997 }
944998}
945999
1000+ #define DEFINE_ATOMIC_MINMAX (Dtype, OpType, operator ) \
1001+ __device__ __forceinline__ Dtype CudaAtomic##OpType(Dtype *address, \
1002+ const Dtype val) { \
1003+ uintptr_t base_addr = reinterpret_cast <uintptr_t >(address) & (~3 ); \
1004+ uint32_t offset_bytes = reinterpret_cast <uintptr_t >(address) - base_addr; \
1005+ uint32_t shift = 0 , mask = 0 ; \
1006+ if constexpr (sizeof (Dtype) == 1 ) { \
1007+ shift = offset_bytes * 8 ; \
1008+ mask = 0xFFU << shift; \
1009+ } else { \
1010+ shift = (offset_bytes / 2 ) * 16 ; \
1011+ mask = 0xFFFFU << shift; \
1012+ } \
1013+ Dtype current = 0 ; \
1014+ Dtype new_val = 0 ; \
1015+ uint32_t assumed32 = 0 , old32 = __loadAligned (base_addr, mask, shift); \
1016+ do { \
1017+ assumed32 = old32; \
1018+ current = static_cast <Dtype>((old32 & mask) >> shift); \
1019+ new_val = operator (current, val); \
1020+ uint32_t new32 = \
1021+ (old32 & ~mask) | (static_cast <uint32_t >(new_val) << shift); \
1022+ old32 = atomicCAS ( \
1023+ reinterpret_cast <uint32_t *>(base_addr), assumed32, new32); \
1024+ } while (assumed32 != old32); \
1025+ return current; \
1026+ }
1027+
1028+ DEFINE_ATOMIC_MINMAX (int16_t , Min, min)
1029+ DEFINE_ATOMIC_MINMAX (int16_t , Max, max)
1030+ DEFINE_ATOMIC_MINMAX (uint8_t , Min, min)
1031+ DEFINE_ATOMIC_MINMAX (uint8_t , Max, max)
1032+
1033+ #undef DEFINE_ATOMIC_MINMAX
1034+
9461035#ifdef PADDLE_WITH_CUDA
9471036/*
9481037 * One thead block deals with elementwise atomicAdd for vector of len.
0 commit comments