Skip to content
Merged
Show file tree
Hide file tree
Changes from 6 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
143 changes: 103 additions & 40 deletions llvm/lib/Target/X86/X86ISelLowering.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -29456,6 +29456,7 @@ static SDValue LowerFMINIMUM_FMAXIMUM(SDValue Op, const X86Subtarget &Subtarget,
SelectionDAG &DAG) {
const TargetLowering &TLI = DAG.getTargetLoweringInfo();
EVT VT = Op.getValueType();

SDValue X = Op.getOperand(0);
SDValue Y = Op.getOperand(1);
SDLoc DL(Op);
Expand All @@ -29478,9 +29479,25 @@ static SDValue LowerFMINIMUM_FMAXIMUM(SDValue Op, const X86Subtarget &Subtarget,
}

uint64_t SizeInBits = VT.getScalarSizeInBits();
// When one operand is a zero (positive or negative), we can avoid signed-zero
// ordering (potentially by flipping the operand order). x86's FMIN behaves
// like `X < Y ? X : Y`; FMAX likewise behaves like `X > Y ? X : Y`. Since
// zeroes compare as equal regardless of sign, the second operand is chosen
// whenever both operands are zero. So, here are the formulations for both
// operations and both signs that give us the correctly-ordered result, even
// if X is a signed zero:
//
// - `min(X, -0.0)` -> `X < -0.0 ? X : -0.0`
// - `min(X, 0.0)` -> `0.0 < X ? 0.0 : X`
// - `max(X, -0.0)` -> `-0.0 > X ? -0.0 : X`
// - `max(X, 0.0)` -> `X > 0.0 ? X : 0.0`
//
// Here, `PreferredZero` refers to the zero that goes in the "else" branch
// (it's "preferred" because it's chosen if both operands are equal or
// unordered). `OppositeZero` refers to the zero that *doesn't* go in the
// "else" branch.
APInt PreferredZero = APInt::getZero(SizeInBits);
APInt OppositeZero = PreferredZero;
EVT IVT = VT.changeTypeToInteger();
X86ISD::NodeType MinMaxOp;
if (IsMaxOp) {
MinMaxOp = X86ISD::FMAX;
Expand All @@ -29492,8 +29509,8 @@ static SDValue LowerFMINIMUM_FMAXIMUM(SDValue Op, const X86Subtarget &Subtarget,
EVT SetCCType =
TLI.getSetCCResultType(DAG.getDataLayout(), *DAG.getContext(), VT);

// The tables below show the expected result of Max in cases of NaN and
// signed zeros.
// The tables below show the expected result of Max in cases of NaN and signed
// zeros.
//
// Y Y
// Num xNaN +0 -0
Expand All @@ -29503,12 +29520,9 @@ static SDValue LowerFMINIMUM_FMAXIMUM(SDValue Op, const X86Subtarget &Subtarget,
// xNaN | X | X/Y | -0 | +0 | -0 |
// --------------- ---------------
//
// It is achieved by means of FMAX/FMIN with preliminary checks and operand
// reordering.
//
// We check if any of operands is NaN and return NaN. Then we check if any of
// operands is zero or negative zero (for fmaximum and fminimum respectively)
// to ensure the correct zero is returned.
// It is achieved by means of FMAX/FMIN with preliminary checks, operand
// reordering if one operand is a constant, and bitwise operations and selects
// to handle signed zero and NaN operands otherwise.
auto MatchesZero = [](SDValue Op, APInt Zero) {
Op = peekThroughBitcasts(Op);
if (auto *CstOp = dyn_cast<ConstantFPSDNode>(Op))
Expand Down Expand Up @@ -29539,15 +29553,17 @@ static SDValue LowerFMINIMUM_FMAXIMUM(SDValue Op, const X86Subtarget &Subtarget,
Op->getFlags().hasNoSignedZeros() ||
DAG.isKnownNeverZeroFloat(X) ||
DAG.isKnownNeverZeroFloat(Y);
SDValue NewX, NewY;
bool ShouldHandleZeros = true;
SDValue NewX = X;
SDValue NewY = Y;
if (IgnoreSignedZero || MatchesZero(Y, PreferredZero) ||
MatchesZero(X, OppositeZero)) {
// Operands are already in right order or order does not matter.
NewX = X;
NewY = Y;
ShouldHandleZeros = false;
} else if (MatchesZero(X, PreferredZero) || MatchesZero(Y, OppositeZero)) {
NewX = Y;
NewY = X;
ShouldHandleZeros = false;
} else if (!VT.isVector() && (VT == MVT::f16 || Subtarget.hasDQI()) &&
(Op->getFlags().hasNoNaNs() || IsXNeverNaN || IsYNeverNaN)) {
if (IsXNeverNaN)
Expand All @@ -29569,35 +29585,12 @@ static SDValue LowerFMINIMUM_FMAXIMUM(SDValue Op, const X86Subtarget &Subtarget,
NewX = DAG.getSelect(DL, VT, NeedSwap, Y, X);
NewY = DAG.getSelect(DL, VT, NeedSwap, X, Y);
return DAG.getNode(MinMaxOp, DL, VT, NewX, NewY, Op->getFlags());
} else {
SDValue IsXSigned;
if (Subtarget.is64Bit() || VT != MVT::f64) {
SDValue XInt = DAG.getNode(ISD::BITCAST, DL, IVT, X);
SDValue ZeroCst = DAG.getConstant(0, DL, IVT);
IsXSigned = DAG.getSetCC(DL, SetCCType, XInt, ZeroCst, ISD::SETLT);
} else {
assert(VT == MVT::f64);
SDValue Ins = DAG.getNode(ISD::INSERT_VECTOR_ELT, DL, MVT::v2f64,
DAG.getConstantFP(0, DL, MVT::v2f64), X,
DAG.getVectorIdxConstant(0, DL));
SDValue VX = DAG.getNode(ISD::BITCAST, DL, MVT::v4f32, Ins);
SDValue Hi = DAG.getNode(ISD::EXTRACT_VECTOR_ELT, DL, MVT::f32, VX,
DAG.getVectorIdxConstant(1, DL));
Hi = DAG.getBitcast(MVT::i32, Hi);
SDValue ZeroCst = DAG.getConstant(0, DL, MVT::i32);
EVT SetCCType = TLI.getSetCCResultType(DAG.getDataLayout(),
*DAG.getContext(), MVT::i32);
IsXSigned = DAG.getSetCC(DL, SetCCType, Hi, ZeroCst, ISD::SETLT);
}
if (MinMaxOp == X86ISD::FMAX) {
NewX = DAG.getSelect(DL, VT, IsXSigned, X, Y);
NewY = DAG.getSelect(DL, VT, IsXSigned, Y, X);
} else {
NewX = DAG.getSelect(DL, VT, IsXSigned, Y, X);
NewY = DAG.getSelect(DL, VT, IsXSigned, X, Y);
}
}

EVT SVT = VT.getScalarType();
assert(VT.isFloatingPoint() && (SVT == MVT::f16 || SVT == MVT::f32 || SVT == MVT::f64) &&
"Unexpected type in LowerFMINIMUM_FMAXIMUM");

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Should we assert that VT is f64/f32/f16 here? Is there anyway that fp80/bf16 code can arrive here?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I had to double-check; apparently bf16 code can call this lowering function, but I added an assert before here and it's not hit.

Just to make sure, I added tests for the scalar bf16 versions of minimum/maximum/minimumnum/maximumnum, and added CHECK lines for AVX512-BF16.

bool IgnoreNaN = DAG.getTarget().Options.NoNaNsFPMath ||
Op->getFlags().hasNoNaNs() || (IsXNeverNaN && IsYNeverNaN);

Expand All @@ -29612,10 +29605,80 @@ static SDValue LowerFMINIMUM_FMAXIMUM(SDValue Op, const X86Subtarget &Subtarget,

SDValue MinMax = DAG.getNode(MinMaxOp, DL, VT, NewX, NewY, Op->getFlags());

// We handle signed-zero ordering by taking the larger (or smaller) sign bit.
if (ShouldHandleZeros) {
const fltSemantics &Sem = VT.getFltSemantics();
unsigned EltBits = VT.getScalarSizeInBits();
bool IsFakeVector = !VT.isVector();
MVT LogicVT = VT.getSimpleVT();
if (IsFakeVector)
LogicVT = (VT == MVT::f64) ? MVT::v2f64
: (VT == MVT::f32) ? MVT::v4f32
: MVT::v8f16;

// We take the sign bit from the first operand and combine it with the
// output sign bit (see below). Right now, if ShouldHandleZeros is true, the
// operands will never have been swapped. If you add another optimization
// that swaps the input operands if one is a known value, make sure this
// logic stays correct!
SDValue LogicX = NewX;
SDValue LogicMinMax = MinMax;
if (IsFakeVector) {
// Promote scalars to vectors for bitwise operations.
LogicX = DAG.getNode(ISD::SCALAR_TO_VECTOR, DL, LogicVT, NewX);
LogicMinMax = DAG.getNode(ISD::SCALAR_TO_VECTOR, DL, LogicVT, MinMax);
}

// x86's min/max operations return the second operand if both inputs are
// signed zero. For the maximum operation, we want to "and" the sign bit of
// the output with the sign bit of the first operand--that means that if the
// first operand is +0.0, the output will be too. For the minimum, it's the
// opposite: we "or" the output sign bit with the sign bit of the first
// operand, ensuring that if the first operand is -0.0, the output will be
// too.
SDValue Result;
if (IsMaxOp) {
// getSignedMaxValue returns a bit pattern of all ones but the highest
// bit. We "or" that with the first operand, then "and" that with the max
// operation's result. That clears only the sign bit, and only if the
// first operand is positive.
SDValue OrMask = DAG.getConstantFP(
APFloat(Sem, APInt::getSignedMaxValue(EltBits)), DL, LogicVT);
SDValue MaskedSignBit =
DAG.getNode(X86ISD::FOR, DL, LogicVT, LogicX, OrMask);
Result =
DAG.getNode(X86ISD::FAND, DL, LogicVT, MaskedSignBit, LogicMinMax);
} else {
// Likewise, getSignMask returns a bit pattern with only the highest bit
// set. This one *sets* only the sign bit, and only if the first operand
// is *negative*.
SDValue AndMask = DAG.getConstantFP(
APFloat(Sem, APInt::getSignMask(EltBits)), DL, LogicVT);
SDValue MaskedSignBit =
DAG.getNode(X86ISD::FAND, DL, LogicVT, LogicX, AndMask);
Result =
DAG.getNode(X86ISD::FOR, DL, LogicVT, MaskedSignBit, LogicMinMax);
}

// Extract scalar back from vector.
if (IsFakeVector)
MinMax = DAG.getNode(ISD::EXTRACT_VECTOR_ELT, DL, VT, Result,
DAG.getVectorIdxConstant(0, DL));
else
MinMax = Result;
}

if (IgnoreNaN || DAG.isKnownNeverNaN(IsNum ? NewY : NewX))
return MinMax;

SDValue NaNSrc = IsNum ? MinMax : NewX;
// The x86 min/max return the second operand if either is NaN, which doesn't
// match the numeric or non-numeric semantics. For the non-numeric versions,
// we want to return NaN if either operand is NaN. To do that, we check if
// NewX (the first operand) is NaN, and select it if so. For the numeric
// versions, we want to return the non-NaN operand if there is one. So we
// check if NewY (the second operand) is NaN, and again select the first
// operand if so.
SDValue NaNSrc = IsNum ? NewY : NewX;
SDValue IsNaN = DAG.getSetCC(DL, SetCCType, NaNSrc, NaNSrc, ISD::SETUO);

return DAG.getSelect(DL, VT, IsNaN, NewX, MinMax);
Expand Down
42 changes: 16 additions & 26 deletions llvm/test/CodeGen/X86/avx512fp16-fminimum-fmaximum.ll
Original file line number Diff line number Diff line change
Expand Up @@ -13,14 +13,9 @@ declare <32 x half> @llvm.maximum.v32f16(<32 x half>, <32 x half>)
define half @test_fminimum(half %x, half %y) {
; CHECK-LABEL: test_fminimum:
; CHECK: # %bb.0:
; CHECK-NEXT: vmovw %xmm0, %eax
; CHECK-NEXT: testw %ax, %ax
; CHECK-NEXT: sets %al
; CHECK-NEXT: kmovd %eax, %k1
; CHECK-NEXT: vmovaps %xmm1, %xmm2
; CHECK-NEXT: vmovsh %xmm0, %xmm0, %xmm2 {%k1}
; CHECK-NEXT: vmovsh %xmm1, %xmm0, %xmm0 {%k1}
; CHECK-NEXT: vminsh %xmm2, %xmm0, %xmm1
; CHECK-NEXT: vminsh %xmm1, %xmm0, %xmm2
; CHECK-NEXT: vpbroadcastw {{.*#+}} xmm1 = [-0.0E+0,-0.0E+0,-0.0E+0,-0.0E+0,-0.0E+0,-0.0E+0,-0.0E+0,-0.0E+0]
; CHECK-NEXT: vpternlogq {{.*#+}} xmm1 = (xmm1 & xmm0) | xmm2
; CHECK-NEXT: vcmpunordsh %xmm0, %xmm0, %k1
; CHECK-NEXT: vmovsh %xmm0, %xmm0, %xmm1 {%k1}
; CHECK-NEXT: vmovaps %xmm1, %xmm0
Expand Down Expand Up @@ -92,16 +87,12 @@ define half @test_fminimum_combine_cmps(half %x, half %y) {
define half @test_fmaximum(half %x, half %y) {
; CHECK-LABEL: test_fmaximum:
; CHECK: # %bb.0:
; CHECK-NEXT: vmovw %xmm0, %eax
; CHECK-NEXT: testw %ax, %ax
; CHECK-NEXT: sets %al
; CHECK-NEXT: kmovd %eax, %k1
; CHECK-NEXT: vmovaps %xmm0, %xmm2
; CHECK-NEXT: vmovsh %xmm1, %xmm0, %xmm2 {%k1}
; CHECK-NEXT: vmaxsh %xmm1, %xmm0, %xmm2
; CHECK-NEXT: vpbroadcastw {{.*#+}} xmm1 = [NaN,NaN,NaN,NaN,NaN,NaN,NaN,NaN]
; CHECK-NEXT: vpternlogq {{.*#+}} xmm1 = xmm2 & (xmm1 | xmm0)
; CHECK-NEXT: vcmpunordsh %xmm0, %xmm0, %k1
; CHECK-NEXT: vmovsh %xmm0, %xmm0, %xmm1 {%k1}
; CHECK-NEXT: vmaxsh %xmm2, %xmm1, %xmm0
; CHECK-NEXT: vcmpunordsh %xmm1, %xmm1, %k1
; CHECK-NEXT: vmovsh %xmm1, %xmm0, %xmm0 {%k1}
; CHECK-NEXT: vmovaps %xmm1, %xmm0
; CHECK-NEXT: retq
%r = call half @llvm.maximum.f16(half %x, half %y)
ret half %r
Expand Down Expand Up @@ -196,10 +187,9 @@ define <16 x half> @test_fmaximum_v16f16_nans(<16 x half> %x, <16 x half> %y) "n
define <32 x half> @test_fminimum_v32f16_szero(<32 x half> %x, <32 x half> %y) "no-nans-fp-math"="true" {
; CHECK-LABEL: test_fminimum_v32f16_szero:
; CHECK: # %bb.0:
; CHECK-NEXT: vpmovw2m %zmm0, %k1
; CHECK-NEXT: vpblendmw %zmm0, %zmm1, %zmm2 {%k1}
; CHECK-NEXT: vmovdqu16 %zmm1, %zmm0 {%k1}
; CHECK-NEXT: vminph %zmm2, %zmm0, %zmm0
; CHECK-NEXT: vminph %zmm1, %zmm0, %zmm1
; CHECK-NEXT: vpbroadcastw {{.*#+}} zmm2 = [-0.0E+0,-0.0E+0,-0.0E+0,-0.0E+0,-0.0E+0,-0.0E+0,-0.0E+0,-0.0E+0,-0.0E+0,-0.0E+0,-0.0E+0,-0.0E+0,-0.0E+0,-0.0E+0,-0.0E+0,-0.0E+0,-0.0E+0,-0.0E+0,-0.0E+0,-0.0E+0,-0.0E+0,-0.0E+0,-0.0E+0,-0.0E+0,-0.0E+0,-0.0E+0,-0.0E+0,-0.0E+0,-0.0E+0,-0.0E+0,-0.0E+0,-0.0E+0]
; CHECK-NEXT: vpternlogq {{.*#+}} zmm0 = (zmm0 & zmm2) | zmm1
; CHECK-NEXT: retq
%r = call <32 x half> @llvm.minimum.v32f16(<32 x half> %x, <32 x half> %y)
ret <32 x half> %r
Expand All @@ -208,12 +198,12 @@ define <32 x half> @test_fminimum_v32f16_szero(<32 x half> %x, <32 x half> %y) "
define <32 x half> @test_fmaximum_v32f16_nans_szero(<32 x half> %x, <32 x half> %y) {
; CHECK-LABEL: test_fmaximum_v32f16_nans_szero:
; CHECK: # %bb.0:
; CHECK-NEXT: vpmovw2m %zmm0, %k1
; CHECK-NEXT: vpblendmw %zmm1, %zmm0, %zmm2 {%k1}
; CHECK-NEXT: vmaxph %zmm1, %zmm0, %zmm2
; CHECK-NEXT: vpbroadcastw {{.*#+}} zmm1 = [NaN,NaN,NaN,NaN,NaN,NaN,NaN,NaN,NaN,NaN,NaN,NaN,NaN,NaN,NaN,NaN,NaN,NaN,NaN,NaN,NaN,NaN,NaN,NaN,NaN,NaN,NaN,NaN,NaN,NaN,NaN,NaN]
; CHECK-NEXT: vpternlogq {{.*#+}} zmm1 = zmm2 & (zmm1 | zmm0)
; CHECK-NEXT: vcmpunordph %zmm0, %zmm0, %k1
; CHECK-NEXT: vmovdqu16 %zmm0, %zmm1 {%k1}
; CHECK-NEXT: vmaxph %zmm2, %zmm1, %zmm0
; CHECK-NEXT: vcmpunordph %zmm1, %zmm1, %k1
; CHECK-NEXT: vmovdqu16 %zmm1, %zmm0 {%k1}
; CHECK-NEXT: vmovdqa64 %zmm1, %zmm0
; CHECK-NEXT: retq
%r = call <32 x half> @llvm.maximum.v32f16(<32 x half> %x, <32 x half> %y)
ret <32 x half> %r
Expand Down
Loading