Skip to content

Commit 03f1c8f

Browse files
committed
[AArch64] Optimize extending loads of small vectors (llvm#163064)
Reduces the total amount of loads and the amount of moves between SIMD registers and general-purpose registers.
1 parent bd0335c commit 03f1c8f

25 files changed

+567
-350
lines changed

llvm/lib/Target/AArch64/AArch64ISelLowering.cpp

Lines changed: 115 additions & 33 deletions
Original file line numberDiff line numberDiff line change
@@ -1432,12 +1432,24 @@ AArch64TargetLowering::AArch64TargetLowering(const TargetMachine &TM,
14321432
setOperationAction(ISD::BITCAST, MVT::v2i16, Custom);
14331433
setOperationAction(ISD::BITCAST, MVT::v4i8, Custom);
14341434

1435-
setLoadExtAction(ISD::EXTLOAD, MVT::v4i16, MVT::v4i8, Custom);
1435+
setLoadExtAction(ISD::EXTLOAD, MVT::v2i32, MVT::v2i8, Custom);
1436+
setLoadExtAction(ISD::SEXTLOAD, MVT::v2i32, MVT::v2i8, Custom);
1437+
setLoadExtAction(ISD::ZEXTLOAD, MVT::v2i32, MVT::v2i8, Custom);
1438+
setLoadExtAction(ISD::EXTLOAD, MVT::v2i64, MVT::v2i8, Custom);
1439+
setLoadExtAction(ISD::SEXTLOAD, MVT::v2i64, MVT::v2i8, Custom);
1440+
setLoadExtAction(ISD::ZEXTLOAD, MVT::v2i64, MVT::v2i8, Custom);
1441+
setLoadExtAction(ISD::EXTLOAD, MVT::v4i16, MVT::v4i8, Custom);
14361442
setLoadExtAction(ISD::SEXTLOAD, MVT::v4i16, MVT::v4i8, Custom);
14371443
setLoadExtAction(ISD::ZEXTLOAD, MVT::v4i16, MVT::v4i8, Custom);
1438-
setLoadExtAction(ISD::EXTLOAD, MVT::v4i32, MVT::v4i8, Custom);
1444+
setLoadExtAction(ISD::EXTLOAD, MVT::v4i32, MVT::v4i8, Custom);
14391445
setLoadExtAction(ISD::SEXTLOAD, MVT::v4i32, MVT::v4i8, Custom);
14401446
setLoadExtAction(ISD::ZEXTLOAD, MVT::v4i32, MVT::v4i8, Custom);
1447+
setLoadExtAction(ISD::EXTLOAD, MVT::v2i32, MVT::v2i16, Custom);
1448+
setLoadExtAction(ISD::SEXTLOAD, MVT::v2i32, MVT::v2i16, Custom);
1449+
setLoadExtAction(ISD::ZEXTLOAD, MVT::v2i32, MVT::v2i16, Custom);
1450+
setLoadExtAction(ISD::EXTLOAD, MVT::v2i64, MVT::v2i16, Custom);
1451+
setLoadExtAction(ISD::SEXTLOAD, MVT::v2i64, MVT::v2i16, Custom);
1452+
setLoadExtAction(ISD::ZEXTLOAD, MVT::v2i64, MVT::v2i16, Custom);
14411453

14421454
// ADDP custom lowering
14431455
for (MVT VT : { MVT::v32i8, MVT::v16i16, MVT::v8i32, MVT::v4i64 })
@@ -6402,8 +6414,34 @@ bool AArch64TargetLowering::shouldRemoveExtendFromGSIndex(SDValue Extend,
64026414
return DataVT.isFixedLengthVector() || DataVT.getVectorMinNumElements() > 2;
64036415
}
64046416

6417+
/// Helper function to check if a small vector load can be optimized.
6418+
static bool isEligibleForSmallVectorLoadOpt(LoadSDNode *LD,
6419+
const AArch64Subtarget &Subtarget) {
6420+
if (!Subtarget.isNeonAvailable())
6421+
return false;
6422+
if (LD->isVolatile())
6423+
return false;
6424+
6425+
EVT MemVT = LD->getMemoryVT();
6426+
if (MemVT != MVT::v2i8 && MemVT != MVT::v4i8 && MemVT != MVT::v2i16)
6427+
return false;
6428+
6429+
Align Alignment = LD->getAlign();
6430+
Align RequiredAlignment = Align(MemVT.getStoreSize().getFixedValue());
6431+
if (Subtarget.requiresStrictAlign() && Alignment < RequiredAlignment)
6432+
return false;
6433+
6434+
return true;
6435+
}
6436+
64056437
bool AArch64TargetLowering::isVectorLoadExtDesirable(SDValue ExtVal) const {
64066438
EVT ExtVT = ExtVal.getValueType();
6439+
// Small, illegal vectors can be extended inreg.
6440+
if (auto *Load = dyn_cast<LoadSDNode>(ExtVal.getOperand(0))) {
6441+
if (ExtVT.isFixedLengthVector() && ExtVT.getStoreSizeInBits() <= 128 &&
6442+
isEligibleForSmallVectorLoadOpt(Load, *Subtarget))
6443+
return true;
6444+
}
64076445
if (!ExtVT.isScalableVector() && !Subtarget->useSVEForFixedLengthVectors())
64086446
return false;
64096447

@@ -6859,12 +6897,86 @@ SDValue AArch64TargetLowering::LowerStore128(SDValue Op,
68596897
return Result;
68606898
}
68616899

6900+
/// Helper function to optimize loads of extended small vectors.
6901+
/// These patterns would otherwise get scalarized into inefficient sequences.
6902+
static SDValue tryLowerSmallVectorExtLoad(LoadSDNode *Load, SelectionDAG &DAG) {
6903+
const AArch64Subtarget &Subtarget = DAG.getSubtarget<AArch64Subtarget>();
6904+
if (!isEligibleForSmallVectorLoadOpt(Load, Subtarget))
6905+
return SDValue();
6906+
6907+
EVT MemVT = Load->getMemoryVT();
6908+
EVT ResVT = Load->getValueType(0);
6909+
unsigned NumElts = ResVT.getVectorNumElements();
6910+
unsigned DstEltBits = ResVT.getScalarSizeInBits();
6911+
unsigned SrcEltBits = MemVT.getScalarSizeInBits();
6912+
6913+
unsigned ExtOpcode;
6914+
switch (Load->getExtensionType()) {
6915+
case ISD::EXTLOAD:
6916+
case ISD::ZEXTLOAD:
6917+
ExtOpcode = ISD::ZERO_EXTEND;
6918+
break;
6919+
case ISD::SEXTLOAD:
6920+
ExtOpcode = ISD::SIGN_EXTEND;
6921+
break;
6922+
case ISD::NON_EXTLOAD:
6923+
return SDValue();
6924+
}
6925+
6926+
SDLoc DL(Load);
6927+
SDValue Chain = Load->getChain();
6928+
SDValue BasePtr = Load->getBasePtr();
6929+
const MachinePointerInfo &PtrInfo = Load->getPointerInfo();
6930+
Align Alignment = Load->getAlign();
6931+
6932+
// Load the data as an FP scalar to avoid issues with integer loads.
6933+
unsigned LoadBits = MemVT.getStoreSizeInBits();
6934+
MVT ScalarLoadType = MVT::getFloatingPointVT(LoadBits);
6935+
SDValue ScalarLoad =
6936+
DAG.getLoad(ScalarLoadType, DL, Chain, BasePtr, PtrInfo, Alignment);
6937+
6938+
MVT ScalarToVecTy = MVT::getVectorVT(ScalarLoadType, 128 / LoadBits);
6939+
SDValue ScalarToVec =
6940+
DAG.getNode(ISD::SCALAR_TO_VECTOR, DL, ScalarToVecTy, ScalarLoad);
6941+
MVT BitcastTy =
6942+
MVT::getVectorVT(MVT::getIntegerVT(SrcEltBits), 128 / SrcEltBits);
6943+
SDValue Bitcast = DAG.getNode(ISD::BITCAST, DL, BitcastTy, ScalarToVec);
6944+
6945+
SDValue Res = Bitcast;
6946+
unsigned CurrentEltBits = Res.getValueType().getScalarSizeInBits();
6947+
unsigned CurrentNumElts = Res.getValueType().getVectorNumElements();
6948+
while (CurrentEltBits < DstEltBits) {
6949+
if (Res.getValueSizeInBits() >= 128) {
6950+
CurrentNumElts = CurrentNumElts / 2;
6951+
MVT ExtractVT =
6952+
MVT::getVectorVT(MVT::getIntegerVT(CurrentEltBits), CurrentNumElts);
6953+
Res = DAG.getNode(ISD::EXTRACT_SUBVECTOR, DL, ExtractVT, Res,
6954+
DAG.getConstant(0, DL, MVT::i64));
6955+
}
6956+
CurrentEltBits = CurrentEltBits * 2;
6957+
MVT ExtVT =
6958+
MVT::getVectorVT(MVT::getIntegerVT(CurrentEltBits), CurrentNumElts);
6959+
Res = DAG.getNode(ExtOpcode, DL, ExtVT, Res);
6960+
}
6961+
6962+
if (CurrentNumElts != NumElts) {
6963+
MVT FinalVT = MVT::getVectorVT(MVT::getIntegerVT(CurrentEltBits), NumElts);
6964+
Res = DAG.getNode(ISD::EXTRACT_SUBVECTOR, DL, FinalVT, Res,
6965+
DAG.getConstant(0, DL, MVT::i64));
6966+
}
6967+
6968+
return DAG.getMergeValues({Res, ScalarLoad.getValue(1)}, DL);
6969+
}
6970+
68626971
SDValue AArch64TargetLowering::LowerLOAD(SDValue Op,
68636972
SelectionDAG &DAG) const {
68646973
SDLoc DL(Op);
68656974
LoadSDNode *LoadNode = cast<LoadSDNode>(Op);
68666975
assert(LoadNode && "Expected custom lowering of a load node");
68676976

6977+
if (SDValue Result = tryLowerSmallVectorExtLoad(LoadNode, DAG))
6978+
return Result;
6979+
68686980
if (LoadNode->getMemoryVT() == MVT::i64x8) {
68696981
SmallVector<SDValue, 8> Ops;
68706982
SDValue Base = LoadNode->getBasePtr();
@@ -6883,37 +6995,7 @@ SDValue AArch64TargetLowering::LowerLOAD(SDValue Op,
68836995
return DAG.getMergeValues({Loaded, Chain}, DL);
68846996
}
68856997

6886-
// Custom lowering for extending v4i8 vector loads.
6887-
EVT VT = Op->getValueType(0);
6888-
assert((VT == MVT::v4i16 || VT == MVT::v4i32) && "Expected v4i16 or v4i32");
6889-
6890-
if (LoadNode->getMemoryVT() != MVT::v4i8)
6891-
return SDValue();
6892-
6893-
// Avoid generating unaligned loads.
6894-
if (Subtarget->requiresStrictAlign() && LoadNode->getAlign() < Align(4))
6895-
return SDValue();
6896-
6897-
unsigned ExtType;
6898-
if (LoadNode->getExtensionType() == ISD::SEXTLOAD)
6899-
ExtType = ISD::SIGN_EXTEND;
6900-
else if (LoadNode->getExtensionType() == ISD::ZEXTLOAD ||
6901-
LoadNode->getExtensionType() == ISD::EXTLOAD)
6902-
ExtType = ISD::ZERO_EXTEND;
6903-
else
6904-
return SDValue();
6905-
6906-
SDValue Load = DAG.getLoad(MVT::f32, DL, LoadNode->getChain(),
6907-
LoadNode->getBasePtr(), MachinePointerInfo());
6908-
SDValue Chain = Load.getValue(1);
6909-
SDValue Vec = DAG.getNode(ISD::SCALAR_TO_VECTOR, DL, MVT::v2f32, Load);
6910-
SDValue BC = DAG.getNode(ISD::BITCAST, DL, MVT::v8i8, Vec);
6911-
SDValue Ext = DAG.getNode(ExtType, DL, MVT::v8i16, BC);
6912-
Ext = DAG.getNode(ISD::EXTRACT_SUBVECTOR, DL, MVT::v4i16, Ext,
6913-
DAG.getConstant(0, DL, MVT::i64));
6914-
if (VT == MVT::v4i32)
6915-
Ext = DAG.getNode(ExtType, DL, MVT::v4i32, Ext);
6916-
return DAG.getMergeValues({Ext, Chain}, DL);
6998+
return SDValue();
69176999
}
69187000

69197001
SDValue AArch64TargetLowering::LowerVECTOR_COMPRESS(SDValue Op,

0 commit comments

Comments
 (0)