@@ -3036,7 +3036,16 @@ template <typename T>
30363036struct FloorFunctor : public BaseActivationFunctor <T> {
30373037 template <typename Device, typename X, typename Out>
30383038 void operator ()(Device d, X x, Out out) const {
3039- out.device (d) = x.floor ();
3039+ if constexpr ((std::is_same<T, uint8_t >::value) ||
3040+ (std::is_same<T, int8_t >::value) ||
3041+ (std::is_same<T, uint16_t >::value) ||
3042+ (std::is_same<T, int16_t >::value) ||
3043+ (std::is_same<T, int >::value) ||
3044+ (std::is_same<T, int64_t >::value)) {
3045+ out.device (d) = x;
3046+ } else {
3047+ out.device (d) = x.floor ();
3048+ }
30403049 }
30413050};
30423051
@@ -3160,7 +3169,16 @@ template <typename T>
31603169struct CeilFunctor : public BaseActivationFunctor <T> {
31613170 template <typename Device, typename X, typename Out>
31623171 void operator ()(Device d, X x, Out out) const {
3163- out.device (d) = x.ceil ();
3172+ if constexpr ((std::is_same<T, uint8_t >::value) ||
3173+ (std::is_same<T, int8_t >::value) ||
3174+ (std::is_same<T, uint16_t >::value) ||
3175+ (std::is_same<T, int16_t >::value) ||
3176+ (std::is_same<T, int >::value) ||
3177+ (std::is_same<T, int64_t >::value)) {
3178+ out.device (d) = x;
3179+ } else {
3180+ out.device (d) = x.ceil ();
3181+ }
31643182 }
31653183};
31663184
@@ -5403,7 +5421,16 @@ struct CudaCeilFunctor : public BaseActivationFunctor<T> {
54035421 // ceil(x) = ceil(x)
54045422 __device__ __forceinline__ T operator ()(const T arg_x) const {
54055423 MPType x = static_cast <MPType>(arg_x);
5406- return static_cast <T>(ceil (x));
5424+ if constexpr ((std::is_same<T, uint8_t >::value) ||
5425+ (std::is_same<T, int8_t >::value) ||
5426+ (std::is_same<T, uint16_t >::value) ||
5427+ (std::is_same<T, int16_t >::value) ||
5428+ (std::is_same<T, int >::value) ||
5429+ (std::is_same<T, int64_t >::value)) {
5430+ return static_cast <T>(x);
5431+ } else {
5432+ return static_cast <T>(ceil (x));
5433+ }
54075434 }
54085435};
54095436
@@ -5492,7 +5519,16 @@ struct CudaFloorFunctor : public BaseActivationFunctor<T> {
54925519 // floor(x) = floor(x)
54935520 __device__ __forceinline__ T operator ()(const T arg_x) const {
54945521 MPType x = static_cast <MPType>(arg_x);
5495- return static_cast <T>(floor (x));
5522+ if constexpr ((std::is_same<T, uint8_t >::value) ||
5523+ (std::is_same<T, int8_t >::value) ||
5524+ (std::is_same<T, uint16_t >::value) ||
5525+ (std::is_same<T, int16_t >::value) ||
5526+ (std::is_same<T, int >::value) ||
5527+ (std::is_same<T, int64_t >::value)) {
5528+ return static_cast <T>(x);
5529+ } else {
5530+ return static_cast <T>(floor (x));
5531+ }
54965532 }
54975533};
54985534
0 commit comments