@@ -222,6 +222,10 @@ static cl::opt<unsigned> RangeIterThreshold(
222222 cl::desc("Threshold for switching to iteratively computing SCEV ranges"),
223223 cl::init(32));
224224
225+ static cl::opt<unsigned> MaxLoopGuardCollectionDepth(
226+ "scalar-evolution-max-loop-guard-collection-depth", cl::Hidden,
227+ cl::desc("Maximum depth for recrusive loop guard collection"), cl::init(1));
228+
225229static cl::opt<bool>
226230ClassifyExpressions("scalar-evolution-classify-expressions",
227231 cl::Hidden, cl::init(true),
@@ -10608,7 +10612,7 @@ ScalarEvolution::getPredecessorWithUniqueSuccessorForBB(const BasicBlock *BB)
1060810612 if (const Loop *L = LI.getLoopFor(BB))
1060910613 return {L->getLoopPredecessor(), L->getHeader()};
1061010614
10611- return {nullptr, nullptr };
10615+ return {nullptr, BB };
1061210616}
1061310617
1061410618/// SCEV structural equivalence is usually sufficient for testing whether two
@@ -15089,7 +15093,81 @@ bool ScalarEvolution::matchURem(const SCEV *Expr, const SCEV *&LHS,
1508915093
1509015094ScalarEvolution::LoopGuards
1509115095ScalarEvolution::LoopGuards::collect(const Loop *L, ScalarEvolution &SE) {
15096+ BasicBlock *Header = L->getHeader();
15097+ BasicBlock *Pred = L->getLoopPredecessor();
1509215098 LoopGuards Guards(SE);
15099+ SmallPtrSet<const BasicBlock *, 8> VisitedBlocks;
15100+ collectFromBlock(SE, Guards, Header, Pred, VisitedBlocks);
15101+ return Guards;
15102+ }
15103+
15104+ void ScalarEvolution::LoopGuards::collectFromPHI(
15105+ ScalarEvolution &SE, ScalarEvolution::LoopGuards &Guards,
15106+ const PHINode &Phi, SmallPtrSetImpl<const BasicBlock *> &VisitedBlocks,
15107+ SmallDenseMap<const BasicBlock *, LoopGuards> &IncomingGuards,
15108+ unsigned Depth) {
15109+ if (!SE.isSCEVable(Phi.getType()))
15110+ return;
15111+
15112+ using MinMaxPattern = std::pair<const SCEVConstant *, SCEVTypes>;
15113+ auto GetMinMaxConst = [&](unsigned IncomingIdx) -> MinMaxPattern {
15114+ const BasicBlock *InBlock = Phi.getIncomingBlock(IncomingIdx);
15115+ if (!VisitedBlocks.insert(InBlock).second)
15116+ return {nullptr, scCouldNotCompute};
15117+ auto [G, Inserted] = IncomingGuards.try_emplace(InBlock, LoopGuards(SE));
15118+ if (Inserted)
15119+ collectFromBlock(SE, G->second, Phi.getParent(), InBlock, VisitedBlocks,
15120+ Depth + 1);
15121+ auto &RewriteMap = G->second.RewriteMap;
15122+ if (RewriteMap.empty())
15123+ return {nullptr, scCouldNotCompute};
15124+ auto S = RewriteMap.find(SE.getSCEV(Phi.getIncomingValue(IncomingIdx)));
15125+ if (S == RewriteMap.end())
15126+ return {nullptr, scCouldNotCompute};
15127+ auto *SM = dyn_cast_if_present<SCEVMinMaxExpr>(S->second);
15128+ if (!SM)
15129+ return {nullptr, scCouldNotCompute};
15130+ if (const SCEVConstant *C0 = dyn_cast<SCEVConstant>(SM->getOperand(0)))
15131+ return {C0, SM->getSCEVType()};
15132+ return {nullptr, scCouldNotCompute};
15133+ };
15134+ auto MergeMinMaxConst = [](MinMaxPattern P1,
15135+ MinMaxPattern P2) -> MinMaxPattern {
15136+ auto [C1, T1] = P1;
15137+ auto [C2, T2] = P2;
15138+ if (!C1 || !C2 || T1 != T2)
15139+ return {nullptr, scCouldNotCompute};
15140+ switch (T1) {
15141+ case scUMaxExpr:
15142+ return {C1->getAPInt().ult(C2->getAPInt()) ? C1 : C2, T1};
15143+ case scSMaxExpr:
15144+ return {C1->getAPInt().slt(C2->getAPInt()) ? C1 : C2, T1};
15145+ case scUMinExpr:
15146+ return {C1->getAPInt().ugt(C2->getAPInt()) ? C1 : C2, T1};
15147+ case scSMinExpr:
15148+ return {C1->getAPInt().sgt(C2->getAPInt()) ? C1 : C2, T1};
15149+ default:
15150+ llvm_unreachable("Trying to merge non-MinMaxExpr SCEVs.");
15151+ }
15152+ };
15153+ auto P = GetMinMaxConst(0);
15154+ for (unsigned int In = 1; In < Phi.getNumIncomingValues(); In++) {
15155+ if (!P.first)
15156+ break;
15157+ P = MergeMinMaxConst(P, GetMinMaxConst(In));
15158+ }
15159+ if (P.first) {
15160+ const SCEV *LHS = SE.getSCEV(const_cast<PHINode *>(&Phi));
15161+ SmallVector<const SCEV *, 2> Ops({P.first, LHS});
15162+ const SCEV *RHS = SE.getMinMaxExpr(P.second, Ops);
15163+ Guards.RewriteMap.insert({LHS, RHS});
15164+ }
15165+ }
15166+
15167+ void ScalarEvolution::LoopGuards::collectFromBlock(
15168+ ScalarEvolution &SE, ScalarEvolution::LoopGuards &Guards,
15169+ const BasicBlock *Block, const BasicBlock *Pred,
15170+ SmallPtrSetImpl<const BasicBlock *> &VisitedBlocks, unsigned Depth) {
1509315171 SmallVector<const SCEV *> ExprsToRewrite;
1509415172 auto CollectCondition = [&](ICmpInst::Predicate Predicate, const SCEV *LHS,
1509515173 const SCEV *RHS,
@@ -15428,14 +15506,13 @@ ScalarEvolution::LoopGuards::collect(const Loop *L, ScalarEvolution &SE) {
1542815506 }
1542915507 };
1543015508
15431- BasicBlock *Header = L->getHeader();
1543215509 SmallVector<PointerIntPair<Value *, 1, bool>> Terms;
1543315510 // First, collect information from assumptions dominating the loop.
1543415511 for (auto &AssumeVH : SE.AC.assumptions()) {
1543515512 if (!AssumeVH)
1543615513 continue;
1543715514 auto *AssumeI = cast<CallInst>(AssumeVH);
15438- if (!SE.DT.dominates(AssumeI, Header ))
15515+ if (!SE.DT.dominates(AssumeI, Block ))
1543915516 continue;
1544015517 Terms.emplace_back(AssumeI->getOperand(0), true);
1544115518 }
@@ -15446,27 +15523,42 @@ ScalarEvolution::LoopGuards::collect(const Loop *L, ScalarEvolution &SE) {
1544615523 if (GuardDecl)
1544715524 for (const auto *GU : GuardDecl->users())
1544815525 if (const auto *Guard = dyn_cast<IntrinsicInst>(GU))
15449- if (Guard->getFunction() == Header ->getParent() &&
15450- SE.DT.dominates(Guard, Header ))
15526+ if (Guard->getFunction() == Block ->getParent() &&
15527+ SE.DT.dominates(Guard, Block ))
1545115528 Terms.emplace_back(Guard->getArgOperand(0), true);
1545215529
1545315530 // Third, collect conditions from dominating branches. Starting at the loop
1545415531 // predecessor, climb up the predecessor chain, as long as there are
1545515532 // predecessors that can be found that have unique successors leading to the
1545615533 // original header.
1545715534 // TODO: share this logic with isLoopEntryGuardedByCond.
15458- for (std::pair<const BasicBlock *, const BasicBlock *> Pair(
15459- L->getLoopPredecessor(), Header);
15460- Pair.first;
15535+ std::pair<const BasicBlock *, const BasicBlock *> Pair(Pred, Block);
15536+ for (; Pair.first;
1546115537 Pair = SE.getPredecessorWithUniqueSuccessorForBB(Pair.first)) {
15462-
15538+ VisitedBlocks.insert(Pair.second);
1546315539 const BranchInst *LoopEntryPredicate =
1546415540 dyn_cast<BranchInst>(Pair.first->getTerminator());
1546515541 if (!LoopEntryPredicate || LoopEntryPredicate->isUnconditional())
1546615542 continue;
1546715543
1546815544 Terms.emplace_back(LoopEntryPredicate->getCondition(),
1546915545 LoopEntryPredicate->getSuccessor(0) == Pair.second);
15546+
15547+ // If we are recursively collecting guards stop after 2
15548+ // predecessors to limit compile-time impact for now.
15549+ if (Depth > 0 && Terms.size() == 2)
15550+ break;
15551+ }
15552+ // Finally, if we stopped climbing the predecessor chain because
15553+ // there wasn't a unique one to continue, try to collect conditions
15554+ // for PHINodes by recursively following all of their incoming
15555+ // blocks and try to merge the found conditions to build a new one
15556+ // for the Phi.
15557+ if (Pair.second->hasNPredecessorsOrMore(2) &&
15558+ Depth < MaxLoopGuardCollectionDepth) {
15559+ SmallDenseMap<const BasicBlock *, LoopGuards> IncomingGuards;
15560+ for (auto &Phi : Pair.second->phis())
15561+ collectFromPHI(SE, Guards, Phi, VisitedBlocks, IncomingGuards, Depth);
1547015562 }
1547115563
1547215564 // Now apply the information from the collected conditions to
@@ -15523,7 +15615,6 @@ ScalarEvolution::LoopGuards::collect(const Loop *L, ScalarEvolution &SE) {
1552315615 Guards.RewriteMap.insert({Expr, Guards.rewrite(RewriteTo)});
1552415616 }
1552515617 }
15526- return Guards;
1552715618}
1552815619
1552915620const SCEV *ScalarEvolution::LoopGuards::rewrite(const SCEV *Expr) const {
0 commit comments