Skip to content

Commit 541d4cb

Browse files
author
git apple-llvm automerger
committed
Merge commit '1651aa294342' from llvm.org/main into next
2 parents b7b853c + 1651aa2 commit 541d4cb

File tree

5 files changed

+61
-43
lines changed

5 files changed

+61
-43
lines changed

llvm/include/llvm/CodeGen/TargetLowering.h

Lines changed: 22 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -1661,17 +1661,20 @@ class LLVM_ABI TargetLoweringBase {
16611661
/// InputVT should be treated. Either it's legal, needs to be promoted to a
16621662
/// larger size, needs to be expanded to some other code sequence, or the
16631663
/// target has a custom expander for it.
1664-
LegalizeAction getPartialReduceMLAAction(EVT AccVT, EVT InputVT) const {
1665-
PartialReduceActionTypes TypePair = {AccVT.getSimpleVT().SimpleTy,
1666-
InputVT.getSimpleVT().SimpleTy};
1667-
auto It = PartialReduceMLAActions.find(TypePair);
1664+
LegalizeAction getPartialReduceMLAAction(unsigned Opc, EVT AccVT,
1665+
EVT InputVT) const {
1666+
assert(Opc == ISD::PARTIAL_REDUCE_SMLA || Opc == ISD::PARTIAL_REDUCE_UMLA);
1667+
PartialReduceActionTypes Key = {Opc, AccVT.getSimpleVT().SimpleTy,
1668+
InputVT.getSimpleVT().SimpleTy};
1669+
auto It = PartialReduceMLAActions.find(Key);
16681670
return It != PartialReduceMLAActions.end() ? It->second : Expand;
16691671
}
16701672

16711673
/// Return true if a PARTIAL_REDUCE_U/SMLA node with the specified types is
16721674
/// legal or custom for this target.
1673-
bool isPartialReduceMLALegalOrCustom(EVT AccVT, EVT InputVT) const {
1674-
LegalizeAction Action = getPartialReduceMLAAction(AccVT, InputVT);
1675+
bool isPartialReduceMLALegalOrCustom(unsigned Opc, EVT AccVT,
1676+
EVT InputVT) const {
1677+
LegalizeAction Action = getPartialReduceMLAAction(Opc, AccVT, InputVT);
16751678
return Action == Legal || Action == Custom;
16761679
}
16771680

@@ -2756,12 +2759,18 @@ class LLVM_ABI TargetLoweringBase {
27562759
/// type InputVT should be treated by the target. Either it's legal, needs to
27572760
/// be promoted to a larger size, needs to be expanded to some other code
27582761
/// sequence, or the target has a custom expander for it.
2759-
void setPartialReduceMLAAction(MVT AccVT, MVT InputVT,
2762+
void setPartialReduceMLAAction(unsigned Opc, MVT AccVT, MVT InputVT,
27602763
LegalizeAction Action) {
2764+
assert(Opc == ISD::PARTIAL_REDUCE_SMLA || Opc == ISD::PARTIAL_REDUCE_UMLA);
27612765
assert(AccVT.isValid() && InputVT.isValid() &&
27622766
"setPartialReduceMLAAction types aren't valid");
2763-
PartialReduceActionTypes TypePair = {AccVT.SimpleTy, InputVT.SimpleTy};
2764-
PartialReduceMLAActions[TypePair] = Action;
2767+
PartialReduceActionTypes Key = {Opc, AccVT.SimpleTy, InputVT.SimpleTy};
2768+
PartialReduceMLAActions[Key] = Action;
2769+
}
2770+
void setPartialReduceMLAAction(ArrayRef<unsigned> Opcodes, MVT AccVT,
2771+
MVT InputVT, LegalizeAction Action) {
2772+
for (unsigned Opc : Opcodes)
2773+
setPartialReduceMLAAction(Opc, AccVT, InputVT, Action);
27652774
}
27662775

27672776
/// If Opc/OrigVT is specified as being promoted, the promotion code defaults
@@ -3753,10 +3762,10 @@ class LLVM_ABI TargetLoweringBase {
37533762
uint32_t CondCodeActions[ISD::SETCC_INVALID][(MVT::VALUETYPE_SIZE + 7) / 8];
37543763

37553764
using PartialReduceActionTypes =
3756-
std::pair<MVT::SimpleValueType, MVT::SimpleValueType>;
3757-
/// For each result type and input type for the ISD::PARTIAL_REDUCE_U/SMLA
3758-
/// nodes, keep a LegalizeAction which indicates how instruction selection
3759-
/// should deal with this operation.
3765+
std::tuple<unsigned, MVT::SimpleValueType, MVT::SimpleValueType>;
3766+
/// For each partial reduce opcode, result type and input type combination,
3767+
/// keep a LegalizeAction which indicates how instruction selection should
3768+
/// deal with this operation.
37603769
DenseMap<PartialReduceActionTypes, LegalizeAction> PartialReduceMLAActions;
37613770

37623771
ValueTypeActionImpl ValueTypeActions;

llvm/lib/CodeGen/SelectionDAG/DAGCombiner.cpp

Lines changed: 14 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -12673,17 +12673,17 @@ SDValue DAGCombiner::foldPartialReduceMLAMulOp(SDNode *N) {
1267312673
SDValue LHSExtOp = LHS->getOperand(0);
1267412674
EVT LHSExtOpVT = LHSExtOp.getValueType();
1267512675

12676+
bool ExtIsSigned = LHSOpcode == ISD::SIGN_EXTEND;
12677+
unsigned NewOpcode =
12678+
ExtIsSigned ? ISD::PARTIAL_REDUCE_SMLA : ISD::PARTIAL_REDUCE_UMLA;
12679+
1267612680
// Only perform these combines if the target supports folding
1267712681
// the extends into the operation.
1267812682
if (!TLI.isPartialReduceMLALegalOrCustom(
12679-
TLI.getTypeToTransformTo(*Context, N->getValueType(0)),
12683+
NewOpcode, TLI.getTypeToTransformTo(*Context, N->getValueType(0)),
1268012684
TLI.getTypeToTransformTo(*Context, LHSExtOpVT)))
1268112685
return SDValue();
1268212686

12683-
bool ExtIsSigned = LHSOpcode == ISD::SIGN_EXTEND;
12684-
unsigned NewOpcode =
12685-
ExtIsSigned ? ISD::PARTIAL_REDUCE_SMLA : ISD::PARTIAL_REDUCE_UMLA;
12686-
1268712687
// partial_reduce_*mla(acc, mul(ext(x), splat(C)), splat(1))
1268812688
// -> partial_reduce_*mla(acc, x, C)
1268912689
if (ISD::isConstantSplatVector(RHS.getNode(), C)) {
@@ -12737,14 +12737,6 @@ SDValue DAGCombiner::foldPartialReduceAdd(SDNode *N) {
1273712737
if (!ISD::isExtOpcode(Op1Opcode))
1273812738
return SDValue();
1273912739

12740-
SDValue UnextOp1 = Op1.getOperand(0);
12741-
EVT UnextOp1VT = UnextOp1.getValueType();
12742-
auto *Context = DAG.getContext();
12743-
if (!TLI.isPartialReduceMLALegalOrCustom(
12744-
TLI.getTypeToTransformTo(*Context, N->getValueType(0)),
12745-
TLI.getTypeToTransformTo(*Context, UnextOp1VT)))
12746-
return SDValue();
12747-
1274812740
bool Op1IsSigned = Op1Opcode == ISD::SIGN_EXTEND;
1274912741
bool NodeIsSigned = N->getOpcode() == ISD::PARTIAL_REDUCE_SMLA;
1275012742
EVT AccElemVT = Acc.getValueType().getVectorElementType();
@@ -12754,6 +12746,15 @@ SDValue DAGCombiner::foldPartialReduceAdd(SDNode *N) {
1275412746

1275512747
unsigned NewOpcode =
1275612748
Op1IsSigned ? ISD::PARTIAL_REDUCE_SMLA : ISD::PARTIAL_REDUCE_UMLA;
12749+
12750+
SDValue UnextOp1 = Op1.getOperand(0);
12751+
EVT UnextOp1VT = UnextOp1.getValueType();
12752+
auto *Context = DAG.getContext();
12753+
if (!TLI.isPartialReduceMLALegalOrCustom(
12754+
NewOpcode, TLI.getTypeToTransformTo(*Context, N->getValueType(0)),
12755+
TLI.getTypeToTransformTo(*Context, UnextOp1VT)))
12756+
return SDValue();
12757+
1275712758
return DAG.getNode(NewOpcode, DL, N->getValueType(0), Acc, UnextOp1,
1275812759
DAG.getConstant(1, DL, UnextOp1VT));
1275912760
}

llvm/lib/CodeGen/SelectionDAG/LegalizeVectorOps.cpp

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -530,8 +530,9 @@ SDValue VectorLegalizer::LegalizeOp(SDValue Op) {
530530
}
531531
case ISD::PARTIAL_REDUCE_UMLA:
532532
case ISD::PARTIAL_REDUCE_SMLA:
533-
Action = TLI.getPartialReduceMLAAction(Node->getValueType(0),
534-
Node->getOperand(1).getValueType());
533+
Action =
534+
TLI.getPartialReduceMLAAction(Op.getOpcode(), Node->getValueType(0),
535+
Node->getOperand(1).getValueType());
535536
break;
536537

537538
#define BEGIN_REGISTER_VP_SDNODE(VPID, LEGALPOS, ...) \

llvm/lib/Target/AArch64/AArch64ISelLowering.cpp

Lines changed: 14 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -1458,9 +1458,12 @@ AArch64TargetLowering::AArch64TargetLowering(const TargetMachine &TM,
14581458
setOperationAction(ISD::FADD, VT, Custom);
14591459

14601460
if (EnablePartialReduceNodes && Subtarget->hasDotProd()) {
1461-
setPartialReduceMLAAction(MVT::v4i32, MVT::v16i8, Legal);
1462-
setPartialReduceMLAAction(MVT::v2i32, MVT::v8i8, Legal);
1463-
setPartialReduceMLAAction(MVT::v2i64, MVT::v16i8, Custom);
1461+
static const unsigned MLAOps[] = {ISD::PARTIAL_REDUCE_SMLA,
1462+
ISD::PARTIAL_REDUCE_UMLA};
1463+
1464+
setPartialReduceMLAAction(MLAOps, MVT::v4i32, MVT::v16i8, Legal);
1465+
setPartialReduceMLAAction(MLAOps, MVT::v2i32, MVT::v8i8, Legal);
1466+
setPartialReduceMLAAction(MLAOps, MVT::v2i64, MVT::v16i8, Custom);
14641467
}
14651468

14661469
} else /* !isNeonAvailable */ {
@@ -1881,16 +1884,18 @@ AArch64TargetLowering::AArch64TargetLowering(const TargetMachine &TM,
18811884
if (EnablePartialReduceNodes && Subtarget->isSVEorStreamingSVEAvailable()) {
18821885
// Mark known legal pairs as 'Legal' (these will expand to UDOT or SDOT).
18831886
// Other pairs will default to 'Expand'.
1884-
setPartialReduceMLAAction(MVT::nxv2i64, MVT::nxv8i16, Legal);
1885-
setPartialReduceMLAAction(MVT::nxv4i32, MVT::nxv16i8, Legal);
1887+
static const unsigned MLAOps[] = {ISD::PARTIAL_REDUCE_SMLA,
1888+
ISD::PARTIAL_REDUCE_UMLA};
1889+
setPartialReduceMLAAction(MLAOps, MVT::nxv2i64, MVT::nxv8i16, Legal);
1890+
setPartialReduceMLAAction(MLAOps, MVT::nxv4i32, MVT::nxv16i8, Legal);
18861891

1887-
setPartialReduceMLAAction(MVT::nxv2i64, MVT::nxv16i8, Custom);
1892+
setPartialReduceMLAAction(MLAOps, MVT::nxv2i64, MVT::nxv16i8, Custom);
18881893

18891894
// Wide add types
18901895
if (Subtarget->hasSVE2() || Subtarget->hasSME()) {
1891-
setPartialReduceMLAAction(MVT::nxv2i64, MVT::nxv4i32, Legal);
1892-
setPartialReduceMLAAction(MVT::nxv4i32, MVT::nxv8i16, Legal);
1893-
setPartialReduceMLAAction(MVT::nxv8i16, MVT::nxv16i8, Legal);
1896+
setPartialReduceMLAAction(MLAOps, MVT::nxv2i64, MVT::nxv4i32, Legal);
1897+
setPartialReduceMLAAction(MLAOps, MVT::nxv4i32, MVT::nxv8i16, Legal);
1898+
setPartialReduceMLAAction(MLAOps, MVT::nxv8i16, MVT::nxv16i8, Legal);
18941899
}
18951900
}
18961901

llvm/lib/Target/RISCV/RISCVISelLowering.cpp

Lines changed: 8 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1575,11 +1575,13 @@ RISCVTargetLowering::RISCVTargetLowering(const TargetMachine &TM,
15751575

15761576
// zve32x is broken for partial_reduce_umla, but let's not make it worse.
15771577
if (Subtarget.hasStdExtZvqdotq() && Subtarget.getELen() >= 64) {
1578-
setPartialReduceMLAAction(MVT::nxv1i32, MVT::nxv4i8, Custom);
1579-
setPartialReduceMLAAction(MVT::nxv2i32, MVT::nxv8i8, Custom);
1580-
setPartialReduceMLAAction(MVT::nxv4i32, MVT::nxv16i8, Custom);
1581-
setPartialReduceMLAAction(MVT::nxv8i32, MVT::nxv32i8, Custom);
1582-
setPartialReduceMLAAction(MVT::nxv16i32, MVT::nxv64i8, Custom);
1578+
static const unsigned MLAOps[] = {ISD::PARTIAL_REDUCE_SMLA,
1579+
ISD::PARTIAL_REDUCE_UMLA};
1580+
setPartialReduceMLAAction(MLAOps, MVT::nxv1i32, MVT::nxv4i8, Custom);
1581+
setPartialReduceMLAAction(MLAOps, MVT::nxv2i32, MVT::nxv8i8, Custom);
1582+
setPartialReduceMLAAction(MLAOps, MVT::nxv4i32, MVT::nxv16i8, Custom);
1583+
setPartialReduceMLAAction(MLAOps, MVT::nxv8i32, MVT::nxv32i8, Custom);
1584+
setPartialReduceMLAAction(MLAOps, MVT::nxv16i32, MVT::nxv64i8, Custom);
15831585

15841586
if (Subtarget.useRVVForFixedLengthVectors()) {
15851587
for (MVT VT : MVT::integer_fixedlen_vector_valuetypes()) {
@@ -1588,7 +1590,7 @@ RISCVTargetLowering::RISCVTargetLowering(const TargetMachine &TM,
15881590
continue;
15891591
ElementCount EC = VT.getVectorElementCount();
15901592
MVT ArgVT = MVT::getVectorVT(MVT::i8, EC.multiplyCoefficientBy(4));
1591-
setPartialReduceMLAAction(VT, ArgVT, Custom);
1593+
setPartialReduceMLAAction(MLAOps, VT, ArgVT, Custom);
15921594
}
15931595
}
15941596
}

0 commit comments

Comments
 (0)