@@ -4563,6 +4563,80 @@ static SDValue signExtendBitcastSrcVector(SelectionDAG &DAG, EVT SExtVT,
45634563 llvm_unreachable (" Unexpected node type for vXi1 sign extension" );
45644564}
45654565
4566+ static SDValue
4567+ performSETCC_BITCASTCombine (SDNode *N, SelectionDAG &DAG,
4568+ TargetLowering::DAGCombinerInfo &DCI,
4569+ const LoongArchSubtarget &Subtarget) {
4570+ SDLoc DL (N);
4571+ EVT VT = N->getValueType (0 );
4572+ SDValue Src = N->getOperand (0 );
4573+ EVT SrcVT = Src.getValueType ();
4574+
4575+ if (Src.getOpcode () != ISD::SETCC || !Src.hasOneUse ())
4576+ return SDValue ();
4577+
4578+ bool UseLASX;
4579+ unsigned Opc = ISD::DELETED_NODE;
4580+ EVT CmpVT = Src.getOperand (0 ).getValueType ();
4581+ EVT EltVT = CmpVT.getVectorElementType ();
4582+
4583+ if (Subtarget.hasExtLSX () && CmpVT.getSizeInBits () == 128 )
4584+ UseLASX = false ;
4585+ else if (Subtarget.has32S () && Subtarget.hasExtLASX () &&
4586+ CmpVT.getSizeInBits () == 256 )
4587+ UseLASX = true ;
4588+ else
4589+ return SDValue ();
4590+
4591+ SDValue SrcN1 = Src.getOperand (1 );
4592+ switch (cast<CondCodeSDNode>(Src.getOperand (2 ))->get ()) {
4593+ default :
4594+ break ;
4595+ case ISD::SETEQ:
4596+ // x == 0 => not (vmsknez.b x)
4597+ if (ISD::isBuildVectorAllZeros (SrcN1.getNode ()) && EltVT == MVT::i8 )
4598+ Opc = UseLASX ? LoongArchISD::XVMSKEQZ : LoongArchISD::VMSKEQZ;
4599+ break ;
4600+ case ISD::SETGT:
4601+ // x > -1 => vmskgez.b x
4602+ if (ISD::isBuildVectorAllOnes (SrcN1.getNode ()) && EltVT == MVT::i8 )
4603+ Opc = UseLASX ? LoongArchISD::XVMSKGEZ : LoongArchISD::VMSKGEZ;
4604+ break ;
4605+ case ISD::SETGE:
4606+ // x >= 0 => vmskgez.b x
4607+ if (ISD::isBuildVectorAllZeros (SrcN1.getNode ()) && EltVT == MVT::i8 )
4608+ Opc = UseLASX ? LoongArchISD::XVMSKGEZ : LoongArchISD::VMSKGEZ;
4609+ break ;
4610+ case ISD::SETLT:
4611+ // x < 0 => vmskltz.{b,h,w,d} x
4612+ if (ISD::isBuildVectorAllZeros (SrcN1.getNode ()) &&
4613+ (EltVT == MVT::i8 || EltVT == MVT::i16 || EltVT == MVT::i32 ||
4614+ EltVT == MVT::i64 ))
4615+ Opc = UseLASX ? LoongArchISD::XVMSKLTZ : LoongArchISD::VMSKLTZ;
4616+ break ;
4617+ case ISD::SETLE:
4618+ // x <= -1 => vmskltz.{b,h,w,d} x
4619+ if (ISD::isBuildVectorAllOnes (SrcN1.getNode ()) &&
4620+ (EltVT == MVT::i8 || EltVT == MVT::i16 || EltVT == MVT::i32 ||
4621+ EltVT == MVT::i64 ))
4622+ Opc = UseLASX ? LoongArchISD::XVMSKLTZ : LoongArchISD::VMSKLTZ;
4623+ break ;
4624+ case ISD::SETNE:
4625+ // x != 0 => vmsknez.b x
4626+ if (ISD::isBuildVectorAllZeros (SrcN1.getNode ()) && EltVT == MVT::i8 )
4627+ Opc = UseLASX ? LoongArchISD::XVMSKNEZ : LoongArchISD::VMSKNEZ;
4628+ break ;
4629+ }
4630+
4631+ if (Opc == ISD::DELETED_NODE)
4632+ return SDValue ();
4633+
4634+ SDValue V = DAG.getNode (Opc, DL, MVT::i64 , Src.getOperand (0 ));
4635+ EVT T = EVT::getIntegerVT (*DAG.getContext (), SrcVT.getVectorNumElements ());
4636+ V = DAG.getZExtOrTrunc (V, DL, T);
4637+ return DAG.getBitcast (VT, V);
4638+ }
4639+
45664640static SDValue performBITCASTCombine (SDNode *N, SelectionDAG &DAG,
45674641 TargetLowering::DAGCombinerInfo &DCI,
45684642 const LoongArchSubtarget &Subtarget) {
@@ -4577,110 +4651,63 @@ static SDValue performBITCASTCombine(SDNode *N, SelectionDAG &DAG,
45774651 if (!SrcVT.isSimple () || SrcVT.getScalarType () != MVT::i1)
45784652 return SDValue ();
45794653
4580- unsigned Opc = ISD::DELETED_NODE;
45814654 // Combine SETCC and BITCAST into [X]VMSK{LT,GE,NE} when possible
4655+ SDValue Res = performSETCC_BITCASTCombine (N, DAG, DCI, Subtarget);
4656+ if (Res)
4657+ return Res;
4658+
4659+ // Generate vXi1 using [X]VMSKLTZ
4660+ MVT SExtVT;
4661+ unsigned Opc;
4662+ bool UseLASX = false ;
4663+ bool PropagateSExt = false ;
4664+
45824665 if (Src.getOpcode () == ISD::SETCC && Src.hasOneUse ()) {
4583- bool UseLASX;
45844666 EVT CmpVT = Src.getOperand (0 ).getValueType ();
4585- EVT EltVT = CmpVT.getVectorElementType ();
4586-
4587- if (Subtarget.hasExtLSX () && CmpVT.getSizeInBits () <= 128 )
4588- UseLASX = false ;
4589- else if (Subtarget.has32S () && Subtarget.hasExtLASX () &&
4590- CmpVT.getSizeInBits () <= 256 )
4591- UseLASX = true ;
4592- else
4667+ if (CmpVT.getSizeInBits () > 256 )
45934668 return SDValue ();
4594-
4595- SDValue SrcN1 = Src.getOperand (1 );
4596- switch (cast<CondCodeSDNode>(Src.getOperand (2 ))->get ()) {
4597- default :
4598- break ;
4599- case ISD::SETEQ:
4600- // x == 0 => not (vmsknez.b x)
4601- if (ISD::isBuildVectorAllZeros (SrcN1.getNode ()) && EltVT == MVT::i8 )
4602- Opc = UseLASX ? LoongArchISD::XVMSKEQZ : LoongArchISD::VMSKEQZ;
4603- break ;
4604- case ISD::SETGT:
4605- // x > -1 => vmskgez.b x
4606- if (ISD::isBuildVectorAllOnes (SrcN1.getNode ()) && EltVT == MVT::i8 )
4607- Opc = UseLASX ? LoongArchISD::XVMSKGEZ : LoongArchISD::VMSKGEZ;
4608- break ;
4609- case ISD::SETGE:
4610- // x >= 0 => vmskgez.b x
4611- if (ISD::isBuildVectorAllZeros (SrcN1.getNode ()) && EltVT == MVT::i8 )
4612- Opc = UseLASX ? LoongArchISD::XVMSKGEZ : LoongArchISD::VMSKGEZ;
4613- break ;
4614- case ISD::SETLT:
4615- // x < 0 => vmskltz.{b,h,w,d} x
4616- if (ISD::isBuildVectorAllZeros (SrcN1.getNode ()) &&
4617- (EltVT == MVT::i8 || EltVT == MVT::i16 || EltVT == MVT::i32 ||
4618- EltVT == MVT::i64 ))
4619- Opc = UseLASX ? LoongArchISD::XVMSKLTZ : LoongArchISD::VMSKLTZ;
4620- break ;
4621- case ISD::SETLE:
4622- // x <= -1 => vmskltz.{b,h,w,d} x
4623- if (ISD::isBuildVectorAllOnes (SrcN1.getNode ()) &&
4624- (EltVT == MVT::i8 || EltVT == MVT::i16 || EltVT == MVT::i32 ||
4625- EltVT == MVT::i64 ))
4626- Opc = UseLASX ? LoongArchISD::XVMSKLTZ : LoongArchISD::VMSKLTZ;
4627- break ;
4628- case ISD::SETNE:
4629- // x != 0 => vmsknez.b x
4630- if (ISD::isBuildVectorAllZeros (SrcN1.getNode ()) && EltVT == MVT::i8 )
4631- Opc = UseLASX ? LoongArchISD::XVMSKNEZ : LoongArchISD::VMSKNEZ;
4632- break ;
4633- }
46344669 }
46354670
4636- // Generate vXi1 using [X]VMSKLTZ
4637- if (Opc == ISD::DELETED_NODE) {
4638- MVT SExtVT;
4639- bool UseLASX = false ;
4640- bool PropagateSExt = false ;
4641- switch (SrcVT.getSimpleVT ().SimpleTy ) {
4642- default :
4643- return SDValue ();
4644- case MVT::v2i1:
4645- SExtVT = MVT::v2i64;
4646- break ;
4647- case MVT::v4i1:
4648- SExtVT = MVT::v4i32;
4649- if (Subtarget.hasExtLASX () && checkBitcastSrcVectorSize (Src, 256 , 0 )) {
4650- SExtVT = MVT::v4i64;
4651- UseLASX = true ;
4652- PropagateSExt = true ;
4653- }
4654- break ;
4655- case MVT::v8i1:
4656- SExtVT = MVT::v8i16;
4657- if (Subtarget.hasExtLASX () && checkBitcastSrcVectorSize (Src, 256 , 0 )) {
4658- SExtVT = MVT::v8i32;
4659- UseLASX = true ;
4660- PropagateSExt = true ;
4661- }
4662- break ;
4663- case MVT::v16i1:
4664- SExtVT = MVT::v16i8;
4665- if (Subtarget.hasExtLASX () && checkBitcastSrcVectorSize (Src, 256 , 0 )) {
4666- SExtVT = MVT::v16i16;
4667- UseLASX = true ;
4668- PropagateSExt = true ;
4669- }
4670- break ;
4671- case MVT::v32i1:
4672- SExtVT = MVT::v32i8;
4671+ switch (SrcVT.getSimpleVT ().SimpleTy ) {
4672+ default :
4673+ return SDValue ();
4674+ case MVT::v2i1:
4675+ SExtVT = MVT::v2i64;
4676+ break ;
4677+ case MVT::v4i1:
4678+ SExtVT = MVT::v4i32;
4679+ if (Subtarget.hasExtLASX () && checkBitcastSrcVectorSize (Src, 256 , 0 )) {
4680+ SExtVT = MVT::v4i64;
46734681 UseLASX = true ;
4674- break ;
4675- };
4676- if (UseLASX && !Subtarget.has32S () && !Subtarget.hasExtLASX ())
4677- return SDValue ();
4678- Src = PropagateSExt ? signExtendBitcastSrcVector (DAG, SExtVT, Src, DL)
4679- : DAG.getNode (ISD::SIGN_EXTEND, DL, SExtVT, Src);
4680- Opc = UseLASX ? LoongArchISD::XVMSKLTZ : LoongArchISD::VMSKLTZ;
4681- } else {
4682- Src = Src.getOperand (0 );
4683- }
4682+ PropagateSExt = true ;
4683+ }
4684+ break ;
4685+ case MVT::v8i1:
4686+ SExtVT = MVT::v8i16;
4687+ if (Subtarget.hasExtLASX () && checkBitcastSrcVectorSize (Src, 256 , 0 )) {
4688+ SExtVT = MVT::v8i32;
4689+ UseLASX = true ;
4690+ PropagateSExt = true ;
4691+ }
4692+ break ;
4693+ case MVT::v16i1:
4694+ SExtVT = MVT::v16i8;
4695+ if (Subtarget.hasExtLASX () && checkBitcastSrcVectorSize (Src, 256 , 0 )) {
4696+ SExtVT = MVT::v16i16;
4697+ UseLASX = true ;
4698+ PropagateSExt = true ;
4699+ }
4700+ break ;
4701+ case MVT::v32i1:
4702+ SExtVT = MVT::v32i8;
4703+ UseLASX = true ;
4704+ break ;
4705+ };
4706+ if (UseLASX && !(Subtarget.has32S () && Subtarget.hasExtLASX ()))
4707+ return SDValue ();
4708+ Src = PropagateSExt ? signExtendBitcastSrcVector (DAG, SExtVT, Src, DL)
4709+ : DAG.getNode (ISD::SIGN_EXTEND, DL, SExtVT, Src);
4710+ Opc = UseLASX ? LoongArchISD::XVMSKLTZ : LoongArchISD::VMSKLTZ;
46844711
46854712 SDValue V = DAG.getNode (Opc, DL, MVT::i64 , Src);
46864713 EVT T = EVT::getIntegerVT (*DAG.getContext (), SrcVT.getVectorNumElements ());
0 commit comments