@@ -163,7 +163,8 @@ int GetSignificandBits(mlir::FloatType ty) {
163163}
164164
165165int GetExponentBias (mlir::FloatType ty) {
166- return 1 - llvm::APFloat::semanticsMinExponent (ty.getFloatSemantics ());
166+ return 1 - llvm::APFloat::semanticsMinExponent (ty.getFloatSemantics ()) -
167+ ty.isFloat8E8M0FNU (); // No zero exponent for E8M0.
167168}
168169
169170bool IsFNUZ (mlir::FloatType ty) {
@@ -215,6 +216,8 @@ Value IsNaN(Value value, mlir::ImplicitLocOpBuilder& b) {
215216 return (bits & 0b0111'1111 ) == 0b0111'1111 ;
216217 } else if (ty.isFloat8E3M4 ()) {
217218 return (bits & 0b0111'1111 ).cmp (ma::CmpIPredicate::ugt, 0b0111'0000 );
219+ } else if (ty.isFloat8E8M0FNU ()) {
220+ return bits == 0xFF ;
218221 }
219222 return bits == 0x80 ;
220223}
@@ -294,6 +297,13 @@ Value EmitFloatConversion(Value value, mlir::FloatType to_ty,
294297 } else {
295298 wide_int_ty = b.getIntegerType (
296299 std::max (from_int_ty.getWidth (), to_int_ty.getWidth ()));
300+ // Avoid overflow for bit shifts.
301+ auto may_overflow = [&](mlir::Type a, mlir::Type b) {
302+ return a.isFloat8E8M0FNU () && b.isF16 ();
303+ };
304+ if (may_overflow (from_ty, to_ty) || may_overflow (to_ty, from_ty)) {
305+ wide_int_ty = b.getI32Type ();
306+ }
297307 }
298308 auto convert_int = [&](mlir::Type ty, Value v) -> Val {
299309 if (v.getType () == ty) {
@@ -320,11 +330,11 @@ Value EmitFloatConversion(Value value, mlir::FloatType to_ty,
320330 };
321331
322332 // Shift bits to destination type, without sign bit.
323- Val from_sign_bit = from_bits. shrui (from_width - 1 ) != 0 ;
324- from_bits = from_bits & (( 1ULL << (from_width - 1 )) - 1 );
325-
326- Value result_is_inf = IsInf (value, b );
327- Value input_is_nan = IsNaN (value, b);
333+ Val from_sign_bit;
334+ if (!from_ty. isFloat8E8M0FNU ( )) {
335+ from_sign_bit = from_bits. shrui (from_width - 1 ) != 0 ;
336+ from_bits = from_bits & (( 1ULL << (from_width - 1 )) - 1 );
337+ }
328338
329339 auto cst_bits = [&](llvm::APFloat f) {
330340 return cst (b.getIntegerType (llvm::APFloat::getSizeInBits (f.getSemantics ())),
@@ -338,9 +348,13 @@ Value EmitFloatConversion(Value value, mlir::FloatType to_ty,
338348 if (to_ty.isFloat4E2M1FN ()) {
339349 to_inf = cst_bits (llvm::APFloat::getLargest (to_ty.getFloatSemantics ()));
340350 to_nan = to_zero | 0x8 ;
351+ } else if (to_ty.isFloat8E8M0FNU ()) {
352+ to_inf = to_nan;
353+ to_zero = Val{to_nan, &b};
341354 }
342355
343- auto round_bits_to_nearest_even = [&](Val bits, Val roundoff) {
356+ auto round_bits_to_nearest_even = [&](Val bits, Val roundoff,
357+ bool use_implicit_bit = false ) {
344358 assert (bits.value .getType () == roundoff.value .getType ());
345359 // Round to nearest even by adding a bias term.
346360 // Consider a bit pattern
@@ -350,9 +364,10 @@ Value EmitFloatConversion(Value value, mlir::FloatType to_ty,
350364 // - L is 1, R is 1, OR
351365 // - L is 0, R is 1, any T is one.
352366 // We do this by adding L to a bit pattern consisting of all T = 1.
353- Val rounded = (bits.shrui (roundoff) & 1 ) +
354- (bits.MakeConstant (1 ).shl (roundoff - 1 ) - 1 );
355- Val bias{b.create <SelectOp>(roundoff == 0 , roundoff, rounded), &b};
367+ Val bias = !use_implicit_bit
368+ ? (bits.shrui (roundoff) & 1 ) +
369+ (bits.MakeConstant (1 ).shl (roundoff - 1 ) - 1 )
370+ : bits.MakeConstant (1 ).shl (roundoff - 1 );
356371 return bits + bias;
357372 };
358373
@@ -362,9 +377,11 @@ Value EmitFloatConversion(Value value, mlir::FloatType to_ty,
362377 // Round the mantissa if it is shrinking.
363378 Val rounded_from_bits = convert_int (wide_int_ty, from_bits);
364379 if (digit_shift < 0 ) {
365- rounded_from_bits = round_bits_to_nearest_even (
366- from_bits, from_bits.MakeConstant (-digit_shift)) &
367- ~((1ll << (-digit_shift)) - 1 );
380+ rounded_from_bits =
381+ round_bits_to_nearest_even (
382+ rounded_from_bits, rounded_from_bits.MakeConstant (-digit_shift),
383+ /* use_implicit_bit=*/ to_mantissa == 0 ) &
384+ ~((1ll << (-digit_shift)) - 1 );
368385 }
369386
370387 // Re-bias the exponent.
@@ -431,10 +448,12 @@ Value EmitFloatConversion(Value value, mlir::FloatType to_ty,
431448 Value biased_exp_sle_zero = biased_exponent.cmp (CmpIPredicate::sle, 0 );
432449 bits.value =
433450 b.create <SelectOp>(biased_exp_sle_zero, subnormal_bits, normal_bits);
434- if (digit_shift > 0 ) {
451+ if (digit_shift >= 0 ) {
435452 bits = bits.shl (digit_shift);
436453 } else {
437- bits = round_bits_to_nearest_even (bits, bits.MakeConstant (-digit_shift));
454+ bits = round_bits_to_nearest_even (
455+ bits, bits.MakeConstant (-digit_shift),
456+ /* use_implicit_bit=*/ to_mantissa == 0 && exp_offset != 0 );
438457 bits = bits.shrui (-digit_shift);
439458 }
440459 bits = convert_int (to_int_ty, bits);
@@ -443,11 +462,11 @@ Value EmitFloatConversion(Value value, mlir::FloatType to_ty,
443462 } else if (to_min_exp > from_min_exp) {
444463 // `To` supports fewer exponents near zero which means that some values in
445464 // `From` may become subnormal.
446- Val unbiased_exp = biased_from_exp - from_bias;
447- Val biased_to_exp = unbiased_exp + to_bias;
465+ Val biased_to_exp = biased_from_exp + (to_bias - from_bias);
448466 // Subnormals and zero.
449467 // Round and shift mantissa down.
450- Val from_has_leading_one = biased_from_exp != 0 ;
468+ Val from_has_leading_one =
469+ !from_ty.isFloat8E8M0FNU () ? biased_from_exp != 0 : cst (i32_ty, 1 );
451470 Val from_has_leading_one_i32 = convert_int (i32_ty, from_has_leading_one);
452471 from_has_leading_one = convert_int (from_int_ty, from_has_leading_one);
453472 Val exponent_shift_i32 =
@@ -482,7 +501,13 @@ Value EmitFloatConversion(Value value, mlir::FloatType to_ty,
482501 result);
483502 }
484503
485- if (IsFNUZ (to_ty)) {
504+ Value result_is_inf = IsInf (value, b);
505+ Value input_is_nan = IsNaN (value, b);
506+
507+ if (to_ty.isFloat8E8M0FNU ()) {
508+ // Converting a negative number to E8M0 results in NaN.
509+ input_is_nan = from_sign_bit | input_is_nan;
510+ } else if (IsFNUZ (to_ty)) {
486511 // Clear the sign bit if the result is zero (the output has no negative
487512 // zero). Handle the edge case when the input is zero and the result is not.
488513 Val result_is_non_zero =
@@ -494,14 +519,17 @@ Value EmitFloatConversion(Value value, mlir::FloatType to_ty,
494519 from_sign_bit = from_sign_bit ^ input_is_nan;
495520 }
496521
522+ if (!from_ty.isFloat8E8M0FNU ()) {
523+ result = b.create <SelectOp>(from_bits == 0 , to_zero, result);
524+ }
497525 result = b.create <SelectOp>(result_is_inf, to_inf, result);
498- result = b.create <SelectOp>(from_bits == 0 , to_zero, result);
499526 result = b.create <SelectOp>(input_is_nan, to_nan, result);
500527
501- Value neg_result = Val{result, &b} | (1ll << (to_int_ty.getWidth () - 1 ));
502-
503528 // Insert sign bit.
504- result = b.create <SelectOp>(from_sign_bit, neg_result, result);
529+ if (!from_ty.isFloat8E8M0FNU ()) {
530+ Value neg_result = Val{result, &b} | (1ll << (to_int_ty.getWidth () - 1 ));
531+ result = b.create <SelectOp>(from_sign_bit, neg_result, result);
532+ }
505533 result = b.create <ma::BitcastOp>(to_ty, result);
506534 return result;
507535}
@@ -598,6 +626,14 @@ struct RewriteAbsFPattern : public mlir::OpRewritePattern<mlir::math::AbsFOp> {
598626 return rewriter.notifyMatchFailure (op,
599627 " not an f8 (or less) or bf16 absf" );
600628 }
629+
630+ // If type is unsigned (E8M0), the operation is no-op.
631+ if (!llvm::APFloat::semanticsHasSignedRepr (
632+ src.getType ().getFloatSemantics ())) {
633+ rewriter.replaceAllOpUsesWith (op, op.getOperand ());
634+ return mlir::success ();
635+ }
636+
601637 mlir::ImplicitLocOpBuilder b (op.getLoc (), rewriter);
602638 mlir::Type i_ty = rewriter.getIntegerType (src.getType ().getWidth ());
603639 Val value{b.create <ma::BitcastOp>(i_ty, src), &b};
0 commit comments