@@ -379,6 +379,10 @@ class LoopInterchangeLegality {
379
379
return InnerLoopInductions;
380
380
}
381
381
382
+ ArrayRef<Instruction *> getHasNoWrapReductions () const {
383
+ return HasNoWrapReductions;
384
+ }
385
+
382
386
private:
383
387
bool tightlyNested (Loop *Outer, Loop *Inner);
384
388
bool containsUnsafeInstructions (BasicBlock *BB);
@@ -405,6 +409,11 @@ class LoopInterchangeLegality {
405
409
406
410
// / Set of inner loop induction PHIs
407
411
SmallVector<PHINode *, 8 > InnerLoopInductions;
412
+
413
+ // / Hold instructions that have nuw/nsw flags and involved in reductions,
414
+ // / like integer addition/multiplication. Those flags must be dropped when
415
+ // / interchanging the loops.
416
+ SmallVector<Instruction *, 4 > HasNoWrapReductions;
408
417
};
409
418
410
419
// / Manages information utilized by the profitability check for cache. The main
@@ -473,7 +482,7 @@ class LoopInterchangeTransform {
473
482
: OuterLoop(Outer), InnerLoop(Inner), SE(SE), LI(LI), DT(DT), LIL(LIL) {}
474
483
475
484
// / Interchange OuterLoop and InnerLoop.
476
- bool transform ();
485
+ bool transform (ArrayRef<Instruction *> DropNoWrapInsts );
477
486
void restructureLoops (Loop *NewInner, Loop *NewOuter,
478
487
BasicBlock *OrigInnerPreHeader,
479
488
BasicBlock *OrigOuterPreHeader);
@@ -613,7 +622,7 @@ struct LoopInterchange {
613
622
});
614
623
615
624
LoopInterchangeTransform LIT (OuterLoop, InnerLoop, SE, LI, DT, LIL);
616
- LIT.transform ();
625
+ LIT.transform (LIL. getHasNoWrapReductions () );
617
626
LLVM_DEBUG (dbgs () << " Loops interchanged.\n " );
618
627
LoopsInterchanged++;
619
628
@@ -798,7 +807,9 @@ static Value *followLCSSA(Value *SV) {
798
807
}
799
808
800
809
// Check V's users to see if it is involved in a reduction in L.
801
- static PHINode *findInnerReductionPhi (Loop *L, Value *V) {
810
+ static PHINode *
811
+ findInnerReductionPhi (Loop *L, Value *V,
812
+ SmallVectorImpl<Instruction *> &HasNoWrapInsts) {
802
813
// Reduction variables cannot be constants.
803
814
if (isa<Constant>(V))
804
815
return nullptr ;
@@ -812,7 +823,65 @@ static PHINode *findInnerReductionPhi(Loop *L, Value *V) {
812
823
// Detect floating point reduction only when it can be reordered.
813
824
if (RD.getExactFPMathInst () != nullptr )
814
825
return nullptr ;
815
- return PHI;
826
+
827
+ RecurKind RK = RD.getRecurrenceKind ();
828
+ switch (RK) {
829
+ case RecurKind::Or:
830
+ case RecurKind::And:
831
+ case RecurKind::Xor:
832
+ case RecurKind::SMin:
833
+ case RecurKind::SMax:
834
+ case RecurKind::UMin:
835
+ case RecurKind::UMax:
836
+ case RecurKind::FAdd:
837
+ case RecurKind::FMul:
838
+ case RecurKind::FMin:
839
+ case RecurKind::FMax:
840
+ case RecurKind::FMinimum:
841
+ case RecurKind::FMaximum:
842
+ case RecurKind::FMinimumNum:
843
+ case RecurKind::FMaximumNum:
844
+ case RecurKind::FMulAdd:
845
+ case RecurKind::AnyOf:
846
+ return PHI;
847
+
848
+ // Change the order of integer addition/multiplication may change the
849
+ // semantics. Consider the following case:
850
+ //
851
+ // int A[2][2] = {{ INT_MAX, INT_MAX }, { INT_MIN, INT_MIN }};
852
+ // int sum = 0;
853
+ // for (int i = 0; i < 2; i++)
854
+ // for (int j = 0; j < 2; j++)
855
+ // sum += A[j][i];
856
+ //
857
+ // If the above loops are exchanged, the addition will cause an
858
+ // overflow. To prevent this, we must drop the nuw/nsw flags from the
859
+ // addition/multiplication instructions when we actually exchanges the
860
+ // loops.
861
+ case RecurKind::Add:
862
+ case RecurKind::Mul: {
863
+ unsigned OpCode = RecurrenceDescriptor::getOpcode (RK);
864
+ SmallVector<Instruction *, 4 > Ops = RD.getReductionOpChain (PHI, L);
865
+
866
+ // Bail out when we fail to collect reduction instructions chain.
867
+ if (Ops.empty ())
868
+ return nullptr ;
869
+
870
+ for (Instruction *I : Ops) {
871
+ assert (I->getOpcode () == OpCode &&
872
+ " Expected the instruction to be the reduction operation" );
873
+
874
+ // If the instruction has nuw/nsw flags, we must drop them when the
875
+ // transformation is actually performed.
876
+ if (I->hasNoSignedWrap () || I->hasNoUnsignedWrap ())
877
+ HasNoWrapInsts.push_back (I);
878
+ }
879
+ return PHI;
880
+ }
881
+
882
+ default :
883
+ return nullptr ;
884
+ }
816
885
}
817
886
return nullptr ;
818
887
}
@@ -844,7 +913,8 @@ bool LoopInterchangeLegality::findInductionAndReductions(
844
913
// Check if we have a PHI node in the outer loop that has a reduction
845
914
// result from the inner loop as an incoming value.
846
915
Value *V = followLCSSA (PHI.getIncomingValueForBlock (L->getLoopLatch ()));
847
- PHINode *InnerRedPhi = findInnerReductionPhi (InnerLoop, V);
916
+ PHINode *InnerRedPhi =
917
+ findInnerReductionPhi (InnerLoop, V, HasNoWrapReductions);
848
918
if (!InnerRedPhi ||
849
919
!llvm::is_contained (InnerRedPhi->incoming_values (), &PHI)) {
850
920
LLVM_DEBUG (
@@ -1430,7 +1500,8 @@ void LoopInterchangeTransform::restructureLoops(
1430
1500
SE->forgetLoop (NewOuter);
1431
1501
}
1432
1502
1433
- bool LoopInterchangeTransform::transform () {
1503
+ bool LoopInterchangeTransform::transform (
1504
+ ArrayRef<Instruction *> DropNoWrapInsts) {
1434
1505
bool Transformed = false ;
1435
1506
1436
1507
if (InnerLoop->getSubLoops ().empty ()) {
@@ -1531,6 +1602,13 @@ bool LoopInterchangeTransform::transform() {
1531
1602
return false ;
1532
1603
}
1533
1604
1605
+ // Finally, drop the nsw/nuw flags from the instructions for reduction
1606
+ // calculations.
1607
+ for (Instruction *Reduction : DropNoWrapInsts) {
1608
+ Reduction->setHasNoSignedWrap (false );
1609
+ Reduction->setHasNoUnsignedWrap (false );
1610
+ }
1611
+
1534
1612
return true ;
1535
1613
}
1536
1614
0 commit comments