@@ -8268,6 +8268,105 @@ VPReplicateRecipe *VPRecipeBuilder::handleReplication(Instruction *I,
82688268 return Recipe;
82698269}
82708270
8271+ // / Find all possible partial reductions in the loop and track all of those that
8272+ // / are valid so recipes can be formed later.
8273+ void VPRecipeBuilder::collectScaledReductions (VFRange &Range) {
8274+ // Find all possible partial reductions.
8275+ SmallVector<std::pair<PartialReductionChain, unsigned >, 1 >
8276+ PartialReductionChains;
8277+ for (const auto &[Phi, RdxDesc] : Legal->getReductionVars ())
8278+ if (std::optional<std::pair<PartialReductionChain, unsigned >> Pair =
8279+ getScaledReduction (Phi, RdxDesc, Range))
8280+ PartialReductionChains.push_back (*Pair);
8281+
8282+ // A partial reduction is invalid if any of its extends are used by
8283+ // something that isn't another partial reduction. This is because the
8284+ // extends are intended to be lowered along with the reduction itself.
8285+
8286+ // Build up a set of partial reduction bin ops for efficient use checking.
8287+ SmallSet<User *, 4 > PartialReductionBinOps;
8288+ for (const auto &[PartialRdx, _] : PartialReductionChains)
8289+ PartialReductionBinOps.insert (PartialRdx.BinOp );
8290+
8291+ auto ExtendIsOnlyUsedByPartialReductions =
8292+ [&PartialReductionBinOps](Instruction *Extend) {
8293+ return all_of (Extend->users (), [&](const User *U) {
8294+ return PartialReductionBinOps.contains (U);
8295+ });
8296+ };
8297+
8298+ // Check if each use of a chain's two extends is a partial reduction
8299+ // and only add those that don't have non-partial reduction users.
8300+ for (auto Pair : PartialReductionChains) {
8301+ PartialReductionChain Chain = Pair.first ;
8302+ if (ExtendIsOnlyUsedByPartialReductions (Chain.ExtendA ) &&
8303+ ExtendIsOnlyUsedByPartialReductions (Chain.ExtendB ))
8304+ ScaledReductionExitInstrs.insert (std::make_pair (Chain.Reduction , Pair));
8305+ }
8306+ }
8307+
8308+ std::optional<std::pair<PartialReductionChain, unsigned >>
8309+ VPRecipeBuilder::getScaledReduction (PHINode *PHI,
8310+ const RecurrenceDescriptor &Rdx,
8311+ VFRange &Range) {
8312+ // TODO: Allow scaling reductions when predicating. The select at
8313+ // the end of the loop chooses between the phi value and most recent
8314+ // reduction result, both of which have different VFs to the active lane
8315+ // mask when scaling.
8316+ if (CM.blockNeedsPredicationForAnyReason (Rdx.getLoopExitInstr ()->getParent ()))
8317+ return std::nullopt ;
8318+
8319+ auto *Update = dyn_cast<BinaryOperator>(Rdx.getLoopExitInstr ());
8320+ if (!Update)
8321+ return std::nullopt ;
8322+
8323+ Value *Op = Update->getOperand (0 );
8324+ Value *PhiOp = Update->getOperand (1 );
8325+ if (Op == PHI) {
8326+ Op = Update->getOperand (1 );
8327+ PhiOp = Update->getOperand (0 );
8328+ }
8329+ if (PhiOp != PHI)
8330+ return std::nullopt ;
8331+
8332+ auto *BinOp = dyn_cast<BinaryOperator>(Op);
8333+ if (!BinOp || !BinOp->hasOneUse ())
8334+ return std::nullopt ;
8335+
8336+ using namespace llvm ::PatternMatch;
8337+ Value *A, *B;
8338+ if (!match (BinOp->getOperand (0 ), m_ZExtOrSExt (m_Value (A))) ||
8339+ !match (BinOp->getOperand (1 ), m_ZExtOrSExt (m_Value (B))))
8340+ return std::nullopt ;
8341+
8342+ Instruction *ExtA = cast<Instruction>(BinOp->getOperand (0 ));
8343+ Instruction *ExtB = cast<Instruction>(BinOp->getOperand (1 ));
8344+
8345+ TTI::PartialReductionExtendKind OpAExtend =
8346+ TargetTransformInfo::getPartialReductionExtendKind (ExtA);
8347+ TTI::PartialReductionExtendKind OpBExtend =
8348+ TargetTransformInfo::getPartialReductionExtendKind (ExtB);
8349+
8350+ PartialReductionChain Chain (Rdx.getLoopExitInstr (), ExtA, ExtB, BinOp);
8351+
8352+ unsigned TargetScaleFactor =
8353+ PHI->getType ()->getPrimitiveSizeInBits ().getKnownScalarFactor (
8354+ A->getType ()->getPrimitiveSizeInBits ());
8355+
8356+ if (LoopVectorizationPlanner::getDecisionAndClampRange (
8357+ [&](ElementCount VF) {
8358+ InstructionCost Cost = TTI->getPartialReductionCost (
8359+ Update->getOpcode (), A->getType (), B->getType (), PHI->getType (),
8360+ VF, OpAExtend, OpBExtend,
8361+ std::make_optional (BinOp->getOpcode ()));
8362+ return Cost.isValid ();
8363+ },
8364+ Range))
8365+ return std::make_pair (Chain, TargetScaleFactor);
8366+
8367+ return std::nullopt ;
8368+ }
8369+
82718370VPRecipeBase *
82728371VPRecipeBuilder::tryToCreateWidenRecipe (Instruction *Instr,
82738372 ArrayRef<VPValue *> Operands,
@@ -8292,9 +8391,14 @@ VPRecipeBuilder::tryToCreateWidenRecipe(Instruction *Instr,
82928391 Legal->getReductionVars ().find (Phi)->second ;
82938392 assert (RdxDesc.getRecurrenceStartValue () ==
82948393 Phi->getIncomingValueForBlock (OrigLoop->getLoopPreheader ()));
8295- PhiRecipe = new VPReductionPHIRecipe (Phi, RdxDesc, *StartV,
8296- CM.isInLoopReduction (Phi),
8297- CM.useOrderedReductions (RdxDesc));
8394+
8395+ // If the PHI is used by a partial reduction, set the scale factor.
8396+ std::optional<std::pair<PartialReductionChain, unsigned >> Pair =
8397+ getScaledReductionForInstr (RdxDesc.getLoopExitInstr ());
8398+ unsigned ScaleFactor = Pair ? Pair->second : 1 ;
8399+ PhiRecipe = new VPReductionPHIRecipe (
8400+ Phi, RdxDesc, *StartV, CM.isInLoopReduction (Phi),
8401+ CM.useOrderedReductions (RdxDesc), ScaleFactor);
82988402 } else {
82998403 // TODO: Currently fixed-order recurrences are modeled as chains of
83008404 // first-order recurrences. If there are no users of the intermediate
@@ -8322,6 +8426,9 @@ VPRecipeBuilder::tryToCreateWidenRecipe(Instruction *Instr,
83228426 if (isa<LoadInst>(Instr) || isa<StoreInst>(Instr))
83238427 return tryToWidenMemory (Instr, Operands, Range);
83248428
8429+ if (getScaledReductionForInstr (Instr))
8430+ return tryToCreatePartialReduction (Instr, Operands);
8431+
83258432 if (!shouldWiden (Instr, Range))
83268433 return nullptr ;
83278434
@@ -8342,6 +8449,21 @@ VPRecipeBuilder::tryToCreateWidenRecipe(Instruction *Instr,
83428449 return tryToWiden (Instr, Operands, VPBB);
83438450}
83448451
8452+ VPRecipeBase *
8453+ VPRecipeBuilder::tryToCreatePartialReduction (Instruction *Reduction,
8454+ ArrayRef<VPValue *> Operands) {
8455+ assert (Operands.size () == 2 &&
8456+ " Unexpected number of operands for partial reduction" );
8457+
8458+ VPValue *BinOp = Operands[0 ];
8459+ VPValue *Phi = Operands[1 ];
8460+ if (isa<VPReductionPHIRecipe>(BinOp->getDefiningRecipe ()))
8461+ std::swap (BinOp, Phi);
8462+
8463+ return new VPPartialReductionRecipe (Reduction->getOpcode (), BinOp, Phi,
8464+ Reduction);
8465+ }
8466+
83458467void LoopVectorizationPlanner::buildVPlansWithVPRecipes (ElementCount MinVF,
83468468 ElementCount MaxVF) {
83478469 assert (OrigLoop->isInnermost () && " Inner loop expected." );
@@ -8514,7 +8636,8 @@ LoopVectorizationPlanner::tryToBuildVPlanWithVPRecipes(VFRange &Range) {
85148636 bool HasNUW = Style == TailFoldingStyle::None;
85158637 addCanonicalIVRecipes (*Plan, Legal->getWidestInductionType (), HasNUW, DL);
85168638
8517- VPRecipeBuilder RecipeBuilder (*Plan, OrigLoop, TLI, Legal, CM, PSE, Builder);
8639+ VPRecipeBuilder RecipeBuilder (*Plan, OrigLoop, TLI, &TTI, Legal, CM, PSE,
8640+ Builder);
85188641
85198642 // ---------------------------------------------------------------------------
85208643 // Pre-construction: record ingredients whose recipes we'll need to further
@@ -8560,6 +8683,9 @@ LoopVectorizationPlanner::tryToBuildVPlanWithVPRecipes(VFRange &Range) {
85608683 bool NeedsBlends = BB != HeaderBB && !BB->phis ().empty ();
85618684 return Legal->blockNeedsPredication (BB) || NeedsBlends;
85628685 });
8686+
8687+ RecipeBuilder.collectScaledReductions (Range);
8688+
85638689 for (BasicBlock *BB : make_range (DFS.beginRPO (), DFS.endRPO ())) {
85648690 // Relevant instructions from basic block BB will be grouped into VPRecipe
85658691 // ingredients and fill a new VPBasicBlock.
@@ -8770,7 +8896,7 @@ VPlanPtr LoopVectorizationPlanner::buildVPlan(VFRange &Range) {
87708896 bool HasNUW = true ;
87718897 addCanonicalIVRecipes (*Plan, Legal->getWidestInductionType (), HasNUW,
87728898 DebugLoc ());
8773- assert (verifyVPlanIsValid (*Plan) && " VPlan is invalid" );
8899+ assert (verifyVPlanIsValid (*Plan) && " VPlan is invalid" );
87748900 return Plan;
87758901}
87768902
0 commit comments