Skip to content

Commit 8aa1220

Browse files
committed
[LV] Pass step to emitTransformedIndex (NFC).
Move out the induction step creation from emitTransformedIndex to the callers. In some places (e.g. widenIntOrFpInduction) the step is already created. Passing the step in ensures the steps are kept in sync.
1 parent 9e69959 commit 8aa1220

File tree

1 file changed

+37
-44
lines changed

1 file changed

+37
-44
lines changed

llvm/lib/Transforms/Vectorize/LoopVectorize.cpp

Lines changed: 37 additions & 44 deletions
Original file line numberDiff line numberDiff line change
@@ -2529,19 +2529,30 @@ static void buildScalarSteps(Value *ScalarIV, Value *Step,
25292529
}
25302530
}
25312531

2532+
// Generate code for the induction step. Note that induction steps are
2533+
// required to be loop-invariant
2534+
static Value *CreateStepValue(const SCEV *Step, ScalarEvolution &SE,
2535+
Instruction *InsertBefore,
2536+
Loop *OrigLoop = nullptr) {
2537+
const DataLayout &DL = SE.getDataLayout();
2538+
assert((!OrigLoop || SE.isLoopInvariant(Step, OrigLoop)) &&
2539+
"Induction step should be loop invariant");
2540+
if (auto *E = dyn_cast<SCEVUnknown>(Step))
2541+
return E->getValue();
2542+
2543+
SCEVExpander Exp(SE, DL, "induction");
2544+
return Exp.expandCodeFor(Step, Step->getType(), InsertBefore);
2545+
}
2546+
25322547
/// Compute the transformed value of Index at offset StartValue using step
25332548
/// StepValue.
25342549
/// For integer induction, returns StartValue + Index * StepValue.
25352550
/// For pointer induction, returns StartValue[Index * StepValue].
25362551
/// FIXME: The newly created binary instructions should contain nsw/nuw
25372552
/// flags, which can be found from the original scalar operations.
2538-
static Value *emitTransformedIndex(IRBuilderBase &B, Value *Index,
2539-
ScalarEvolution *SE, const DataLayout &DL,
2540-
const InductionDescriptor &ID, LoopInfo &LI,
2541-
BasicBlock *VectorHeader) {
2553+
static Value *emitTransformedIndex(IRBuilderBase &B, Value *Index, Value *Step,
2554+
const InductionDescriptor &ID) {
25422555

2543-
SCEVExpander Exp(*SE, DL, "induction");
2544-
auto Step = ID.getStep();
25452556
auto StartValue = ID.getStartValue();
25462557
assert(Index->getType()->getScalarType() == Step->getType() &&
25472558
"Index scalar type does not match StepValue type");
@@ -2580,39 +2591,21 @@ static Value *emitTransformedIndex(IRBuilderBase &B, Value *Index,
25802591
return B.CreateMul(X, Y);
25812592
};
25822593

2583-
// Get a suitable insert point for SCEV expansion. For blocks in the vector
2584-
// loop, choose the end of the vector loop header (=VectorHeader), because
2585-
// the DomTree is not kept up-to-date for additional blocks generated in the
2586-
// vector loop. By using the header as insertion point, we guarantee that the
2587-
// expanded instructions dominate all their uses.
2588-
auto GetInsertPoint = [&B, &LI, VectorHeader]() {
2589-
BasicBlock *InsertBB = B.GetInsertPoint()->getParent();
2590-
if (InsertBB != VectorHeader &&
2591-
LI.getLoopFor(VectorHeader) == LI.getLoopFor(InsertBB))
2592-
return VectorHeader->getTerminator();
2593-
return &*B.GetInsertPoint();
2594-
};
2595-
25962594
switch (ID.getKind()) {
25972595
case InductionDescriptor::IK_IntInduction: {
25982596
assert(!isa<VectorType>(Index->getType()) &&
25992597
"Vector indices not supported for integer inductions yet");
26002598
assert(Index->getType() == StartValue->getType() &&
26012599
"Index type does not match StartValue type");
2602-
if (ID.getConstIntStepValue() && ID.getConstIntStepValue()->isMinusOne())
2600+
if (isa<ConstantInt>(Step) && cast<ConstantInt>(Step)->isMinusOne())
26032601
return B.CreateSub(StartValue, Index);
2604-
auto *Offset = CreateMul(
2605-
Index, Exp.expandCodeFor(Step, Index->getType(), GetInsertPoint()));
2602+
auto *Offset = CreateMul(Index, Step);
26062603
return CreateAdd(StartValue, Offset);
26072604
}
26082605
case InductionDescriptor::IK_PtrInduction: {
2609-
assert(isa<SCEVConstant>(Step) &&
2606+
assert(isa<Constant>(Step) &&
26102607
"Expected constant step for pointer induction");
2611-
return B.CreateGEP(
2612-
ID.getElementType(), StartValue,
2613-
CreateMul(Index,
2614-
Exp.expandCodeFor(Step, Index->getType()->getScalarType(),
2615-
GetInsertPoint())));
2608+
return B.CreateGEP(ID.getElementType(), StartValue, CreateMul(Index, Step));
26162609
}
26172610
case InductionDescriptor::IK_FpInduction: {
26182611
assert(!isa<VectorType>(Index->getType()) &&
@@ -2624,8 +2617,7 @@ static Value *emitTransformedIndex(IRBuilderBase &B, Value *Index,
26242617
InductionBinOp->getOpcode() == Instruction::FSub) &&
26252618
"Original bin op should be defined for FP induction");
26262619

2627-
Value *StepValue = cast<SCEVUnknown>(Step)->getValue();
2628-
Value *MulExp = B.CreateFMul(StepValue, Index);
2620+
Value *MulExp = B.CreateFMul(Step, Index);
26292621
return B.CreateBinOp(InductionBinOp->getOpcode(), StartValue, MulExp,
26302622
"induction");
26312623
}
@@ -2676,8 +2668,7 @@ void InnerLoopVectorizer::widenIntOrFpInduction(
26762668
NeededType->isIntegerTy()
26772669
? Builder.CreateSExtOrTrunc(ScalarIV, NeededType)
26782670
: Builder.CreateCast(Instruction::SIToFP, ScalarIV, NeededType);
2679-
ScalarIV = emitTransformedIndex(Builder, ScalarIV, PSE.getSE(), DL, ID,
2680-
*State.LI, State.CFG.PrevBB);
2671+
ScalarIV = emitTransformedIndex(Builder, ScalarIV, Step, ID);
26812672
ScalarIV->setName("offset.idx");
26822673
}
26832674
if (Trunc) {
@@ -3410,20 +3401,21 @@ void InnerLoopVectorizer::createInductionResumeValues(
34103401
Instruction::CastOps CastOp =
34113402
CastInst::getCastOpcode(VectorTripCount, true, StepType, true);
34123403
Value *CRD = B.CreateCast(CastOp, VectorTripCount, StepType, "cast.crd");
3413-
const DataLayout &DL = LoopScalarBody->getModule()->getDataLayout();
3414-
EndValue = emitTransformedIndex(B, CRD, PSE.getSE(), DL, II, *LI,
3415-
LoopVectorBody);
3404+
Value *Step =
3405+
CreateStepValue(II.getStep(), *PSE.getSE(), &*B.GetInsertPoint());
3406+
EndValue = emitTransformedIndex(B, CRD, Step, II);
34163407
EndValue->setName("ind.end");
34173408

34183409
// Compute the end value for the additional bypass (if applicable).
34193410
if (AdditionalBypass.first) {
34203411
B.SetInsertPoint(&(*AdditionalBypass.first->getFirstInsertionPt()));
34213412
CastOp = CastInst::getCastOpcode(AdditionalBypass.second, true,
34223413
StepType, true);
3414+
Value *Step =
3415+
CreateStepValue(II.getStep(), *PSE.getSE(), &*B.GetInsertPoint());
34233416
CRD =
34243417
B.CreateCast(CastOp, AdditionalBypass.second, StepType, "cast.crd");
3425-
EndValueFromAdditionalBypass = emitTransformedIndex(
3426-
B, CRD, PSE.getSE(), DL, II, *LI, LoopVectorBody);
3418+
EndValueFromAdditionalBypass = emitTransformedIndex(B, CRD, Step, II);
34273419
EndValueFromAdditionalBypass->setName("ind.end");
34283420
}
34293421
}
@@ -3597,8 +3589,6 @@ void InnerLoopVectorizer::fixupIVUsers(PHINode *OrigPhi,
35973589
for (User *U : OrigPhi->users()) {
35983590
auto *UI = cast<Instruction>(U);
35993591
if (!OrigLoop->contains(UI)) {
3600-
const DataLayout &DL =
3601-
OrigLoop->getHeader()->getModule()->getDataLayout();
36023592
assert(isa<PHINode>(UI) && "Expected LCSSA form");
36033593

36043594
IRBuilder<> B(MiddleBlock->getTerminator());
@@ -3615,8 +3605,10 @@ void InnerLoopVectorizer::fixupIVUsers(PHINode *OrigPhi,
36153605
II.getStep()->getType())
36163606
: B.CreateSExtOrTrunc(CountMinusOne, II.getStep()->getType());
36173607
CMO->setName("cast.cmo");
3618-
Value *Escape = emitTransformedIndex(B, CMO, PSE.getSE(), DL, II, *LI,
3619-
LoopVectorBody);
3608+
3609+
Value *Step = CreateStepValue(II.getStep(), *PSE.getSE(),
3610+
LoopVectorBody->getTerminator());
3611+
Value *Escape = emitTransformedIndex(B, CMO, Step, II);
36203612
Escape->setName("ind.escape");
36213613
MissingVals[UI] = Escape;
36223614
}
@@ -4504,9 +4496,10 @@ void InnerLoopVectorizer::widenPHIInstruction(Instruction *PN,
45044496
Value *Idx = Builder.CreateAdd(
45054497
PartStart, ConstantInt::get(PtrInd->getType(), Lane));
45064498
Value *GlobalIdx = Builder.CreateAdd(PtrInd, Idx);
4507-
Value *SclrGep =
4508-
emitTransformedIndex(Builder, GlobalIdx, PSE.getSE(), DL, II,
4509-
*State.LI, State.CFG.PrevBB);
4499+
4500+
Value *Step = CreateStepValue(II.getStep(), *PSE.getSE(),
4501+
State.CFG.PrevBB->getTerminator());
4502+
Value *SclrGep = emitTransformedIndex(Builder, GlobalIdx, Step, II);
45104503
SclrGep->setName("next.gep");
45114504
State.set(PhiR, SclrGep, VPIteration(Part, Lane));
45124505
}

0 commit comments

Comments
 (0)