@@ -11122,6 +11122,7 @@ SDValue RISCVTargetLowering::lowerMaskedLoad(SDValue Op,
1112211122 SDValue BasePtr = MemSD->getBasePtr();
1112311123
1112411124 SDValue Mask, PassThru, VL;
11125+ bool IsExpandingLoad = false;
1112511126 if (const auto *VPLoad = dyn_cast<VPLoadSDNode>(Op)) {
1112611127 Mask = VPLoad->getMask();
1112711128 PassThru = DAG.getUNDEF(VT);
@@ -11130,6 +11131,7 @@ SDValue RISCVTargetLowering::lowerMaskedLoad(SDValue Op,
1113011131 const auto *MLoad = cast<MaskedLoadSDNode>(Op);
1113111132 Mask = MLoad->getMask();
1113211133 PassThru = MLoad->getPassThru();
11134+ IsExpandingLoad = MLoad->isExpandingLoad();
1113311135 }
1113411136
1113511137 bool IsUnmasked = ISD::isConstantSplatVectorAllOnes(Mask.getNode());
@@ -11149,25 +11151,59 @@ SDValue RISCVTargetLowering::lowerMaskedLoad(SDValue Op,
1114911151 if (!VL)
1115011152 VL = getDefaultVLOps(VT, ContainerVT, DL, DAG, Subtarget).second;
1115111153
11152- unsigned IntID =
11153- IsUnmasked ? Intrinsic::riscv_vle : Intrinsic::riscv_vle_mask;
11154+ SDValue ExpandingVL;
11155+ if (!IsUnmasked && IsExpandingLoad) {
11156+ ExpandingVL = VL;
11157+ VL =
11158+ DAG.getNode(RISCVISD::VCPOP_VL, DL, XLenVT, Mask,
11159+ getAllOnesMask(Mask.getSimpleValueType(), VL, DL, DAG), VL);
11160+ }
11161+
11162+ unsigned IntID = IsUnmasked || IsExpandingLoad ? Intrinsic::riscv_vle
11163+ : Intrinsic::riscv_vle_mask;
1115411164 SmallVector<SDValue, 8> Ops{Chain, DAG.getTargetConstant(IntID, DL, XLenVT)};
11155- if (IsUnmasked )
11165+ if (IntID == Intrinsic::riscv_vle )
1115611166 Ops.push_back(DAG.getUNDEF(ContainerVT));
1115711167 else
1115811168 Ops.push_back(PassThru);
1115911169 Ops.push_back(BasePtr);
11160- if (!IsUnmasked )
11170+ if (IntID == Intrinsic::riscv_vle_mask )
1116111171 Ops.push_back(Mask);
1116211172 Ops.push_back(VL);
11163- if (!IsUnmasked )
11173+ if (IntID == Intrinsic::riscv_vle_mask )
1116411174 Ops.push_back(DAG.getTargetConstant(RISCVII::TAIL_AGNOSTIC, DL, XLenVT));
1116511175
1116611176 SDVTList VTs = DAG.getVTList({ContainerVT, MVT::Other});
1116711177
1116811178 SDValue Result =
1116911179 DAG.getMemIntrinsicNode(ISD::INTRINSIC_W_CHAIN, DL, VTs, Ops, MemVT, MMO);
1117011180 Chain = Result.getValue(1);
11181+ if (ExpandingVL) {
11182+ MVT IndexVT = ContainerVT;
11183+ if (ContainerVT.isFloatingPoint())
11184+ IndexVT = ContainerVT.changeVectorElementTypeToInteger();
11185+
11186+ MVT IndexEltVT = IndexVT.getVectorElementType();
11187+ bool UseVRGATHEREI16 = false;
11188+ // If index vector is an i8 vector and the element count exceeds 256, we
11189+ // should change the element type of index vector to i16 to avoid
11190+ // overflow.
11191+ if (IndexEltVT == MVT::i8 && VT.getVectorNumElements() > 256) {
11192+ // FIXME: We need to do vector splitting manually for LMUL=8 cases.
11193+ assert(getLMUL(IndexVT) != RISCVII::LMUL_8);
11194+ IndexVT = IndexVT.changeVectorElementType(MVT::i16);
11195+ UseVRGATHEREI16 = true;
11196+ }
11197+
11198+ SDValue Iota =
11199+ DAG.getNode(ISD::INTRINSIC_WO_CHAIN, DL, IndexVT,
11200+ DAG.getConstant(Intrinsic::riscv_viota, DL, XLenVT),
11201+ DAG.getUNDEF(IndexVT), Mask, ExpandingVL);
11202+ Result =
11203+ DAG.getNode(UseVRGATHEREI16 ? RISCVISD::VRGATHEREI16_VV_VL
11204+ : RISCVISD::VRGATHER_VV_VL,
11205+ DL, ContainerVT, Result, Iota, PassThru, Mask, ExpandingVL);
11206+ }
1117111207
1117211208 if (VT.isFixedLengthVector())
1117311209 Result = convertFromScalableVector(VT, Result, DAG, Subtarget);
0 commit comments