@@ -95,17 +95,22 @@ class MathTypedTest : public MathTest {
9595 Tuple (&b, {IsFinite (x), IsInf (x), IsPosInf (x), IsNegInf (x), IsNan (x)});
9696
9797 bool has_inf = std::numeric_limits<T>::has_infinity;
98+ bool has_nan = std::numeric_limits<T>::has_quiet_NaN;
99+ bool is_finite = !has_inf && !has_nan;
100+ bool is_nan_only = !has_inf && has_nan;
101+
98102 auto expected = LiteralUtil::MakeTupleOwned (
99- LiteralUtil::CreateR1<bool >(
100- { true , true , true , true , true , false , false , false , false }),
103+ LiteralUtil::CreateR1<bool >({ true , true , true , true , true , is_finite,
104+ is_finite, is_finite, is_finite }),
101105 LiteralUtil::CreateR1<bool >({false , false , false , false , false , has_inf,
102106 has_inf, false , false }),
103107 LiteralUtil::CreateR1<bool >(
104108 {false , false , false , false , false , has_inf, false , false , false }),
105109 LiteralUtil::CreateR1<bool >(
106110 {false , false , false , false , false , false , has_inf, false , false }),
107111 LiteralUtil::CreateR1<bool >({false , false , false , false , false ,
108- !has_inf, !has_inf, true , true }));
112+ is_nan_only, is_nan_only, has_nan,
113+ has_nan}));
109114 ComputeAndCompareLiteral (&b, expected, {});
110115 }
111116
@@ -118,10 +123,11 @@ class MathTypedTest : public MathTest {
118123 LiteralUtil::CreateR1<T>({T{-0.0 }, T{0 }, T{1 }, T{-1 }, inf, -inf, nan}),
119124 &b));
120125
126+ bool is_mx = std::is_same_v<T, tsl::float4_e2m1fn>;
121127 ComputeAndCompareLiteral (
122128 &b,
123129 LiteralUtil::CreateR1<bool >(
124- {has_negative_zero_v<T>, false , false , false , false , false , false }),
130+ {has_negative_zero_v<T>, false , false , false , false , false , is_mx }),
125131 {}, error_spec_);
126132 }
127133
@@ -136,6 +142,9 @@ class MathTypedTest : public MathTest {
136142 // For good measure, we also check pow with an exponent other than 0.5.
137143 void TestSqrtPowInequivalence () {
138144 SetFastMathDisabled (true );
145+ if (std::is_same_v<T, tsl::float4_e2m1fn>) {
146+ GTEST_SKIP () << " Skipping due to low precision" ;
147+ }
139148
140149 // Tests disable constant folding by default, but this test needs it
141150 // enabled, otherwise we don't tickle the bug we're trying to catch.
@@ -181,19 +190,24 @@ class MathTypedTest : public MathTest {
181190 &b);
182191 Erf (x);
183192
184- bool has_inf = std::numeric_limits<T>::has_infinity;
185- std::vector<T> expected = {
186- has_inf ? T (-1 ) : nan, has_inf ? T (1 ) : nan, T (-0 ), T (0 ), T (-1 ), T (1 )};
193+ bool inf_as_nan = !std::numeric_limits<T>::has_infinity &&
194+ std::numeric_limits<T>::has_quiet_NaN;
195+ std::vector<T> expected = {inf_as_nan ? nan : T (-1 ),
196+ inf_as_nan ? nan : T (1 ),
197+ T (-0 ),
198+ T (0 ),
199+ T (-1 ),
200+ T (1 )};
187201
188202 ComputeAndCompareR1<T>(&b, expected, {}, error_spec_);
189203 }
190204};
191205
192206// TODO(b/123355973): Add bfloat16 to TestTypes once it's working.
193207using TestTypes =
194- ::testing::Types<tsl::float8_e3m4 , tsl::float8_e4m3 , tsl::float8_e4m3fnuz ,
195- tsl::float8_e4m3b11fnuz , tsl::float8_e5m2 ,
196- tsl::float8_e5m2fnuz,
208+ ::testing::Types<tsl::float4_e2m1fn , tsl::float8_e3m4 , tsl::float8_e4m3 ,
209+ tsl::float8_e4m3fnuz , tsl::float8_e4m3b11fnuz ,
210+ tsl::float8_e5m2, tsl:: float8_e5m2fnuz,
197211#ifndef XLA_BACKEND_DOES_NOT_SUPPORT_FLOAT16
198212 Eigen::half,
199213#endif
0 commit comments