Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
148 changes: 115 additions & 33 deletions llvm/lib/Target/AArch64/AArch64ISelLowering.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1432,12 +1432,24 @@ AArch64TargetLowering::AArch64TargetLowering(const TargetMachine &TM,
setOperationAction(ISD::BITCAST, MVT::v2i16, Custom);
setOperationAction(ISD::BITCAST, MVT::v4i8, Custom);

setLoadExtAction(ISD::EXTLOAD, MVT::v4i16, MVT::v4i8, Custom);
setLoadExtAction(ISD::EXTLOAD, MVT::v2i32, MVT::v2i8, Custom);
setLoadExtAction(ISD::SEXTLOAD, MVT::v2i32, MVT::v2i8, Custom);
setLoadExtAction(ISD::ZEXTLOAD, MVT::v2i32, MVT::v2i8, Custom);
setLoadExtAction(ISD::EXTLOAD, MVT::v2i64, MVT::v2i8, Custom);
setLoadExtAction(ISD::SEXTLOAD, MVT::v2i64, MVT::v2i8, Custom);
setLoadExtAction(ISD::ZEXTLOAD, MVT::v2i64, MVT::v2i8, Custom);
setLoadExtAction(ISD::EXTLOAD, MVT::v4i16, MVT::v4i8, Custom);
setLoadExtAction(ISD::SEXTLOAD, MVT::v4i16, MVT::v4i8, Custom);
setLoadExtAction(ISD::ZEXTLOAD, MVT::v4i16, MVT::v4i8, Custom);
setLoadExtAction(ISD::EXTLOAD, MVT::v4i32, MVT::v4i8, Custom);
setLoadExtAction(ISD::EXTLOAD, MVT::v4i32, MVT::v4i8, Custom);
setLoadExtAction(ISD::SEXTLOAD, MVT::v4i32, MVT::v4i8, Custom);
setLoadExtAction(ISD::ZEXTLOAD, MVT::v4i32, MVT::v4i8, Custom);
setLoadExtAction(ISD::EXTLOAD, MVT::v2i32, MVT::v2i16, Custom);
setLoadExtAction(ISD::SEXTLOAD, MVT::v2i32, MVT::v2i16, Custom);
setLoadExtAction(ISD::ZEXTLOAD, MVT::v2i32, MVT::v2i16, Custom);
setLoadExtAction(ISD::EXTLOAD, MVT::v2i64, MVT::v2i16, Custom);
setLoadExtAction(ISD::SEXTLOAD, MVT::v2i64, MVT::v2i16, Custom);
setLoadExtAction(ISD::ZEXTLOAD, MVT::v2i64, MVT::v2i16, Custom);

// ADDP custom lowering
for (MVT VT : { MVT::v32i8, MVT::v16i16, MVT::v8i32, MVT::v4i64 })
Expand Down Expand Up @@ -6402,8 +6414,34 @@ bool AArch64TargetLowering::shouldRemoveExtendFromGSIndex(SDValue Extend,
return DataVT.isFixedLengthVector() || DataVT.getVectorMinNumElements() > 2;
}

/// Helper function to check if a small vector load can be optimized.
static bool isEligibleForSmallVectorLoadOpt(LoadSDNode *LD,
const AArch64Subtarget &Subtarget) {
if (!Subtarget.isNeonAvailable())
return false;
if (LD->isVolatile())
return false;

EVT MemVT = LD->getMemoryVT();
if (MemVT != MVT::v2i8 && MemVT != MVT::v4i8 && MemVT != MVT::v2i16)
return false;

Align Alignment = LD->getAlign();
Align RequiredAlignment = Align(MemVT.getStoreSize().getFixedValue());
if (Subtarget.requiresStrictAlign() && Alignment < RequiredAlignment)
return false;

return true;
}

bool AArch64TargetLowering::isVectorLoadExtDesirable(SDValue ExtVal) const {
EVT ExtVT = ExtVal.getValueType();
// Small, illegal vectors can be extended inreg.
if (auto *Load = dyn_cast<LoadSDNode>(ExtVal.getOperand(0))) {
if (ExtVT.isFixedLengthVector() && ExtVT.getStoreSizeInBits() <= 128 &&
isEligibleForSmallVectorLoadOpt(Load, *Subtarget))
return true;
}
if (!ExtVT.isScalableVector() && !Subtarget->useSVEForFixedLengthVectors())
return false;

Expand Down Expand Up @@ -6859,12 +6897,86 @@ SDValue AArch64TargetLowering::LowerStore128(SDValue Op,
return Result;
}

/// Helper function to optimize loads of extended small vectors.
/// These patterns would otherwise get scalarized into inefficient sequences.
static SDValue tryLowerSmallVectorExtLoad(LoadSDNode *Load, SelectionDAG &DAG) {
const AArch64Subtarget &Subtarget = DAG.getSubtarget<AArch64Subtarget>();
if (!isEligibleForSmallVectorLoadOpt(Load, Subtarget))
return SDValue();

EVT MemVT = Load->getMemoryVT();
EVT ResVT = Load->getValueType(0);
unsigned NumElts = ResVT.getVectorNumElements();
unsigned DstEltBits = ResVT.getScalarSizeInBits();
unsigned SrcEltBits = MemVT.getScalarSizeInBits();

unsigned ExtOpcode;
switch (Load->getExtensionType()) {
case ISD::EXTLOAD:
case ISD::ZEXTLOAD:
ExtOpcode = ISD::ZERO_EXTEND;
break;
case ISD::SEXTLOAD:
ExtOpcode = ISD::SIGN_EXTEND;
break;
case ISD::NON_EXTLOAD:
return SDValue();
}

SDLoc DL(Load);
SDValue Chain = Load->getChain();
SDValue BasePtr = Load->getBasePtr();
const MachinePointerInfo &PtrInfo = Load->getPointerInfo();
Align Alignment = Load->getAlign();

// Load the data as an FP scalar to avoid issues with integer loads.
unsigned LoadBits = MemVT.getStoreSizeInBits();
MVT ScalarLoadType = MVT::getFloatingPointVT(LoadBits);
SDValue ScalarLoad =
DAG.getLoad(ScalarLoadType, DL, Chain, BasePtr, PtrInfo, Alignment);

MVT ScalarToVecTy = MVT::getVectorVT(ScalarLoadType, 128 / LoadBits);
SDValue ScalarToVec =
DAG.getNode(ISD::SCALAR_TO_VECTOR, DL, ScalarToVecTy, ScalarLoad);
MVT BitcastTy =
MVT::getVectorVT(MVT::getIntegerVT(SrcEltBits), 128 / SrcEltBits);
SDValue Bitcast = DAG.getNode(ISD::BITCAST, DL, BitcastTy, ScalarToVec);

SDValue Res = Bitcast;
unsigned CurrentEltBits = Res.getValueType().getScalarSizeInBits();
unsigned CurrentNumElts = Res.getValueType().getVectorNumElements();
while (CurrentEltBits < DstEltBits) {
if (Res.getValueSizeInBits() >= 128) {
CurrentNumElts = CurrentNumElts / 2;
MVT ExtractVT =
MVT::getVectorVT(MVT::getIntegerVT(CurrentEltBits), CurrentNumElts);
Res = DAG.getNode(ISD::EXTRACT_SUBVECTOR, DL, ExtractVT, Res,
DAG.getConstant(0, DL, MVT::i64));
}
CurrentEltBits = CurrentEltBits * 2;
MVT ExtVT =
MVT::getVectorVT(MVT::getIntegerVT(CurrentEltBits), CurrentNumElts);
Res = DAG.getNode(ExtOpcode, DL, ExtVT, Res);
}

if (CurrentNumElts != NumElts) {
MVT FinalVT = MVT::getVectorVT(MVT::getIntegerVT(CurrentEltBits), NumElts);
Res = DAG.getNode(ISD::EXTRACT_SUBVECTOR, DL, FinalVT, Res,
DAG.getConstant(0, DL, MVT::i64));
}

return DAG.getMergeValues({Res, ScalarLoad.getValue(1)}, DL);
}

SDValue AArch64TargetLowering::LowerLOAD(SDValue Op,
SelectionDAG &DAG) const {
SDLoc DL(Op);
LoadSDNode *LoadNode = cast<LoadSDNode>(Op);
assert(LoadNode && "Expected custom lowering of a load node");

if (SDValue Result = tryLowerSmallVectorExtLoad(LoadNode, DAG))
return Result;

if (LoadNode->getMemoryVT() == MVT::i64x8) {
SmallVector<SDValue, 8> Ops;
SDValue Base = LoadNode->getBasePtr();
Expand All @@ -6883,37 +6995,7 @@ SDValue AArch64TargetLowering::LowerLOAD(SDValue Op,
return DAG.getMergeValues({Loaded, Chain}, DL);
}

// Custom lowering for extending v4i8 vector loads.
EVT VT = Op->getValueType(0);
assert((VT == MVT::v4i16 || VT == MVT::v4i32) && "Expected v4i16 or v4i32");

if (LoadNode->getMemoryVT() != MVT::v4i8)
return SDValue();

// Avoid generating unaligned loads.
if (Subtarget->requiresStrictAlign() && LoadNode->getAlign() < Align(4))
return SDValue();

unsigned ExtType;
if (LoadNode->getExtensionType() == ISD::SEXTLOAD)
ExtType = ISD::SIGN_EXTEND;
else if (LoadNode->getExtensionType() == ISD::ZEXTLOAD ||
LoadNode->getExtensionType() == ISD::EXTLOAD)
ExtType = ISD::ZERO_EXTEND;
else
return SDValue();

SDValue Load = DAG.getLoad(MVT::f32, DL, LoadNode->getChain(),
LoadNode->getBasePtr(), MachinePointerInfo());
SDValue Chain = Load.getValue(1);
SDValue Vec = DAG.getNode(ISD::SCALAR_TO_VECTOR, DL, MVT::v2f32, Load);
SDValue BC = DAG.getNode(ISD::BITCAST, DL, MVT::v8i8, Vec);
SDValue Ext = DAG.getNode(ExtType, DL, MVT::v8i16, BC);
Ext = DAG.getNode(ISD::EXTRACT_SUBVECTOR, DL, MVT::v4i16, Ext,
DAG.getConstant(0, DL, MVT::i64));
if (VT == MVT::v4i32)
Ext = DAG.getNode(ExtType, DL, MVT::v4i32, Ext);
return DAG.getMergeValues({Ext, Chain}, DL);
return SDValue();
}

SDValue AArch64TargetLowering::LowerVECTOR_COMPRESS(SDValue Op,
Expand Down
Loading