@@ -3883,6 +3883,7 @@ class BoUpSLP {
38833883 enum CombinedOpcode {
38843884 NotCombinedOp = -1,
38853885 MinMax = Instruction::OtherOpsEnd + 1,
3886+ FMulAdd,
38863887 };
38873888 CombinedOpcode CombinedOp = NotCombinedOp;
38883889
@@ -4033,6 +4034,9 @@ class BoUpSLP {
40334034 /// Returns true if any scalar in the list is a copyable element.
40344035 bool hasCopyableElements() const { return !CopyableElements.empty(); }
40354036
4037+ /// Returns the state of the operations.
4038+ const InstructionsState &getOperations() const { return S; }
4039+
40364040 /// When ReuseReorderShuffleIndices is empty it just returns position of \p
40374041 /// V within vector of Scalars. Otherwise, try to remap on its reuse index.
40384042 unsigned findLaneForValue(Value *V) const {
@@ -11987,6 +11991,89 @@ void BoUpSLP::reorderGatherNode(TreeEntry &TE) {
1198711991 }
1198811992}
1198911993
11994+ /// Check if we can convert fadd/fsub sequence to FMAD.
11995+ /// \returns Cost of the FMAD, if conversion is possible, invalid cost otherwise.
11996+ static InstructionCost canConvertToFMA(ArrayRef<Value *> VL,
11997+ const InstructionsState &S,
11998+ DominatorTree &DT, const DataLayout &DL,
11999+ TargetTransformInfo &TTI,
12000+ const TargetLibraryInfo &TLI) {
12001+ assert(all_of(VL,
12002+ [](Value *V) {
12003+ return V->getType()->getScalarType()->isFloatingPointTy();
12004+ }) &&
12005+ "Can only convert to FMA for floating point types");
12006+ assert(S.isAddSubLikeOp() && "Can only convert to FMA for add/sub");
12007+
12008+ auto CheckForContractable = [&](ArrayRef<Value *> VL) {
12009+ FastMathFlags FMF;
12010+ FMF.set();
12011+ for (Value *V : VL) {
12012+ auto *I = dyn_cast<Instruction>(V);
12013+ if (!I)
12014+ continue;
12015+ if (S.isCopyableElement(I))
12016+ continue;
12017+ Instruction *MatchingI = S.getMatchingMainOpOrAltOp(I);
12018+ if (S.getMainOp() != MatchingI && S.getAltOp() != MatchingI)
12019+ continue;
12020+ if (auto *FPCI = dyn_cast<FPMathOperator>(I))
12021+ FMF &= FPCI->getFastMathFlags();
12022+ }
12023+ return FMF.allowContract();
12024+ };
12025+ if (!CheckForContractable(VL))
12026+ return InstructionCost::getInvalid();
12027+ // fmul also should be contractable
12028+ InstructionsCompatibilityAnalysis Analysis(DT, DL, TTI, TLI);
12029+ SmallVector<BoUpSLP::ValueList> Operands = Analysis.buildOperands(S, VL);
12030+
12031+ InstructionsState OpS = getSameOpcode(Operands.front(), TLI);
12032+ if (!OpS.valid())
12033+ return InstructionCost::getInvalid();
12034+
12035+ if (OpS.isAltShuffle() || OpS.getOpcode() != Instruction::FMul)
12036+ return InstructionCost::getInvalid();
12037+ if (!CheckForContractable(Operands.front()))
12038+ return InstructionCost::getInvalid();
12039+ // Compare the costs.
12040+ InstructionCost FMulPlusFAddCost = 0;
12041+ InstructionCost FMACost = 0;
12042+ constexpr TTI::TargetCostKind CostKind = TTI::TCK_RecipThroughput;
12043+ FastMathFlags FMF;
12044+ FMF.set();
12045+ for (Value *V : VL) {
12046+ auto *I = dyn_cast<Instruction>(V);
12047+ if (!I)
12048+ continue;
12049+ if (!S.isCopyableElement(I))
12050+ if (auto *FPCI = dyn_cast<FPMathOperator>(I))
12051+ FMF &= FPCI->getFastMathFlags();
12052+ FMulPlusFAddCost += TTI.getInstructionCost(I, CostKind);
12053+ }
12054+ unsigned NumOps = 0;
12055+ for (auto [V, Op] : zip(VL, Operands.front())) {
12056+ if (S.isCopyableElement(V))
12057+ continue;
12058+ auto *I = dyn_cast<Instruction>(Op);
12059+ if (!I || !I->hasOneUse() || OpS.isCopyableElement(I)) {
12060+ if (auto *OpI = dyn_cast<Instruction>(V))
12061+ FMACost += TTI.getInstructionCost(OpI, CostKind);
12062+ if (I)
12063+ FMACost += TTI.getInstructionCost(I, CostKind);
12064+ continue;
12065+ }
12066+ ++NumOps;
12067+ if (auto *FPCI = dyn_cast<FPMathOperator>(I))
12068+ FMF &= FPCI->getFastMathFlags();
12069+ FMulPlusFAddCost += TTI.getInstructionCost(I, CostKind);
12070+ }
12071+ Type *Ty = VL.front()->getType();
12072+ IntrinsicCostAttributes ICA(Intrinsic::fmuladd, Ty, {Ty, Ty, Ty}, FMF);
12073+ FMACost += NumOps * TTI.getIntrinsicInstrCost(ICA, CostKind);
12074+ return FMACost < FMulPlusFAddCost ? FMACost : InstructionCost::getInvalid();
12075+ }
12076+
1199012077void BoUpSLP::transformNodes() {
1199112078 constexpr TTI::TargetCostKind CostKind = TTI::TCK_RecipThroughput;
1199212079 BaseGraphSize = VectorizableTree.size();
@@ -12355,6 +12442,25 @@ void BoUpSLP::transformNodes() {
1235512442 }
1235612443 break;
1235712444 }
12445+ case Instruction::FSub:
12446+ case Instruction::FAdd: {
12447+ // Check if possible to convert (a*b)+c to fma.
12448+ if (E.State != TreeEntry::Vectorize ||
12449+ !E.getOperations().isAddSubLikeOp())
12450+ break;
12451+ if (!canConvertToFMA(E.Scalars, E.getOperations(), *DT, *DL, *TTI, *TLI)
12452+ .isValid())
12453+ break;
12454+ // This node is a fmuladd node.
12455+ E.CombinedOp = TreeEntry::FMulAdd;
12456+ TreeEntry *FMulEntry = getOperandEntry(&E, 0);
12457+ if (FMulEntry->UserTreeIndex &&
12458+ FMulEntry->State == TreeEntry::Vectorize) {
12459+ // The FMul node is part of the combined fmuladd node.
12460+ FMulEntry->State = TreeEntry::CombinedVectorize;
12461+ }
12462+ break;
12463+ }
1235812464 default:
1235912465 break;
1236012466 }
@@ -13587,6 +13693,11 @@ BoUpSLP::getEntryCost(const TreeEntry *E, ArrayRef<Value *> VectorizedVals,
1358713693 }
1358813694 return IntrinsicCost;
1358913695 };
13696+ auto GetFMulAddCost = [&, &TTI = *TTI](const InstructionsState &S,
13697+ Instruction *VI) {
13698+ InstructionCost Cost = canConvertToFMA(VI, S, *DT, *DL, TTI, *TLI);
13699+ return Cost;
13700+ };
1359013701 switch (ShuffleOrOp) {
1359113702 case Instruction::PHI: {
1359213703 // Count reused scalars.
@@ -13927,6 +14038,30 @@ BoUpSLP::getEntryCost(const TreeEntry *E, ArrayRef<Value *> VectorizedVals,
1392714038 };
1392814039 return GetCostDiff(GetScalarCost, GetVectorCost);
1392914040 }
14041+ case TreeEntry::FMulAdd: {
14042+ auto GetScalarCost = [&](unsigned Idx) {
14043+ if (isa<PoisonValue>(UniqueValues[Idx]))
14044+ return InstructionCost(TTI::TCC_Free);
14045+ return GetFMulAddCost(E->getOperations(),
14046+ cast<Instruction>(UniqueValues[Idx]));
14047+ };
14048+ auto GetVectorCost = [&, &TTI = *TTI](InstructionCost CommonCost) {
14049+ FastMathFlags FMF;
14050+ FMF.set();
14051+ for (Value *V : E->Scalars) {
14052+ if (auto *FPCI = dyn_cast<FPMathOperator>(V)) {
14053+ FMF &= FPCI->getFastMathFlags();
14054+ if (auto *FPCIOp = dyn_cast<FPMathOperator>(FPCI->getOperand(0)))
14055+ FMF &= FPCIOp->getFastMathFlags();
14056+ }
14057+ }
14058+ IntrinsicCostAttributes ICA(Intrinsic::fmuladd, VecTy,
14059+ {VecTy, VecTy, VecTy}, FMF);
14060+ InstructionCost VecCost = TTI.getIntrinsicInstrCost(ICA, CostKind);
14061+ return VecCost + CommonCost;
14062+ };
14063+ return GetCostDiff(GetScalarCost, GetVectorCost);
14064+ }
1393014065 case Instruction::FNeg:
1393114066 case Instruction::Add:
1393214067 case Instruction::FAdd:
@@ -13964,8 +14099,16 @@ BoUpSLP::getEntryCost(const TreeEntry *E, ArrayRef<Value *> VectorizedVals,
1396414099 }
1396514100 TTI::OperandValueInfo Op1Info = TTI::getOperandInfo(Op1);
1396614101 TTI::OperandValueInfo Op2Info = TTI::getOperandInfo(Op2);
13967- return TTI->getArithmeticInstrCost(ShuffleOrOp, OrigScalarTy, CostKind,
13968- Op1Info, Op2Info, Operands);
14102+ InstructionCost ScalarCost = TTI->getArithmeticInstrCost(
14103+ ShuffleOrOp, OrigScalarTy, CostKind, Op1Info, Op2Info, Operands);
14104+ if (auto *I = dyn_cast<Instruction>(UniqueValues[Idx]);
14105+ I && (ShuffleOrOp == Instruction::FAdd ||
14106+ ShuffleOrOp == Instruction::FSub)) {
14107+ InstructionCost IntrinsicCost = GetFMulAddCost(E->getOperations(), I);
14108+ if (IntrinsicCost.isValid())
14109+ ScalarCost = IntrinsicCost;
14110+ }
14111+ return ScalarCost;
1396914112 };
1397014113 auto GetVectorCost = [=](InstructionCost CommonCost) {
1397114114 if (ShuffleOrOp == Instruction::And && It != MinBWs.end()) {
@@ -24205,7 +24348,7 @@ bool SLPVectorizerPass::vectorizeHorReduction(
2420524348 Stack.emplace(SelectRoot(), 0);
2420624349 SmallPtrSet<Value *, 8> VisitedInstrs;
2420724350 bool Res = false;
24208- auto && TryToReduce = [this, &R](Instruction *Inst) -> Value * {
24351+ auto TryToReduce = [this, &R, TTI = TTI ](Instruction *Inst) -> Value * {
2420924352 if (R.isAnalyzedReductionRoot(Inst))
2421024353 return nullptr;
2421124354 if (!isReductionCandidate(Inst))
@@ -24277,6 +24420,12 @@ bool SLPVectorizerPass::tryToVectorize(Instruction *I, BoUpSLP &R) {
2427724420
2427824421 if (!isa<BinaryOperator, CmpInst>(I) || isa<VectorType>(I->getType()))
2427924422 return false;
24423+ // Skip potential FMA candidates.
24424+ if ((I->getOpcode() == Instruction::FAdd ||
24425+ I->getOpcode() == Instruction::FSub) &&
24426+ canConvertToFMA(I, getSameOpcode(I, *TLI), *DT, *DL, *TTI, *TLI)
24427+ .isValid())
24428+ return false;
2428024429
2428124430 Value *P = I->getParent();
2428224431
0 commit comments