@@ -11048,10 +11048,126 @@ SDValue AArch64TargetLowering::LowerSETCCCARRY(SDValue Op,
1104811048 Cmp.getValue(1));
1104911049}
1105011050
11051- SDValue AArch64TargetLowering::LowerSELECT_CC(ISD::CondCode CC, SDValue LHS,
11052- SDValue RHS, SDValue TVal,
11053- SDValue FVal, const SDLoc &dl,
11054- SelectionDAG &DAG) const {
11051+ /// Emit vector comparison for floating-point values, producing a mask.
11052+ static SDValue emitVectorComparison(SDValue LHS, SDValue RHS,
11053+ AArch64CC::CondCode CC, bool NoNans, EVT VT,
11054+ const SDLoc &DL, SelectionDAG &DAG) {
11055+ EVT SrcVT = LHS.getValueType();
11056+ assert(VT.getSizeInBits() == SrcVT.getSizeInBits() &&
11057+ "function only supposed to emit natural comparisons");
11058+
11059+ switch (CC) {
11060+ default:
11061+ return SDValue();
11062+ case AArch64CC::NE: {
11063+ SDValue Fcmeq = DAG.getNode(AArch64ISD::FCMEQ, DL, VT, LHS, RHS);
11064+ // Use vector semantics for the inversion to potentially save a copy between
11065+ // SIMD and regular registers.
11066+ if (!LHS.getValueType().isVector()) {
11067+ EVT VecVT =
11068+ EVT::getVectorVT(*DAG.getContext(), VT, 128 / VT.getSizeInBits());
11069+ SDValue Zero = DAG.getConstant(0, DL, MVT::i64);
11070+ SDValue MaskVec = DAG.getNode(ISD::INSERT_VECTOR_ELT, DL, VecVT,
11071+ DAG.getUNDEF(VecVT), Fcmeq, Zero);
11072+ SDValue InvertedMask = DAG.getNOT(DL, MaskVec, VecVT);
11073+ return DAG.getNode(ISD::EXTRACT_VECTOR_ELT, DL, VT, InvertedMask, Zero);
11074+ }
11075+ return DAG.getNOT(DL, Fcmeq, VT);
11076+ }
11077+ case AArch64CC::EQ:
11078+ return DAG.getNode(AArch64ISD::FCMEQ, DL, VT, LHS, RHS);
11079+ case AArch64CC::GE:
11080+ return DAG.getNode(AArch64ISD::FCMGE, DL, VT, LHS, RHS);
11081+ case AArch64CC::GT:
11082+ return DAG.getNode(AArch64ISD::FCMGT, DL, VT, LHS, RHS);
11083+ case AArch64CC::LE:
11084+ if (!NoNans)
11085+ return SDValue();
11086+ // If we ignore NaNs then we can use to the LS implementation.
11087+ [[fallthrough]];
11088+ case AArch64CC::LS:
11089+ return DAG.getNode(AArch64ISD::FCMGE, DL, VT, RHS, LHS);
11090+ case AArch64CC::LT:
11091+ if (!NoNans)
11092+ return SDValue();
11093+ // If we ignore NaNs then we can use to the MI implementation.
11094+ [[fallthrough]];
11095+ case AArch64CC::MI:
11096+ return DAG.getNode(AArch64ISD::FCMGT, DL, VT, RHS, LHS);
11097+ }
11098+ }
11099+
11100+ /// For SELECT_CC, when the true/false values are (-1, 0) and the compared
11101+ /// values are scalars, try to emit a mask generating vector instruction.
11102+ static SDValue emitFloatCompareMask(SDValue LHS, SDValue RHS, SDValue TVal,
11103+ SDValue FVal, ISD::CondCode CC, bool NoNaNs,
11104+ const SDLoc &DL, SelectionDAG &DAG) {
11105+ assert(!LHS.getValueType().isVector());
11106+ assert(!RHS.getValueType().isVector());
11107+
11108+ auto *CTVal = dyn_cast<ConstantSDNode>(TVal);
11109+ auto *CFVal = dyn_cast<ConstantSDNode>(FVal);
11110+ if (!CTVal || !CFVal)
11111+ return {};
11112+ if (!(CTVal->isAllOnes() && CFVal->isZero()) &&
11113+ !(CTVal->isZero() && CFVal->isAllOnes()))
11114+ return {};
11115+
11116+ if (CTVal->isZero())
11117+ CC = ISD::getSetCCInverse(CC, LHS.getValueType());
11118+
11119+ EVT VT = TVal.getValueType();
11120+ if (VT.getSizeInBits() != LHS.getValueType().getSizeInBits())
11121+ return {};
11122+
11123+ if (!NoNaNs && (CC == ISD::SETUO || CC == ISD::SETO)) {
11124+ bool OneNaN = false;
11125+ if (LHS == RHS) {
11126+ OneNaN = true;
11127+ } else if (DAG.isKnownNeverNaN(RHS)) {
11128+ OneNaN = true;
11129+ RHS = LHS;
11130+ } else if (DAG.isKnownNeverNaN(LHS)) {
11131+ OneNaN = true;
11132+ LHS = RHS;
11133+ }
11134+ if (OneNaN)
11135+ CC = (CC == ISD::SETUO) ? ISD::SETUNE : ISD::SETOEQ;
11136+ }
11137+
11138+ AArch64CC::CondCode CC1;
11139+ AArch64CC::CondCode CC2;
11140+ bool ShouldInvert = false;
11141+ changeVectorFPCCToAArch64CC(CC, CC1, CC2, ShouldInvert);
11142+ SDValue Cmp = emitVectorComparison(LHS, RHS, CC1, NoNaNs, VT, DL, DAG);
11143+ SDValue Cmp2;
11144+ if (CC2 != AArch64CC::AL) {
11145+ Cmp2 = emitVectorComparison(LHS, RHS, CC2, NoNaNs, VT, DL, DAG);
11146+ if (!Cmp2)
11147+ return {};
11148+ }
11149+ if (!Cmp2 && !ShouldInvert)
11150+ return Cmp;
11151+
11152+ EVT VecVT = EVT::getVectorVT(*DAG.getContext(), VT, 128 / VT.getSizeInBits());
11153+ SDValue Zero = DAG.getConstant(0, DL, MVT::i64);
11154+ Cmp = DAG.getNode(ISD::INSERT_VECTOR_ELT, DL, VecVT, DAG.getUNDEF(VecVT), Cmp,
11155+ Zero);
11156+ if (Cmp2) {
11157+ Cmp2 = DAG.getNode(ISD::INSERT_VECTOR_ELT, DL, VecVT, DAG.getUNDEF(VecVT),
11158+ Cmp2, Zero);
11159+ Cmp = DAG.getNode(ISD::OR, DL, VecVT, Cmp, Cmp2);
11160+ }
11161+ if (ShouldInvert)
11162+ Cmp = DAG.getNOT(DL, Cmp, VecVT);
11163+ Cmp = DAG.getNode(ISD::EXTRACT_VECTOR_ELT, DL, VT, Cmp, Zero);
11164+ return Cmp;
11165+ }
11166+
11167+ SDValue AArch64TargetLowering::LowerSELECT_CC(
11168+ ISD::CondCode CC, SDValue LHS, SDValue RHS, SDValue TVal, SDValue FVal,
11169+ iterator_range<SDNode::user_iterator> Users, bool HasNoNaNs,
11170+ const SDLoc &dl, SelectionDAG &DAG) const {
1105511171 // Handle f128 first, because it will result in a comparison of some RTLIB
1105611172 // call result against zero.
1105711173 if (LHS.getValueType() == MVT::f128) {
@@ -11234,6 +11350,27 @@ SDValue AArch64TargetLowering::LowerSELECT_CC(ISD::CondCode CC, SDValue LHS,
1123411350 LHS.getValueType() == MVT::f64);
1123511351 assert(LHS.getValueType() == RHS.getValueType());
1123611352 EVT VT = TVal.getValueType();
11353+
11354+ // If the purpose of the comparison is to select between all ones
11355+ // or all zeros, try to use a vector comparison because the operands are
11356+ // already stored in SIMD registers.
11357+ if (Subtarget->isNeonAvailable() && all_of(Users, [](const SDNode *U) {
11358+ switch (U->getOpcode()) {
11359+ default:
11360+ return false;
11361+ case ISD::INSERT_VECTOR_ELT:
11362+ case ISD::SCALAR_TO_VECTOR:
11363+ case AArch64ISD::DUP:
11364+ return true;
11365+ }
11366+ })) {
11367+ bool NoNaNs = getTargetMachine().Options.NoNaNsFPMath || HasNoNaNs;
11368+ SDValue VectorCmp =
11369+ emitFloatCompareMask(LHS, RHS, TVal, FVal, CC, NoNaNs, dl, DAG);
11370+ if (VectorCmp)
11371+ return VectorCmp;
11372+ }
11373+
1123711374 SDValue Cmp = emitComparison(LHS, RHS, CC, dl, DAG);
1123811375
1123911376 // Unfortunately, the mapping of LLVM FP CC's onto AArch64 CC's isn't totally
@@ -11320,15 +11457,18 @@ SDValue AArch64TargetLowering::LowerSELECT_CC(SDValue Op,
1132011457 SDValue RHS = Op.getOperand(1);
1132111458 SDValue TVal = Op.getOperand(2);
1132211459 SDValue FVal = Op.getOperand(3);
11460+ bool HasNoNans = Op->getFlags().hasNoNaNs();
1132311461 SDLoc DL(Op);
11324- return LowerSELECT_CC(CC, LHS, RHS, TVal, FVal, DL, DAG);
11462+ return LowerSELECT_CC(CC, LHS, RHS, TVal, FVal, Op->users(), HasNoNans, DL,
11463+ DAG);
1132511464}
1132611465
1132711466SDValue AArch64TargetLowering::LowerSELECT(SDValue Op,
1132811467 SelectionDAG &DAG) const {
1132911468 SDValue CCVal = Op->getOperand(0);
1133011469 SDValue TVal = Op->getOperand(1);
1133111470 SDValue FVal = Op->getOperand(2);
11471+ bool HasNoNans = Op->getFlags().hasNoNaNs();
1133211472 SDLoc DL(Op);
1133311473
1133411474 EVT Ty = Op.getValueType();
@@ -11395,7 +11535,8 @@ SDValue AArch64TargetLowering::LowerSELECT(SDValue Op,
1139511535 DAG.getUNDEF(MVT::f32), FVal);
1139611536 }
1139711537
11398- SDValue Res = LowerSELECT_CC(CC, LHS, RHS, TVal, FVal, DL, DAG);
11538+ SDValue Res =
11539+ LowerSELECT_CC(CC, LHS, RHS, TVal, FVal, Op->users(), HasNoNans, DL, DAG);
1139911540
1140011541 if ((Ty == MVT::f16 || Ty == MVT::bf16) && !Subtarget->hasFullFP16()) {
1140111542 return DAG.getTargetExtractSubreg(AArch64::hsub, DL, Ty, Res);
@@ -15648,47 +15789,6 @@ SDValue AArch64TargetLowering::LowerVectorSRA_SRL_SHL(SDValue Op,
1564815789 llvm_unreachable("unexpected shift opcode");
1564915790}
1565015791
15651- static SDValue EmitVectorComparison(SDValue LHS, SDValue RHS,
15652- AArch64CC::CondCode CC, bool NoNans, EVT VT,
15653- const SDLoc &dl, SelectionDAG &DAG) {
15654- EVT SrcVT = LHS.getValueType();
15655- assert(VT.getSizeInBits() == SrcVT.getSizeInBits() &&
15656- "function only supposed to emit natural comparisons");
15657-
15658- if (SrcVT.getVectorElementType().isFloatingPoint()) {
15659- switch (CC) {
15660- default:
15661- return SDValue();
15662- case AArch64CC::NE: {
15663- SDValue Fcmeq = DAG.getNode(AArch64ISD::FCMEQ, dl, VT, LHS, RHS);
15664- return DAG.getNOT(dl, Fcmeq, VT);
15665- }
15666- case AArch64CC::EQ:
15667- return DAG.getNode(AArch64ISD::FCMEQ, dl, VT, LHS, RHS);
15668- case AArch64CC::GE:
15669- return DAG.getNode(AArch64ISD::FCMGE, dl, VT, LHS, RHS);
15670- case AArch64CC::GT:
15671- return DAG.getNode(AArch64ISD::FCMGT, dl, VT, LHS, RHS);
15672- case AArch64CC::LE:
15673- if (!NoNans)
15674- return SDValue();
15675- // If we ignore NaNs then we can use to the LS implementation.
15676- [[fallthrough]];
15677- case AArch64CC::LS:
15678- return DAG.getNode(AArch64ISD::FCMGE, dl, VT, RHS, LHS);
15679- case AArch64CC::LT:
15680- if (!NoNans)
15681- return SDValue();
15682- // If we ignore NaNs then we can use to the MI implementation.
15683- [[fallthrough]];
15684- case AArch64CC::MI:
15685- return DAG.getNode(AArch64ISD::FCMGT, dl, VT, RHS, LHS);
15686- }
15687- }
15688-
15689- return SDValue();
15690- }
15691-
1569215792SDValue AArch64TargetLowering::LowerVSETCC(SDValue Op,
1569315793 SelectionDAG &DAG) const {
1569415794 if (Op.getValueType().isScalableVector())
@@ -15737,15 +15837,14 @@ SDValue AArch64TargetLowering::LowerVSETCC(SDValue Op,
1573715837 bool ShouldInvert;
1573815838 changeVectorFPCCToAArch64CC(CC, CC1, CC2, ShouldInvert);
1573915839
15740- bool NoNaNs = getTargetMachine().Options.NoNaNsFPMath || Op->getFlags().hasNoNaNs();
15741- SDValue Cmp =
15742- EmitVectorComparison (LHS, RHS, CC1, NoNaNs, CmpVT, dl, DAG);
15840+ bool NoNaNs =
15841+ getTargetMachine().Options.NoNaNsFPMath || Op->getFlags().hasNoNaNs();
15842+ SDValue Cmp = emitVectorComparison (LHS, RHS, CC1, NoNaNs, CmpVT, dl, DAG);
1574315843 if (!Cmp.getNode())
1574415844 return SDValue();
1574515845
1574615846 if (CC2 != AArch64CC::AL) {
15747- SDValue Cmp2 =
15748- EmitVectorComparison(LHS, RHS, CC2, NoNaNs, CmpVT, dl, DAG);
15847+ SDValue Cmp2 = emitVectorComparison(LHS, RHS, CC2, NoNaNs, CmpVT, dl, DAG);
1574915848 if (!Cmp2.getNode())
1575015849 return SDValue();
1575115850
@@ -25502,6 +25601,28 @@ static SDValue performDUPCombine(SDNode *N,
2550225601 }
2550325602
2550425603 if (N->getOpcode() == AArch64ISD::DUP) {
25604+ // If the instruction is known to produce a scalar in SIMD registers, we can
25605+ // duplicate it across the vector lanes using DUPLANE instead of moving it
25606+ // to a GPR first. For example, this allows us to handle:
25607+ // v4i32 = DUP (i32 (FCMGT (f32, f32)))
25608+ SDValue Op = N->getOperand(0);
25609+ // FIXME: Ideally, we should be able to handle all instructions that
25610+ // produce a scalar value in FPRs.
25611+ if (Op.getOpcode() == AArch64ISD::FCMEQ ||
25612+ Op.getOpcode() == AArch64ISD::FCMGE ||
25613+ Op.getOpcode() == AArch64ISD::FCMGT) {
25614+ EVT ElemVT = VT.getVectorElementType();
25615+ EVT ExpandedVT = VT;
25616+ // Insert into a 128-bit vector to match DUPLANE's pattern.
25617+ if (VT.getSizeInBits() != 128)
25618+ ExpandedVT = EVT::getVectorVT(*DCI.DAG.getContext(), ElemVT,
25619+ 128 / ElemVT.getSizeInBits());
25620+ SDValue Zero = DCI.DAG.getConstant(0, DL, MVT::i64);
25621+ SDValue Vec = DCI.DAG.getNode(ISD::INSERT_VECTOR_ELT, DL, ExpandedVT,
25622+ DCI.DAG.getUNDEF(ExpandedVT), Op, Zero);
25623+ return DCI.DAG.getNode(getDUPLANEOp(ElemVT), DL, VT, Vec, Zero);
25624+ }
25625+
2550525626 if (DCI.isAfterLegalizeDAG()) {
2550625627 // If scalar dup's operand is extract_vector_elt, try to combine them into
2550725628 // duplane. For example,
0 commit comments