Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
164 changes: 139 additions & 25 deletions llvm/lib/Transforms/Vectorize/VectorCombine.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -123,7 +123,9 @@ class VectorCombine {
bool foldExtractedCmps(Instruction &I);
bool foldBinopOfReductions(Instruction &I);
bool foldSingleElementStore(Instruction &I);
bool scalarizeLoadExtract(Instruction &I);
bool scalarizeLoad(Instruction &I);
bool scalarizeLoadExtract(LoadInst *LI, VectorType *VecTy, Value *Ptr);
bool scalarizeLoadBitcast(LoadInst *LI, VectorType *VecTy, Value *Ptr);
bool scalarizeExtExtract(Instruction &I);
bool foldConcatOfBoolMasks(Instruction &I);
bool foldPermuteOfBinops(Instruction &I);
Expand Down Expand Up @@ -1664,8 +1666,9 @@ bool VectorCombine::foldSingleElementStore(Instruction &I) {
return false;
}

/// Try to scalarize vector loads feeding extractelement instructions.
bool VectorCombine::scalarizeLoadExtract(Instruction &I) {
/// Try to scalarize vector loads feeding extractelement or bitcast
/// instructions.
bool VectorCombine::scalarizeLoad(Instruction &I) {
Value *Ptr;
if (!match(&I, m_Load(m_Value(Ptr))))
return false;
Expand All @@ -1675,35 +1678,30 @@ bool VectorCombine::scalarizeLoadExtract(Instruction &I) {
if (LI->isVolatile() || !DL->typeSizeEqualsStoreSize(VecTy->getScalarType()))
return false;

InstructionCost OriginalCost =
TTI.getMemoryOpCost(Instruction::Load, VecTy, LI->getAlign(),
LI->getPointerAddressSpace(), CostKind);
InstructionCost ScalarizedCost = 0;

bool AllExtracts = true;
bool AllBitcasts = true;
Instruction *LastCheckedInst = LI;
unsigned NumInstChecked = 0;
DenseMap<ExtractElementInst *, ScalarizationResult> NeedFreeze;
auto FailureGuard = make_scope_exit([&]() {
// If the transform is aborted, discard the ScalarizationResults.
for (auto &Pair : NeedFreeze)
Pair.second.discard();
});

// Check if all users of the load are extracts with no memory modifications
// between the load and the extract. Compute the cost of both the original
// code and the scalarized version.
// Check what type of users we have (must either all be extracts or
// bitcasts) and ensure no memory modifications between the load and
// its users.
for (User *U : LI->users()) {
auto *UI = dyn_cast<ExtractElementInst>(U);
auto *UI = dyn_cast<Instruction>(U);
if (!UI || UI->getParent() != LI->getParent())
return false;

// If any extract is waiting to be erased, then bail out as this will
// If any user is waiting to be erased, then bail out as this will
// distort the cost calculation and possibly lead to infinite loops.
if (UI->use_empty())
return false;

// Check if any instruction between the load and the extract may modify
// memory.
if (!isa<ExtractElementInst>(UI))
AllExtracts = false;
if (!isa<BitCastInst>(UI))
AllBitcasts = false;

// Check if any instruction between the load and the user may modify memory.
if (LastCheckedInst->comesBefore(UI)) {
for (Instruction &I :
make_range(std::next(LI->getIterator()), UI->getIterator())) {
Expand All @@ -1715,6 +1713,33 @@ bool VectorCombine::scalarizeLoadExtract(Instruction &I) {
}
LastCheckedInst = UI;
}
}

if (AllExtracts)
return scalarizeLoadExtract(LI, VecTy, Ptr);
if (AllBitcasts)
return scalarizeLoadBitcast(LI, VecTy, Ptr);
return false;
}

/// Try to scalarize vector loads feeding extractelement instructions.
bool VectorCombine::scalarizeLoadExtract(LoadInst *LI, VectorType *VecTy,
Value *Ptr) {

DenseMap<ExtractElementInst *, ScalarizationResult> NeedFreeze;
auto FailureGuard = make_scope_exit([&]() {
// If the transform is aborted, discard the ScalarizationResults.
for (auto &Pair : NeedFreeze)
Pair.second.discard();
});

InstructionCost OriginalCost =
TTI.getMemoryOpCost(Instruction::Load, VecTy, LI->getAlign(),
LI->getPointerAddressSpace(), CostKind);
InstructionCost ScalarizedCost = 0;

for (User *U : LI->users()) {
auto *UI = cast<ExtractElementInst>(U);

auto ScalarIdx =
canScalarizeAccess(VecTy, UI->getIndexOperand(), LI, AC, DT);
Expand All @@ -1735,7 +1760,7 @@ bool VectorCombine::scalarizeLoadExtract(Instruction &I) {
ScalarizedCost += TTI.getAddressComputationCost(VecTy->getElementType());
}

LLVM_DEBUG(dbgs() << "Found all extractions of a vector load: " << I
LLVM_DEBUG(dbgs() << "Found all extractions of a vector load: " << *LI
<< "\n LoadExtractCost: " << OriginalCost
<< " vs ScalarizedCost: " << ScalarizedCost << "\n");

Expand Down Expand Up @@ -1773,6 +1798,72 @@ bool VectorCombine::scalarizeLoadExtract(Instruction &I) {
return true;
}

/// Try to scalarize vector loads feeding bitcast instructions.
bool VectorCombine::scalarizeLoadBitcast(LoadInst *LI, VectorType *VecTy,
Value *Ptr) {
InstructionCost OriginalCost =
TTI.getMemoryOpCost(Instruction::Load, VecTy, LI->getAlign(),
LI->getPointerAddressSpace(), CostKind);

Type *TargetScalarType = nullptr;
unsigned VecBitWidth = DL->getTypeSizeInBits(VecTy);

for (User *U : LI->users()) {
auto *BC = cast<BitCastInst>(U);

Type *DestTy = BC->getDestTy();
if (!DestTy->isIntegerTy() && !DestTy->isFloatingPointTy())
return false;

unsigned DestBitWidth = DL->getTypeSizeInBits(DestTy);
if (DestBitWidth != VecBitWidth)
return false;

// All bitcasts must target the same scalar type.
if (!TargetScalarType)
TargetScalarType = DestTy;
else if (TargetScalarType != DestTy)
return false;

OriginalCost +=
TTI.getCastInstrCost(Instruction::BitCast, TargetScalarType, VecTy,
TTI.getCastContextHint(BC), CostKind, BC);
}

if (!TargetScalarType)
return false;

assert(!LI->user_empty() && "Unexpected load without bitcast users");
InstructionCost ScalarizedCost =
TTI.getMemoryOpCost(Instruction::Load, TargetScalarType, LI->getAlign(),
LI->getPointerAddressSpace(), CostKind);

LLVM_DEBUG(dbgs() << "Found vector load feeding only bitcasts: " << *LI
<< "\n OriginalCost: " << OriginalCost
<< " vs ScalarizedCost: " << ScalarizedCost << "\n");

if (ScalarizedCost >= OriginalCost)
return false;

// Ensure we add the load back to the worklist BEFORE its users so they can
// erased in the correct order.
Worklist.push(LI);

Builder.SetInsertPoint(LI);
auto *ScalarLoad =
Builder.CreateLoad(TargetScalarType, Ptr, LI->getName() + ".scalar");
ScalarLoad->setAlignment(LI->getAlign());
ScalarLoad->copyMetadata(*LI);

// Replace all bitcast users with the scalar load.
for (User *U : LI->users()) {
auto *BC = cast<BitCastInst>(U);
replaceValue(*BC, *ScalarLoad);
}

return true;
}

bool VectorCombine::scalarizeExtExtract(Instruction &I) {
auto *Ext = dyn_cast<ZExtInst>(&I);
if (!Ext)
Expand Down Expand Up @@ -1822,8 +1913,31 @@ bool VectorCombine::scalarizeExtExtract(Instruction &I) {

Value *ScalarV = Ext->getOperand(0);
if (!isGuaranteedNotToBePoison(ScalarV, &AC, dyn_cast<Instruction>(ScalarV),
&DT))
ScalarV = Builder.CreateFreeze(ScalarV);
&DT)) {
// Check wether all lanes are extracted, all extracts trigger UB
// on poison, and the last extract (and hence all previous ones)
// are guaranteed to execute if Ext executes. If so, we do not
// need to insert a freeze.
SmallDenseSet<ConstantInt *, 8> ExtractedLanes;
bool AllExtractsTriggerUB = true;
ExtractElementInst *LastExtract = nullptr;
BasicBlock *ExtBB = Ext->getParent();
for (User *U : Ext->users()) {
auto *Extract = cast<ExtractElementInst>(U);
if (Extract->getParent() != ExtBB || !programUndefinedIfPoison(Extract)) {
AllExtractsTriggerUB = false;
break;
}
ExtractedLanes.insert(cast<ConstantInt>(Extract->getIndexOperand()));
if (!LastExtract || LastExtract->comesBefore(Extract))
LastExtract = Extract;
}
if (ExtractedLanes.size() != DstTy->getNumElements() ||
!AllExtractsTriggerUB ||
!isGuaranteedToTransferExecutionToSuccessor(Ext->getIterator(),
LastExtract->getIterator()))
ScalarV = Builder.CreateFreeze(ScalarV);
}
ScalarV = Builder.CreateBitCast(
ScalarV,
IntegerType::get(SrcTy->getContext(), DL->getTypeSizeInBits(SrcTy)));
Expand Down Expand Up @@ -3734,7 +3848,7 @@ bool VectorCombine::run() {
// TODO: Identify and allow other scalable transforms
if (IsVectorType) {
MadeChange |= scalarizeOpOrCmp(I);
MadeChange |= scalarizeLoadExtract(I);
MadeChange |= scalarizeLoad(I);
MadeChange |= scalarizeExtExtract(I);
MadeChange |= scalarizeVPIntrinsic(I);
MadeChange |= foldInterleaveIntrinsics(I);
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,32 @@
; NOTE: Assertions have been autogenerated by utils/update_test_checks.py UTC_ARGS: --version 6
; RUN: opt -O3 -mtriple=arm64-apple-darwinos -S %s | FileCheck %s

define noundef i32 @load_ext_extract(ptr %src) {
; CHECK-LABEL: define noundef range(i32 0, 1021) i32 @load_ext_extract(
; CHECK-SAME: ptr readonly captures(none) [[SRC:%.*]]) local_unnamed_addr #[[ATTR0:[0-9]+]] {
; CHECK-NEXT: [[ENTRY:.*:]]
; CHECK-NEXT: [[TMP14:%.*]] = load i32, ptr [[SRC]], align 4
; CHECK-NEXT: [[TMP15:%.*]] = lshr i32 [[TMP14]], 24
; CHECK-NEXT: [[TMP16:%.*]] = lshr i32 [[TMP14]], 16
; CHECK-NEXT: [[TMP17:%.*]] = and i32 [[TMP16]], 255
; CHECK-NEXT: [[TMP18:%.*]] = lshr i32 [[TMP14]], 8
; CHECK-NEXT: [[TMP19:%.*]] = and i32 [[TMP18]], 255
; CHECK-NEXT: [[TMP20:%.*]] = and i32 [[TMP14]], 255
; CHECK-NEXT: [[ADD1:%.*]] = add nuw nsw i32 [[TMP20]], [[TMP19]]
; CHECK-NEXT: [[ADD2:%.*]] = add nuw nsw i32 [[ADD1]], [[TMP17]]
; CHECK-NEXT: [[ADD3:%.*]] = add nuw nsw i32 [[ADD2]], [[TMP15]]
; CHECK-NEXT: ret i32 [[ADD3]]
;
entry:
%x = load <4 x i8>, ptr %src, align 4
%ext = zext nneg <4 x i8> %x to <4 x i32>
%ext.0 = extractelement <4 x i32> %ext, i64 0
%ext.1 = extractelement <4 x i32> %ext, i64 1
%ext.2 = extractelement <4 x i32> %ext, i64 2
%ext.3 = extractelement <4 x i32> %ext, i64 3

%add1 = add i32 %ext.0, %ext.1
%add2 = add i32 %add1, %ext.2
%add3 = add i32 %add2, %ext.3
ret i32 %add3
}
Loading