Skip to content

Commit 7fd16ee

Browse files
heihertru
authored andcommitted
[LoongArch] Fix failure to widen operand for [X]VMSK{LT,GE,NE}Z (llvm#149442)
Reported-by: tangyan <[email protected]> (cherry picked from commit 8a307ae)
1 parent 25c1d7a commit 7fd16ee

File tree

2 files changed

+139
-97
lines changed

2 files changed

+139
-97
lines changed

llvm/lib/Target/LoongArch/LoongArchISelLowering.cpp

Lines changed: 124 additions & 97 deletions
Original file line numberDiff line numberDiff line change
@@ -4563,6 +4563,80 @@ static SDValue signExtendBitcastSrcVector(SelectionDAG &DAG, EVT SExtVT,
45634563
llvm_unreachable("Unexpected node type for vXi1 sign extension");
45644564
}
45654565

4566+
static SDValue
4567+
performSETCC_BITCASTCombine(SDNode *N, SelectionDAG &DAG,
4568+
TargetLowering::DAGCombinerInfo &DCI,
4569+
const LoongArchSubtarget &Subtarget) {
4570+
SDLoc DL(N);
4571+
EVT VT = N->getValueType(0);
4572+
SDValue Src = N->getOperand(0);
4573+
EVT SrcVT = Src.getValueType();
4574+
4575+
if (Src.getOpcode() != ISD::SETCC || !Src.hasOneUse())
4576+
return SDValue();
4577+
4578+
bool UseLASX;
4579+
unsigned Opc = ISD::DELETED_NODE;
4580+
EVT CmpVT = Src.getOperand(0).getValueType();
4581+
EVT EltVT = CmpVT.getVectorElementType();
4582+
4583+
if (Subtarget.hasExtLSX() && CmpVT.getSizeInBits() == 128)
4584+
UseLASX = false;
4585+
else if (Subtarget.has32S() && Subtarget.hasExtLASX() &&
4586+
CmpVT.getSizeInBits() == 256)
4587+
UseLASX = true;
4588+
else
4589+
return SDValue();
4590+
4591+
SDValue SrcN1 = Src.getOperand(1);
4592+
switch (cast<CondCodeSDNode>(Src.getOperand(2))->get()) {
4593+
default:
4594+
break;
4595+
case ISD::SETEQ:
4596+
// x == 0 => not (vmsknez.b x)
4597+
if (ISD::isBuildVectorAllZeros(SrcN1.getNode()) && EltVT == MVT::i8)
4598+
Opc = UseLASX ? LoongArchISD::XVMSKEQZ : LoongArchISD::VMSKEQZ;
4599+
break;
4600+
case ISD::SETGT:
4601+
// x > -1 => vmskgez.b x
4602+
if (ISD::isBuildVectorAllOnes(SrcN1.getNode()) && EltVT == MVT::i8)
4603+
Opc = UseLASX ? LoongArchISD::XVMSKGEZ : LoongArchISD::VMSKGEZ;
4604+
break;
4605+
case ISD::SETGE:
4606+
// x >= 0 => vmskgez.b x
4607+
if (ISD::isBuildVectorAllZeros(SrcN1.getNode()) && EltVT == MVT::i8)
4608+
Opc = UseLASX ? LoongArchISD::XVMSKGEZ : LoongArchISD::VMSKGEZ;
4609+
break;
4610+
case ISD::SETLT:
4611+
// x < 0 => vmskltz.{b,h,w,d} x
4612+
if (ISD::isBuildVectorAllZeros(SrcN1.getNode()) &&
4613+
(EltVT == MVT::i8 || EltVT == MVT::i16 || EltVT == MVT::i32 ||
4614+
EltVT == MVT::i64))
4615+
Opc = UseLASX ? LoongArchISD::XVMSKLTZ : LoongArchISD::VMSKLTZ;
4616+
break;
4617+
case ISD::SETLE:
4618+
// x <= -1 => vmskltz.{b,h,w,d} x
4619+
if (ISD::isBuildVectorAllOnes(SrcN1.getNode()) &&
4620+
(EltVT == MVT::i8 || EltVT == MVT::i16 || EltVT == MVT::i32 ||
4621+
EltVT == MVT::i64))
4622+
Opc = UseLASX ? LoongArchISD::XVMSKLTZ : LoongArchISD::VMSKLTZ;
4623+
break;
4624+
case ISD::SETNE:
4625+
// x != 0 => vmsknez.b x
4626+
if (ISD::isBuildVectorAllZeros(SrcN1.getNode()) && EltVT == MVT::i8)
4627+
Opc = UseLASX ? LoongArchISD::XVMSKNEZ : LoongArchISD::VMSKNEZ;
4628+
break;
4629+
}
4630+
4631+
if (Opc == ISD::DELETED_NODE)
4632+
return SDValue();
4633+
4634+
SDValue V = DAG.getNode(Opc, DL, MVT::i64, Src.getOperand(0));
4635+
EVT T = EVT::getIntegerVT(*DAG.getContext(), SrcVT.getVectorNumElements());
4636+
V = DAG.getZExtOrTrunc(V, DL, T);
4637+
return DAG.getBitcast(VT, V);
4638+
}
4639+
45664640
static SDValue performBITCASTCombine(SDNode *N, SelectionDAG &DAG,
45674641
TargetLowering::DAGCombinerInfo &DCI,
45684642
const LoongArchSubtarget &Subtarget) {
@@ -4577,110 +4651,63 @@ static SDValue performBITCASTCombine(SDNode *N, SelectionDAG &DAG,
45774651
if (!SrcVT.isSimple() || SrcVT.getScalarType() != MVT::i1)
45784652
return SDValue();
45794653

4580-
unsigned Opc = ISD::DELETED_NODE;
45814654
// Combine SETCC and BITCAST into [X]VMSK{LT,GE,NE} when possible
4655+
SDValue Res = performSETCC_BITCASTCombine(N, DAG, DCI, Subtarget);
4656+
if (Res)
4657+
return Res;
4658+
4659+
// Generate vXi1 using [X]VMSKLTZ
4660+
MVT SExtVT;
4661+
unsigned Opc;
4662+
bool UseLASX = false;
4663+
bool PropagateSExt = false;
4664+
45824665
if (Src.getOpcode() == ISD::SETCC && Src.hasOneUse()) {
4583-
bool UseLASX;
45844666
EVT CmpVT = Src.getOperand(0).getValueType();
4585-
EVT EltVT = CmpVT.getVectorElementType();
4586-
4587-
if (Subtarget.hasExtLSX() && CmpVT.getSizeInBits() <= 128)
4588-
UseLASX = false;
4589-
else if (Subtarget.has32S() && Subtarget.hasExtLASX() &&
4590-
CmpVT.getSizeInBits() <= 256)
4591-
UseLASX = true;
4592-
else
4667+
if (CmpVT.getSizeInBits() > 256)
45934668
return SDValue();
4594-
4595-
SDValue SrcN1 = Src.getOperand(1);
4596-
switch (cast<CondCodeSDNode>(Src.getOperand(2))->get()) {
4597-
default:
4598-
break;
4599-
case ISD::SETEQ:
4600-
// x == 0 => not (vmsknez.b x)
4601-
if (ISD::isBuildVectorAllZeros(SrcN1.getNode()) && EltVT == MVT::i8)
4602-
Opc = UseLASX ? LoongArchISD::XVMSKEQZ : LoongArchISD::VMSKEQZ;
4603-
break;
4604-
case ISD::SETGT:
4605-
// x > -1 => vmskgez.b x
4606-
if (ISD::isBuildVectorAllOnes(SrcN1.getNode()) && EltVT == MVT::i8)
4607-
Opc = UseLASX ? LoongArchISD::XVMSKGEZ : LoongArchISD::VMSKGEZ;
4608-
break;
4609-
case ISD::SETGE:
4610-
// x >= 0 => vmskgez.b x
4611-
if (ISD::isBuildVectorAllZeros(SrcN1.getNode()) && EltVT == MVT::i8)
4612-
Opc = UseLASX ? LoongArchISD::XVMSKGEZ : LoongArchISD::VMSKGEZ;
4613-
break;
4614-
case ISD::SETLT:
4615-
// x < 0 => vmskltz.{b,h,w,d} x
4616-
if (ISD::isBuildVectorAllZeros(SrcN1.getNode()) &&
4617-
(EltVT == MVT::i8 || EltVT == MVT::i16 || EltVT == MVT::i32 ||
4618-
EltVT == MVT::i64))
4619-
Opc = UseLASX ? LoongArchISD::XVMSKLTZ : LoongArchISD::VMSKLTZ;
4620-
break;
4621-
case ISD::SETLE:
4622-
// x <= -1 => vmskltz.{b,h,w,d} x
4623-
if (ISD::isBuildVectorAllOnes(SrcN1.getNode()) &&
4624-
(EltVT == MVT::i8 || EltVT == MVT::i16 || EltVT == MVT::i32 ||
4625-
EltVT == MVT::i64))
4626-
Opc = UseLASX ? LoongArchISD::XVMSKLTZ : LoongArchISD::VMSKLTZ;
4627-
break;
4628-
case ISD::SETNE:
4629-
// x != 0 => vmsknez.b x
4630-
if (ISD::isBuildVectorAllZeros(SrcN1.getNode()) && EltVT == MVT::i8)
4631-
Opc = UseLASX ? LoongArchISD::XVMSKNEZ : LoongArchISD::VMSKNEZ;
4632-
break;
4633-
}
46344669
}
46354670

4636-
// Generate vXi1 using [X]VMSKLTZ
4637-
if (Opc == ISD::DELETED_NODE) {
4638-
MVT SExtVT;
4639-
bool UseLASX = false;
4640-
bool PropagateSExt = false;
4641-
switch (SrcVT.getSimpleVT().SimpleTy) {
4642-
default:
4643-
return SDValue();
4644-
case MVT::v2i1:
4645-
SExtVT = MVT::v2i64;
4646-
break;
4647-
case MVT::v4i1:
4648-
SExtVT = MVT::v4i32;
4649-
if (Subtarget.hasExtLASX() && checkBitcastSrcVectorSize(Src, 256, 0)) {
4650-
SExtVT = MVT::v4i64;
4651-
UseLASX = true;
4652-
PropagateSExt = true;
4653-
}
4654-
break;
4655-
case MVT::v8i1:
4656-
SExtVT = MVT::v8i16;
4657-
if (Subtarget.hasExtLASX() && checkBitcastSrcVectorSize(Src, 256, 0)) {
4658-
SExtVT = MVT::v8i32;
4659-
UseLASX = true;
4660-
PropagateSExt = true;
4661-
}
4662-
break;
4663-
case MVT::v16i1:
4664-
SExtVT = MVT::v16i8;
4665-
if (Subtarget.hasExtLASX() && checkBitcastSrcVectorSize(Src, 256, 0)) {
4666-
SExtVT = MVT::v16i16;
4667-
UseLASX = true;
4668-
PropagateSExt = true;
4669-
}
4670-
break;
4671-
case MVT::v32i1:
4672-
SExtVT = MVT::v32i8;
4671+
switch (SrcVT.getSimpleVT().SimpleTy) {
4672+
default:
4673+
return SDValue();
4674+
case MVT::v2i1:
4675+
SExtVT = MVT::v2i64;
4676+
break;
4677+
case MVT::v4i1:
4678+
SExtVT = MVT::v4i32;
4679+
if (Subtarget.hasExtLASX() && checkBitcastSrcVectorSize(Src, 256, 0)) {
4680+
SExtVT = MVT::v4i64;
46734681
UseLASX = true;
4674-
break;
4675-
};
4676-
if (UseLASX && !Subtarget.has32S() && !Subtarget.hasExtLASX())
4677-
return SDValue();
4678-
Src = PropagateSExt ? signExtendBitcastSrcVector(DAG, SExtVT, Src, DL)
4679-
: DAG.getNode(ISD::SIGN_EXTEND, DL, SExtVT, Src);
4680-
Opc = UseLASX ? LoongArchISD::XVMSKLTZ : LoongArchISD::VMSKLTZ;
4681-
} else {
4682-
Src = Src.getOperand(0);
4683-
}
4682+
PropagateSExt = true;
4683+
}
4684+
break;
4685+
case MVT::v8i1:
4686+
SExtVT = MVT::v8i16;
4687+
if (Subtarget.hasExtLASX() && checkBitcastSrcVectorSize(Src, 256, 0)) {
4688+
SExtVT = MVT::v8i32;
4689+
UseLASX = true;
4690+
PropagateSExt = true;
4691+
}
4692+
break;
4693+
case MVT::v16i1:
4694+
SExtVT = MVT::v16i8;
4695+
if (Subtarget.hasExtLASX() && checkBitcastSrcVectorSize(Src, 256, 0)) {
4696+
SExtVT = MVT::v16i16;
4697+
UseLASX = true;
4698+
PropagateSExt = true;
4699+
}
4700+
break;
4701+
case MVT::v32i1:
4702+
SExtVT = MVT::v32i8;
4703+
UseLASX = true;
4704+
break;
4705+
};
4706+
if (UseLASX && !(Subtarget.has32S() && Subtarget.hasExtLASX()))
4707+
return SDValue();
4708+
Src = PropagateSExt ? signExtendBitcastSrcVector(DAG, SExtVT, Src, DL)
4709+
: DAG.getNode(ISD::SIGN_EXTEND, DL, SExtVT, Src);
4710+
Opc = UseLASX ? LoongArchISD::XVMSKLTZ : LoongArchISD::VMSKLTZ;
46844711

46854712
SDValue V = DAG.getNode(Opc, DL, MVT::i64, Src);
46864713
EVT T = EVT::getIntegerVT(*DAG.getContext(), SrcVT.getVectorNumElements());

llvm/test/CodeGen/LoongArch/lsx/vmskcond.ll

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -588,3 +588,18 @@ define i2 @vmsk_trunc_i64(<2 x i64> %a) {
588588
%res = bitcast <2 x i1> %y to i2
589589
ret i2 %res
590590
}
591+
592+
define i4 @vmsk_eq_allzeros_v4i8(<4 x i8> %a) {
593+
; CHECK-LABEL: vmsk_eq_allzeros_v4i8:
594+
; CHECK: # %bb.0:
595+
; CHECK-NEXT: vseqi.b $vr0, $vr0, 0
596+
; CHECK-NEXT: vilvl.b $vr0, $vr0, $vr0
597+
; CHECK-NEXT: vilvl.h $vr0, $vr0, $vr0
598+
; CHECK-NEXT: vslli.w $vr0, $vr0, 24
599+
; CHECK-NEXT: vmskltz.w $vr0, $vr0
600+
; CHECK-NEXT: vpickve2gr.hu $a0, $vr0, 0
601+
; CHECK-NEXT: ret
602+
%1 = icmp eq <4 x i8> %a, zeroinitializer
603+
%2 = bitcast <4 x i1> %1 to i4
604+
ret i4 %2
605+
}

0 commit comments

Comments
 (0)