@@ -1432,12 +1432,24 @@ AArch64TargetLowering::AArch64TargetLowering(const TargetMachine &TM,
14321432 setOperationAction(ISD::BITCAST, MVT::v2i16, Custom);
14331433 setOperationAction(ISD::BITCAST, MVT::v4i8, Custom);
14341434
1435- setLoadExtAction(ISD::EXTLOAD, MVT::v4i16, MVT::v4i8, Custom);
1435+ setLoadExtAction(ISD::EXTLOAD, MVT::v2i32, MVT::v2i8, Custom);
1436+ setLoadExtAction(ISD::SEXTLOAD, MVT::v2i32, MVT::v2i8, Custom);
1437+ setLoadExtAction(ISD::ZEXTLOAD, MVT::v2i32, MVT::v2i8, Custom);
1438+ setLoadExtAction(ISD::EXTLOAD, MVT::v2i64, MVT::v2i8, Custom);
1439+ setLoadExtAction(ISD::SEXTLOAD, MVT::v2i64, MVT::v2i8, Custom);
1440+ setLoadExtAction(ISD::ZEXTLOAD, MVT::v2i64, MVT::v2i8, Custom);
1441+ setLoadExtAction(ISD::EXTLOAD, MVT::v4i16, MVT::v4i8, Custom);
14361442 setLoadExtAction(ISD::SEXTLOAD, MVT::v4i16, MVT::v4i8, Custom);
14371443 setLoadExtAction(ISD::ZEXTLOAD, MVT::v4i16, MVT::v4i8, Custom);
1438- setLoadExtAction(ISD::EXTLOAD, MVT::v4i32, MVT::v4i8, Custom);
1444+ setLoadExtAction(ISD::EXTLOAD, MVT::v4i32, MVT::v4i8, Custom);
14391445 setLoadExtAction(ISD::SEXTLOAD, MVT::v4i32, MVT::v4i8, Custom);
14401446 setLoadExtAction(ISD::ZEXTLOAD, MVT::v4i32, MVT::v4i8, Custom);
1447+ setLoadExtAction(ISD::EXTLOAD, MVT::v2i32, MVT::v2i16, Custom);
1448+ setLoadExtAction(ISD::SEXTLOAD, MVT::v2i32, MVT::v2i16, Custom);
1449+ setLoadExtAction(ISD::ZEXTLOAD, MVT::v2i32, MVT::v2i16, Custom);
1450+ setLoadExtAction(ISD::EXTLOAD, MVT::v2i64, MVT::v2i16, Custom);
1451+ setLoadExtAction(ISD::SEXTLOAD, MVT::v2i64, MVT::v2i16, Custom);
1452+ setLoadExtAction(ISD::ZEXTLOAD, MVT::v2i64, MVT::v2i16, Custom);
14411453
14421454 // ADDP custom lowering
14431455 for (MVT VT : { MVT::v32i8, MVT::v16i16, MVT::v8i32, MVT::v4i64 })
@@ -6402,8 +6414,34 @@ bool AArch64TargetLowering::shouldRemoveExtendFromGSIndex(SDValue Extend,
64026414 return DataVT.isFixedLengthVector() || DataVT.getVectorMinNumElements() > 2;
64036415}
64046416
6417+ /// Helper function to check if a small vector load can be optimized.
6418+ static bool isEligibleForSmallVectorLoadOpt(LoadSDNode *LD,
6419+ const AArch64Subtarget &Subtarget) {
6420+ if (!Subtarget.isNeonAvailable())
6421+ return false;
6422+ if (LD->isVolatile())
6423+ return false;
6424+
6425+ EVT MemVT = LD->getMemoryVT();
6426+ if (MemVT != MVT::v2i8 && MemVT != MVT::v4i8 && MemVT != MVT::v2i16)
6427+ return false;
6428+
6429+ Align Alignment = LD->getAlign();
6430+ Align RequiredAlignment = Align(MemVT.getStoreSize().getFixedValue());
6431+ if (Subtarget.requiresStrictAlign() && Alignment < RequiredAlignment)
6432+ return false;
6433+
6434+ return true;
6435+ }
6436+
64056437bool AArch64TargetLowering::isVectorLoadExtDesirable(SDValue ExtVal) const {
64066438 EVT ExtVT = ExtVal.getValueType();
6439+ // Small, illegal vectors can be extended inreg.
6440+ if (auto *Load = dyn_cast<LoadSDNode>(ExtVal.getOperand(0))) {
6441+ if (ExtVT.isFixedLengthVector() && ExtVT.getStoreSizeInBits() <= 128 &&
6442+ isEligibleForSmallVectorLoadOpt(Load, *Subtarget))
6443+ return true;
6444+ }
64076445 if (!ExtVT.isScalableVector() && !Subtarget->useSVEForFixedLengthVectors())
64086446 return false;
64096447
@@ -6859,12 +6897,86 @@ SDValue AArch64TargetLowering::LowerStore128(SDValue Op,
68596897 return Result;
68606898}
68616899
6900+ /// Helper function to optimize loads of extended small vectors.
6901+ /// These patterns would otherwise get scalarized into inefficient sequences.
6902+ static SDValue tryLowerSmallVectorExtLoad(LoadSDNode *Load, SelectionDAG &DAG) {
6903+ const AArch64Subtarget &Subtarget = DAG.getSubtarget<AArch64Subtarget>();
6904+ if (!isEligibleForSmallVectorLoadOpt(Load, Subtarget))
6905+ return SDValue();
6906+
6907+ EVT MemVT = Load->getMemoryVT();
6908+ EVT ResVT = Load->getValueType(0);
6909+ unsigned NumElts = ResVT.getVectorNumElements();
6910+ unsigned DstEltBits = ResVT.getScalarSizeInBits();
6911+ unsigned SrcEltBits = MemVT.getScalarSizeInBits();
6912+
6913+ unsigned ExtOpcode;
6914+ switch (Load->getExtensionType()) {
6915+ case ISD::EXTLOAD:
6916+ case ISD::ZEXTLOAD:
6917+ ExtOpcode = ISD::ZERO_EXTEND;
6918+ break;
6919+ case ISD::SEXTLOAD:
6920+ ExtOpcode = ISD::SIGN_EXTEND;
6921+ break;
6922+ case ISD::NON_EXTLOAD:
6923+ return SDValue();
6924+ }
6925+
6926+ SDLoc DL(Load);
6927+ SDValue Chain = Load->getChain();
6928+ SDValue BasePtr = Load->getBasePtr();
6929+ const MachinePointerInfo &PtrInfo = Load->getPointerInfo();
6930+ Align Alignment = Load->getAlign();
6931+
6932+ // Load the data as an FP scalar to avoid issues with integer loads.
6933+ unsigned LoadBits = MemVT.getStoreSizeInBits();
6934+ MVT ScalarLoadType = MVT::getFloatingPointVT(LoadBits);
6935+ SDValue ScalarLoad =
6936+ DAG.getLoad(ScalarLoadType, DL, Chain, BasePtr, PtrInfo, Alignment);
6937+
6938+ MVT ScalarToVecTy = MVT::getVectorVT(ScalarLoadType, 128 / LoadBits);
6939+ SDValue ScalarToVec =
6940+ DAG.getNode(ISD::SCALAR_TO_VECTOR, DL, ScalarToVecTy, ScalarLoad);
6941+ MVT BitcastTy =
6942+ MVT::getVectorVT(MVT::getIntegerVT(SrcEltBits), 128 / SrcEltBits);
6943+ SDValue Bitcast = DAG.getNode(ISD::BITCAST, DL, BitcastTy, ScalarToVec);
6944+
6945+ SDValue Res = Bitcast;
6946+ unsigned CurrentEltBits = Res.getValueType().getScalarSizeInBits();
6947+ unsigned CurrentNumElts = Res.getValueType().getVectorNumElements();
6948+ while (CurrentEltBits < DstEltBits) {
6949+ if (Res.getValueSizeInBits() >= 128) {
6950+ CurrentNumElts = CurrentNumElts / 2;
6951+ MVT ExtractVT =
6952+ MVT::getVectorVT(MVT::getIntegerVT(CurrentEltBits), CurrentNumElts);
6953+ Res = DAG.getNode(ISD::EXTRACT_SUBVECTOR, DL, ExtractVT, Res,
6954+ DAG.getConstant(0, DL, MVT::i64));
6955+ }
6956+ CurrentEltBits = CurrentEltBits * 2;
6957+ MVT ExtVT =
6958+ MVT::getVectorVT(MVT::getIntegerVT(CurrentEltBits), CurrentNumElts);
6959+ Res = DAG.getNode(ExtOpcode, DL, ExtVT, Res);
6960+ }
6961+
6962+ if (CurrentNumElts != NumElts) {
6963+ MVT FinalVT = MVT::getVectorVT(MVT::getIntegerVT(CurrentEltBits), NumElts);
6964+ Res = DAG.getNode(ISD::EXTRACT_SUBVECTOR, DL, FinalVT, Res,
6965+ DAG.getConstant(0, DL, MVT::i64));
6966+ }
6967+
6968+ return DAG.getMergeValues({Res, ScalarLoad.getValue(1)}, DL);
6969+ }
6970+
68626971SDValue AArch64TargetLowering::LowerLOAD(SDValue Op,
68636972 SelectionDAG &DAG) const {
68646973 SDLoc DL(Op);
68656974 LoadSDNode *LoadNode = cast<LoadSDNode>(Op);
68666975 assert(LoadNode && "Expected custom lowering of a load node");
68676976
6977+ if (SDValue Result = tryLowerSmallVectorExtLoad(LoadNode, DAG))
6978+ return Result;
6979+
68686980 if (LoadNode->getMemoryVT() == MVT::i64x8) {
68696981 SmallVector<SDValue, 8> Ops;
68706982 SDValue Base = LoadNode->getBasePtr();
@@ -6883,37 +6995,7 @@ SDValue AArch64TargetLowering::LowerLOAD(SDValue Op,
68836995 return DAG.getMergeValues({Loaded, Chain}, DL);
68846996 }
68856997
6886- // Custom lowering for extending v4i8 vector loads.
6887- EVT VT = Op->getValueType(0);
6888- assert((VT == MVT::v4i16 || VT == MVT::v4i32) && "Expected v4i16 or v4i32");
6889-
6890- if (LoadNode->getMemoryVT() != MVT::v4i8)
6891- return SDValue();
6892-
6893- // Avoid generating unaligned loads.
6894- if (Subtarget->requiresStrictAlign() && LoadNode->getAlign() < Align(4))
6895- return SDValue();
6896-
6897- unsigned ExtType;
6898- if (LoadNode->getExtensionType() == ISD::SEXTLOAD)
6899- ExtType = ISD::SIGN_EXTEND;
6900- else if (LoadNode->getExtensionType() == ISD::ZEXTLOAD ||
6901- LoadNode->getExtensionType() == ISD::EXTLOAD)
6902- ExtType = ISD::ZERO_EXTEND;
6903- else
6904- return SDValue();
6905-
6906- SDValue Load = DAG.getLoad(MVT::f32, DL, LoadNode->getChain(),
6907- LoadNode->getBasePtr(), MachinePointerInfo());
6908- SDValue Chain = Load.getValue(1);
6909- SDValue Vec = DAG.getNode(ISD::SCALAR_TO_VECTOR, DL, MVT::v2f32, Load);
6910- SDValue BC = DAG.getNode(ISD::BITCAST, DL, MVT::v8i8, Vec);
6911- SDValue Ext = DAG.getNode(ExtType, DL, MVT::v8i16, BC);
6912- Ext = DAG.getNode(ISD::EXTRACT_SUBVECTOR, DL, MVT::v4i16, Ext,
6913- DAG.getConstant(0, DL, MVT::i64));
6914- if (VT == MVT::v4i32)
6915- Ext = DAG.getNode(ExtType, DL, MVT::v4i32, Ext);
6916- return DAG.getMergeValues({Ext, Chain}, DL);
6998+ return SDValue();
69176999}
69187000
69197001SDValue AArch64TargetLowering::LowerVECTOR_COMPRESS(SDValue Op,
0 commit comments