@@ -6971,7 +6971,7 @@ static bool hasPassthruOp(unsigned Opcode) {
69716971 Opcode <= RISCVISD::LAST_STRICTFP_OPCODE &&
69726972 "not a RISC-V target specific op");
69736973 static_assert(
6974- RISCVISD::LAST_VL_VECTOR_OP - RISCVISD::FIRST_VL_VECTOR_OP == 134 &&
6974+ RISCVISD::LAST_VL_VECTOR_OP - RISCVISD::FIRST_VL_VECTOR_OP == 139 &&
69756975 RISCVISD::LAST_STRICTFP_OPCODE - RISCVISD::FIRST_STRICTFP_OPCODE == 21 &&
69766976 "adding target specific op should update this function");
69776977 if (Opcode >= RISCVISD::ADD_VL && Opcode <= RISCVISD::VFMAX_VL)
@@ -6995,7 +6995,7 @@ static bool hasMaskOp(unsigned Opcode) {
69956995 Opcode <= RISCVISD::LAST_STRICTFP_OPCODE &&
69966996 "not a RISC-V target specific op");
69976997 static_assert(
6998- RISCVISD::LAST_VL_VECTOR_OP - RISCVISD::FIRST_VL_VECTOR_OP == 134 &&
6998+ RISCVISD::LAST_VL_VECTOR_OP - RISCVISD::FIRST_VL_VECTOR_OP == 139 &&
69996999 RISCVISD::LAST_STRICTFP_OPCODE - RISCVISD::FIRST_STRICTFP_OPCODE == 21 &&
70007000 "adding target specific op should update this function");
70017001 if (Opcode >= RISCVISD::TRUNCATE_VECTOR_VL && Opcode <= RISCVISD::SETCC_VL)
@@ -18101,6 +18101,118 @@ static SDValue performBUILD_VECTORCombine(SDNode *N, SelectionDAG &DAG,
1810118101 DAG.getBuildVector(VT, DL, RHSOps));
1810218102}
1810318103
18104+ static SDValue lowerVQDOT(unsigned Opc, SDValue Op0, SDValue Op1,
18105+ const SDLoc &DL, SelectionDAG &DAG,
18106+ const RISCVSubtarget &Subtarget) {
18107+ assert(RISCVISD::VQDOT_VL == Opc || RISCVISD::VQDOTU_VL == Opc ||
18108+ RISCVISD::VQDOTSU_VL == Opc);
18109+ MVT VT = Op0.getSimpleValueType();
18110+ assert(VT == Op1.getSimpleValueType() &&
18111+ VT.getVectorElementType() == MVT::i32);
18112+
18113+ assert(VT.isFixedLengthVector());
18114+ MVT ContainerVT = getContainerForFixedLengthVector(DAG, VT, Subtarget);
18115+ SDValue Passthru = convertToScalableVector(
18116+ ContainerVT, DAG.getConstant(0, DL, VT), DAG, Subtarget);
18117+ Op0 = convertToScalableVector(ContainerVT, Op0, DAG, Subtarget);
18118+ Op1 = convertToScalableVector(ContainerVT, Op1, DAG, Subtarget);
18119+
18120+ auto [Mask, VL] = getDefaultVLOps(VT, ContainerVT, DL, DAG, Subtarget);
18121+ const unsigned Policy = RISCVVType::TAIL_AGNOSTIC | RISCVVType::MASK_AGNOSTIC;
18122+ SDValue PolicyOp = DAG.getTargetConstant(Policy, DL, Subtarget.getXLenVT());
18123+ SDValue LocalAccum = DAG.getNode(Opc, DL, ContainerVT,
18124+ {Op0, Op1, Passthru, Mask, VL, PolicyOp});
18125+ return convertFromScalableVector(VT, LocalAccum, DAG, Subtarget);
18126+ }
18127+
18128+ static MVT getQDOTXResultType(MVT OpVT) {
18129+ ElementCount OpEC = OpVT.getVectorElementCount();
18130+ assert(OpEC.isKnownMultipleOf(4) && OpVT.getVectorElementType() == MVT::i8);
18131+ return MVT::getVectorVT(MVT::i32, OpEC.divideCoefficientBy(4));
18132+ }
18133+
18134+ static SDValue foldReduceOperandViaVQDOT(SDValue InVec, const SDLoc &DL,
18135+ SelectionDAG &DAG,
18136+ const RISCVSubtarget &Subtarget,
18137+ const RISCVTargetLowering &TLI) {
18138+ // Note: We intentionally do not check the legality of the reduction type.
18139+ // We want to handle the m4/m8 *src* types, and thus need to let illegal
18140+ // intermediate types flow through here.
18141+ if (InVec.getValueType().getVectorElementType() != MVT::i32 ||
18142+ !InVec.getValueType().getVectorElementCount().isKnownMultipleOf(4))
18143+ return SDValue();
18144+
18145+ // reduce (zext a) <--> reduce (mul zext a. zext 1)
18146+ // reduce (sext a) <--> reduce (mul sext a. sext 1)
18147+ if (InVec.getOpcode() == ISD::ZERO_EXTEND ||
18148+ InVec.getOpcode() == ISD::SIGN_EXTEND) {
18149+ SDValue A = InVec.getOperand(0);
18150+ if (A.getValueType().getVectorElementType() != MVT::i8 ||
18151+ !TLI.isTypeLegal(A.getValueType()))
18152+ return SDValue();
18153+
18154+ MVT ResVT = getQDOTXResultType(A.getSimpleValueType());
18155+ A = DAG.getBitcast(ResVT, A);
18156+ SDValue B = DAG.getConstant(0x01010101, DL, ResVT);
18157+
18158+ bool IsSigned = InVec.getOpcode() == ISD::SIGN_EXTEND;
18159+ unsigned Opc = IsSigned ? RISCVISD::VQDOT_VL : RISCVISD::VQDOTU_VL;
18160+ return lowerVQDOT(Opc, A, B, DL, DAG, Subtarget);
18161+ }
18162+
18163+ // mul (sext, sext) -> vqdot
18164+ // mul (zext, zext) -> vqdotu
18165+ // mul (sext, zext) -> vqdotsu
18166+ // mul (zext, sext) -> vqdotsu (swapped)
18167+ // TODO: Improve .vx handling - we end up with a sub-vector insert
18168+ // which confuses the splat pattern matching. Also, match vqdotus.vx
18169+ if (InVec.getOpcode() != ISD::MUL)
18170+ return SDValue();
18171+
18172+ SDValue A = InVec.getOperand(0);
18173+ SDValue B = InVec.getOperand(1);
18174+ unsigned Opc = 0;
18175+ if (A.getOpcode() == B.getOpcode()) {
18176+ if (A.getOpcode() == ISD::SIGN_EXTEND)
18177+ Opc = RISCVISD::VQDOT_VL;
18178+ else if (A.getOpcode() == ISD::ZERO_EXTEND)
18179+ Opc = RISCVISD::VQDOTU_VL;
18180+ else
18181+ return SDValue();
18182+ } else {
18183+ if (B.getOpcode() != ISD::ZERO_EXTEND)
18184+ std::swap(A, B);
18185+ if (A.getOpcode() != ISD::SIGN_EXTEND || B.getOpcode() != ISD::ZERO_EXTEND)
18186+ return SDValue();
18187+ Opc = RISCVISD::VQDOTSU_VL;
18188+ }
18189+ assert(Opc);
18190+
18191+ if (A.getOperand(0).getValueType().getVectorElementType() != MVT::i8 ||
18192+ A.getOperand(0).getValueType() != B.getOperand(0).getValueType() ||
18193+ !TLI.isTypeLegal(A.getValueType()))
18194+ return SDValue();
18195+
18196+ MVT ResVT = getQDOTXResultType(A.getOperand(0).getSimpleValueType());
18197+ A = DAG.getBitcast(ResVT, A.getOperand(0));
18198+ B = DAG.getBitcast(ResVT, B.getOperand(0));
18199+ return lowerVQDOT(Opc, A, B, DL, DAG, Subtarget);
18200+ }
18201+
18202+ static SDValue performVECREDUCECombine(SDNode *N, SelectionDAG &DAG,
18203+ const RISCVSubtarget &Subtarget,
18204+ const RISCVTargetLowering &TLI) {
18205+ if (!Subtarget.hasStdExtZvqdotq())
18206+ return SDValue();
18207+
18208+ SDLoc DL(N);
18209+ EVT VT = N->getValueType(0);
18210+ SDValue InVec = N->getOperand(0);
18211+ if (SDValue V = foldReduceOperandViaVQDOT(InVec, DL, DAG, Subtarget, TLI))
18212+ return DAG.getNode(ISD::VECREDUCE_ADD, DL, VT, V);
18213+ return SDValue();
18214+ }
18215+
1810418216static SDValue performINSERT_VECTOR_ELTCombine(SDNode *N, SelectionDAG &DAG,
1810518217 const RISCVSubtarget &Subtarget,
1810618218 const RISCVTargetLowering &TLI) {
@@ -19878,8 +19990,11 @@ SDValue RISCVTargetLowering::PerformDAGCombine(SDNode *N,
1987819990
1987919991 return SDValue();
1988019992 }
19881- case ISD::CTPOP:
1988219993 case ISD::VECREDUCE_ADD:
19994+ if (SDValue V = performVECREDUCECombine(N, DAG, Subtarget, *this))
19995+ return V;
19996+ [[fallthrough]];
19997+ case ISD::CTPOP:
1988319998 if (SDValue V = combineToVCPOP(N, DAG, Subtarget))
1988419999 return V;
1988520000 break;
@@ -22401,6 +22516,9 @@ const char *RISCVTargetLowering::getTargetNodeName(unsigned Opcode) const {
2240122516 NODE_NAME_CASE(RI_VUNZIP2A_VL)
2240222517 NODE_NAME_CASE(RI_VUNZIP2B_VL)
2240322518 NODE_NAME_CASE(RI_VEXTRACT)
22519+ NODE_NAME_CASE(VQDOT_VL)
22520+ NODE_NAME_CASE(VQDOTU_VL)
22521+ NODE_NAME_CASE(VQDOTSU_VL)
2240422522 NODE_NAME_CASE(READ_CSR)
2240522523 NODE_NAME_CASE(WRITE_CSR)
2240622524 NODE_NAME_CASE(SWAP_CSR)
0 commit comments