2424
2525namespace Eigen {
2626
27- using bfloat16 = paddle::platform::bfloat16;
2827using complex64 = paddle::platform::complex64;
2928using complex128 = paddle::platform::complex128;
3029using float16 = paddle::platform::float16;
@@ -33,30 +32,31 @@ template <typename T>
3332struct NumTraits ;
3433
3534template <>
36- struct NumTraits <bfloat16> : GenericNumTraits<bfloat16> {
35+ struct NumTraits <paddle::platform::bfloat16>
36+ : GenericNumTraits<paddle::platform::bfloat16> {
3737 enum {
3838 IsSigned = true ,
3939 IsInteger = false ,
4040 IsComplex = false ,
4141 RequireInitialization = false
4242 };
4343
44- HOSTDEVICE static inline bfloat16 epsilon () {
44+ HOSTDEVICE static inline paddle::platform:: bfloat16 epsilon () {
4545 return paddle::platform::raw_uint16_to_bfloat16 (0x3400 );
4646 }
47- HOSTDEVICE static inline bfloat16 dummy_precision () {
48- return bfloat16 (1e-5f );
47+ HOSTDEVICE static inline paddle::platform:: bfloat16 dummy_precision () {
48+ return paddle::platform:: bfloat16 (1e-5f );
4949 }
50- HOSTDEVICE static inline bfloat16 highest () {
50+ HOSTDEVICE static inline paddle::platform:: bfloat16 highest () {
5151 return paddle::platform::raw_uint16_to_bfloat16 (0x7f7f );
5252 }
53- HOSTDEVICE static inline bfloat16 lowest () {
53+ HOSTDEVICE static inline paddle::platform:: bfloat16 lowest () {
5454 return paddle::platform::raw_uint16_to_bfloat16 (0xff7f );
5555 }
56- HOSTDEVICE static inline bfloat16 infinity () {
56+ HOSTDEVICE static inline paddle::platform:: bfloat16 infinity () {
5757 return paddle::platform::raw_uint16_to_bfloat16 (0x7f80 );
5858 }
59- HOSTDEVICE static inline bfloat16 quiet_NaN () {
59+ HOSTDEVICE static inline paddle::platform:: bfloat16 quiet_NaN () {
6060 return paddle::platform::raw_uint16_to_bfloat16 (0xffc1 );
6161 }
6262};
@@ -137,68 +137,91 @@ namespace numext {
137137// ////////// bfloat methods /////////////
138138
139139template <>
140- HOSTDEVICE inline bool (isnan)(const bfloat16& a) {
140+ HOSTDEVICE inline bool (isnan)(const paddle::platform:: bfloat16& a) {
141141 return (paddle::platform::isnan)(a);
142142}
143143
144144template <>
145- HOSTDEVICE inline bool (isinf)(const bfloat16& a) {
145+ HOSTDEVICE inline bool (isinf)(const paddle::platform:: bfloat16& a) {
146146 return (paddle::platform::isinf)(a);
147147}
148148
149149template <>
150- HOSTDEVICE inline bool (isfinite)(const bfloat16& a) {
150+ HOSTDEVICE inline bool (isfinite)(const paddle::platform:: bfloat16& a) {
151151 return (paddle::platform::isfinite)(a);
152152}
153153
154154template <>
155- HOSTDEVICE inline bfloat16 exp (const bfloat16& a) {
156- return bfloat16 (::expf (static_cast <float >(a)));
155+ HOSTDEVICE inline paddle::platform::bfloat16 exp (
156+ const paddle::platform::bfloat16& a) {
157+ return paddle::platform::bfloat16 (::expf (static_cast <float >(a)));
157158}
158159
159160template <>
160- HOSTDEVICE inline bfloat16 erf (const bfloat16& a) {
161- return bfloat16 (::erff (static_cast <float >(a)));
161+ HOSTDEVICE inline paddle::platform::bfloat16 erf (
162+ const paddle::platform::bfloat16& a) {
163+ return paddle::platform::bfloat16 (::erff (static_cast <float >(a)));
162164}
163165
164166template <>
165- HOSTDEVICE inline bfloat16 log (const bfloat16& a) {
166- return bfloat16 (::logf (static_cast <float >(a)));
167+ HOSTDEVICE inline paddle::platform::bfloat16 log (
168+ const paddle::platform::bfloat16& a) {
169+ return paddle::platform::bfloat16 (::logf (static_cast <float >(a)));
167170}
168171
169172template <>
170- HOSTDEVICE inline bfloat16 tanh (const bfloat16& a) {
171- return bfloat16 (::tanhf (static_cast <float >(a)));
173+ HOSTDEVICE inline paddle::platform::bfloat16 tanh (
174+ const paddle::platform::bfloat16& a) {
175+ return paddle::platform::bfloat16 (::tanhf (static_cast <float >(a)));
172176}
173177
174178template <>
175- HOSTDEVICE inline bfloat16 sqrt (const bfloat16& a) {
176- return bfloat16 (::sqrtf (static_cast <float >(a)));
179+ HOSTDEVICE inline paddle::platform::bfloat16 sqrt (
180+ const paddle::platform::bfloat16& a) {
181+ return paddle::platform::bfloat16 (::sqrtf (static_cast <float >(a)));
177182}
178183
179184template <>
180- HOSTDEVICE inline bfloat16 ceil (const bfloat16& a) {
181- return bfloat16 (::ceilf (static_cast <float >(a)));
185+ HOSTDEVICE inline paddle::platform::bfloat16 ceil (
186+ const paddle::platform::bfloat16& a) {
187+ return paddle::platform::bfloat16 (::ceilf (static_cast <float >(a)));
182188}
183189
184190template <>
185- HOSTDEVICE inline bfloat16 floor (const bfloat16& a) {
186- return bfloat16 (::floorf (static_cast <float >(a)));
191+ HOSTDEVICE inline paddle::platform::bfloat16 floor (
192+ const paddle::platform::bfloat16& a) {
193+ return paddle::platform::bfloat16 (::floorf (static_cast <float >(a)));
187194}
188195
189196template <>
190- HOSTDEVICE inline bfloat16 round (const bfloat16& a) {
191- return bfloat16 (::roundf (static_cast <float >(a)));
197+ HOSTDEVICE inline paddle::platform::bfloat16 round (
198+ const paddle::platform::bfloat16& a) {
199+ return paddle::platform::bfloat16 (::roundf (static_cast <float >(a)));
192200}
193201
194202template <>
195- HOSTDEVICE inline bfloat16 pow (const bfloat16& a, const bfloat16& b) {
196- return bfloat16 (::powf (static_cast <float >(a), static_cast <float >(b)));
203+ HOSTDEVICE inline paddle::platform::bfloat16 pow (
204+ const paddle::platform::bfloat16& a, const paddle::platform::bfloat16& b) {
205+ return paddle::platform::bfloat16 (
206+ ::powf (static_cast <float >(a), static_cast<float>(b)));
197207}
198208
199209template <>
200- HOSTDEVICE inline bfloat16 abs (const bfloat16& a) {
201- return bfloat16 (::fabs (static_cast <float >(a)));
210+ HOSTDEVICE inline paddle::platform::bfloat16 abs (
211+ const paddle::platform::bfloat16& a) {
212+ return paddle::platform::bfloat16 (::fabs (static_cast <float >(a)));
213+ }
214+
215+ template <>
216+ HOSTDEVICE inline paddle::platform::bfloat16 mini (
217+ const paddle::platform::bfloat16& a, const paddle::platform::bfloat16& b) {
218+ return b < a ? b : a;
219+ }
220+
221+ template <>
222+ HOSTDEVICE inline paddle::platform::bfloat16 maxi (
223+ const paddle::platform::bfloat16& a, const paddle::platform::bfloat16& b) {
224+ return a < b ? b : a;
202225}
203226
204227// ////////// complex64 methods /////////////
@@ -398,5 +421,15 @@ HOSTDEVICE inline float16 abs(const float16& a) {
398421 return float16 (::fabs (static_cast <float >(a)));
399422}
400423
424+ template <>
425+ HOSTDEVICE inline float16 mini (const float16& a, const float16& b) {
426+ return b < a ? b : a;
427+ }
428+
429+ template <>
430+ HOSTDEVICE inline float16 maxi (const float16& a, const float16& b) {
431+ return a < b ? b : a;
432+ }
433+
401434} // namespace numext
402435} // namespace Eigen
0 commit comments