Skip to content

Commit f480e1b

Browse files
authored
[NVPTX] Add PRMT constant folding and cleanup usage of PRMT node (llvm#148906)
1 parent 3b11aaa commit f480e1b

File tree

5 files changed

+2132
-954
lines changed

5 files changed

+2132
-954
lines changed

llvm/lib/Target/NVPTX/NVPTXISelLowering.cpp

Lines changed: 184 additions & 60 deletions
Original file line numberDiff line numberDiff line change
@@ -1048,9 +1048,12 @@ NVPTXTargetLowering::NVPTXTargetLowering(const NVPTXTargetMachine &TM,
10481048
MVT::v32i32, MVT::v64i32, MVT::v128i32},
10491049
Custom);
10501050

1051-
setOperationAction(ISD::INTRINSIC_WO_CHAIN, MVT::Other, Custom);
1052-
// Enable custom lowering for the i128 bit operand with clusterlaunchcontrol
1053-
setOperationAction(ISD::INTRINSIC_WO_CHAIN, MVT::i128, Custom);
1051+
// Enable custom lowering for the following:
1052+
// * MVT::i128 - clusterlaunchcontrol
1053+
// * MVT::i32 - prmt
1054+
// * MVT::Other - internal.addrspace.wrap
1055+
setOperationAction(ISD::INTRINSIC_WO_CHAIN, {MVT::i32, MVT::i128, MVT::Other},
1056+
Custom);
10541057
}
10551058

10561059
const char *NVPTXTargetLowering::getTargetNodeName(unsigned Opcode) const {
@@ -2060,6 +2063,19 @@ NVPTXTargetLowering::LowerCONCAT_VECTORS(SDValue Op, SelectionDAG &DAG) const {
20602063
return DAG.getBuildVector(Node->getValueType(0), dl, Ops);
20612064
}
20622065

2066+
static SDValue getPRMT(SDValue A, SDValue B, SDValue Selector, SDLoc DL,
2067+
SelectionDAG &DAG,
2068+
unsigned Mode = NVPTX::PTXPrmtMode::NONE) {
2069+
return DAG.getNode(NVPTXISD::PRMT, DL, MVT::i32,
2070+
{A, B, Selector, DAG.getConstant(Mode, DL, MVT::i32)});
2071+
}
2072+
2073+
static SDValue getPRMT(SDValue A, SDValue B, uint64_t Selector, SDLoc DL,
2074+
SelectionDAG &DAG,
2075+
unsigned Mode = NVPTX::PTXPrmtMode::NONE) {
2076+
return getPRMT(A, B, DAG.getConstant(Selector, DL, MVT::i32), DL, DAG, Mode);
2077+
}
2078+
20632079
SDValue NVPTXTargetLowering::LowerBITCAST(SDValue Op, SelectionDAG &DAG) const {
20642080
// Handle bitcasting from v2i8 without hitting the default promotion
20652081
// strategy which goes through stack memory.
@@ -2111,15 +2127,12 @@ SDValue NVPTXTargetLowering::LowerBUILD_VECTOR(SDValue Op,
21112127
L = DAG.getAnyExtOrTrunc(L, DL, MVT::i32);
21122128
R = DAG.getAnyExtOrTrunc(R, DL, MVT::i32);
21132129
}
2114-
return DAG.getNode(
2115-
NVPTXISD::PRMT, DL, MVT::v4i8,
2116-
{L, R, DAG.getConstant(SelectionValue, DL, MVT::i32),
2117-
DAG.getConstant(NVPTX::PTXPrmtMode::NONE, DL, MVT::i32)});
2130+
return getPRMT(L, R, SelectionValue, DL, DAG);
21182131
};
21192132
auto PRMT__10 = GetPRMT(Op->getOperand(0), Op->getOperand(1), true, 0x3340);
21202133
auto PRMT__32 = GetPRMT(Op->getOperand(2), Op->getOperand(3), true, 0x3340);
21212134
auto PRMT3210 = GetPRMT(PRMT__10, PRMT__32, false, 0x5410);
2122-
return DAG.getNode(ISD::BITCAST, DL, VT, PRMT3210);
2135+
return DAG.getBitcast(VT, PRMT3210);
21232136
}
21242137

21252138
// Get value or the Nth operand as an APInt(32). Undef values treated as 0.
@@ -2176,11 +2189,14 @@ SDValue NVPTXTargetLowering::LowerEXTRACT_VECTOR_ELT(SDValue Op,
21762189
SDValue Selector = DAG.getNode(ISD::OR, DL, MVT::i32,
21772190
DAG.getZExtOrTrunc(Index, DL, MVT::i32),
21782191
DAG.getConstant(0x7770, DL, MVT::i32));
2179-
SDValue PRMT = DAG.getNode(
2180-
NVPTXISD::PRMT, DL, MVT::i32,
2181-
{DAG.getBitcast(MVT::i32, Vector), DAG.getConstant(0, DL, MVT::i32),
2182-
Selector, DAG.getConstant(NVPTX::PTXPrmtMode::NONE, DL, MVT::i32)});
2183-
return DAG.getAnyExtOrTrunc(PRMT, DL, Op->getValueType(0));
2192+
SDValue PRMT = getPRMT(DAG.getBitcast(MVT::i32, Vector),
2193+
DAG.getConstant(0, DL, MVT::i32), Selector, DL, DAG);
2194+
SDValue Ext = DAG.getAnyExtOrTrunc(PRMT, DL, Op->getValueType(0));
2195+
SDNodeFlags Flags;
2196+
Flags.setNoSignedWrap(Ext.getScalarValueSizeInBits() > 8);
2197+
Flags.setNoUnsignedWrap(Ext.getScalarValueSizeInBits() >= 8);
2198+
Ext->setFlags(Flags);
2199+
return Ext;
21842200
}
21852201

21862202
// Constant index will be matched by tablegen.
@@ -2242,9 +2258,9 @@ SDValue NVPTXTargetLowering::LowerVECTOR_SHUFFLE(SDValue Op,
22422258
}
22432259

22442260
SDLoc DL(Op);
2245-
return DAG.getNode(NVPTXISD::PRMT, DL, MVT::v4i8, V1, V2,
2246-
DAG.getConstant(Selector, DL, MVT::i32),
2247-
DAG.getConstant(NVPTX::PTXPrmtMode::NONE, DL, MVT::i32));
2261+
SDValue PRMT = getPRMT(DAG.getBitcast(MVT::i32, V1),
2262+
DAG.getBitcast(MVT::i32, V2), Selector, DL, DAG);
2263+
return DAG.getBitcast(Op.getValueType(), PRMT);
22482264
}
22492265
/// LowerShiftRightParts - Lower SRL_PARTS, SRA_PARTS, which
22502266
/// 1) returns two i32 values and take a 2 x i32 value to shift plus a shift
@@ -2729,10 +2745,46 @@ static SDValue LowerClusterLaunchControlQueryCancel(SDValue Op,
27292745
{TryCancelResponse0, TryCancelResponse1});
27302746
}
27312747

2748+
static SDValue lowerPrmtIntrinsic(SDValue Op, SelectionDAG &DAG) {
2749+
const unsigned Mode = [&]() {
2750+
switch (Op->getConstantOperandVal(0)) {
2751+
case Intrinsic::nvvm_prmt:
2752+
return NVPTX::PTXPrmtMode::NONE;
2753+
case Intrinsic::nvvm_prmt_b4e:
2754+
return NVPTX::PTXPrmtMode::B4E;
2755+
case Intrinsic::nvvm_prmt_ecl:
2756+
return NVPTX::PTXPrmtMode::ECL;
2757+
case Intrinsic::nvvm_prmt_ecr:
2758+
return NVPTX::PTXPrmtMode::ECR;
2759+
case Intrinsic::nvvm_prmt_f4e:
2760+
return NVPTX::PTXPrmtMode::F4E;
2761+
case Intrinsic::nvvm_prmt_rc16:
2762+
return NVPTX::PTXPrmtMode::RC16;
2763+
case Intrinsic::nvvm_prmt_rc8:
2764+
return NVPTX::PTXPrmtMode::RC8;
2765+
default:
2766+
llvm_unreachable("unsupported/unhandled intrinsic");
2767+
}
2768+
}();
2769+
SDLoc DL(Op);
2770+
SDValue A = Op->getOperand(1);
2771+
SDValue B = Op.getNumOperands() == 4 ? Op.getOperand(2)
2772+
: DAG.getConstant(0, DL, MVT::i32);
2773+
SDValue Selector = (Op->op_end() - 1)->get();
2774+
return getPRMT(A, B, Selector, DL, DAG, Mode);
2775+
}
27322776
static SDValue lowerIntrinsicWOChain(SDValue Op, SelectionDAG &DAG) {
27332777
switch (Op->getConstantOperandVal(0)) {
27342778
default:
27352779
return Op;
2780+
case Intrinsic::nvvm_prmt:
2781+
case Intrinsic::nvvm_prmt_b4e:
2782+
case Intrinsic::nvvm_prmt_ecl:
2783+
case Intrinsic::nvvm_prmt_ecr:
2784+
case Intrinsic::nvvm_prmt_f4e:
2785+
case Intrinsic::nvvm_prmt_rc16:
2786+
case Intrinsic::nvvm_prmt_rc8:
2787+
return lowerPrmtIntrinsic(Op, DAG);
27362788
case Intrinsic::nvvm_internal_addrspace_wrap:
27372789
return Op.getOperand(1);
27382790
case Intrinsic::nvvm_clusterlaunchcontrol_query_cancel_is_canceled:
@@ -5775,11 +5827,10 @@ PerformBUILD_VECTORCombine(SDNode *N, TargetLowering::DAGCombinerInfo &DCI) {
57755827
SDLoc DL(N);
57765828
auto &DAG = DCI.DAG;
57775829

5778-
auto PRMT = DAG.getNode(
5779-
NVPTXISD::PRMT, DL, MVT::v4i8,
5780-
{Op0, Op1, DAG.getConstant((Op1Bytes << 8) | Op0Bytes, DL, MVT::i32),
5781-
DAG.getConstant(NVPTX::PTXPrmtMode::NONE, DL, MVT::i32)});
5782-
return DAG.getNode(ISD::BITCAST, DL, VT, PRMT);
5830+
auto PRMT =
5831+
getPRMT(DAG.getBitcast(MVT::i32, Op0), DAG.getBitcast(MVT::i32, Op1),
5832+
(Op1Bytes << 8) | Op0Bytes, DL, DAG);
5833+
return DAG.getBitcast(VT, PRMT);
57835834
}
57845835

57855836
static SDValue combineADDRSPACECAST(SDNode *N,
@@ -5797,47 +5848,120 @@ static SDValue combineADDRSPACECAST(SDNode *N,
57975848
return SDValue();
57985849
}
57995850

5851+
// Given a constant selector value and a prmt mode, return the selector value
5852+
// normalized to the generic prmt mode. See the PTX ISA documentation for more
5853+
// details:
5854+
// https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#data-movement-and-conversion-instructions-prmt
5855+
static APInt getPRMTSelector(const APInt &Selector, unsigned Mode) {
5856+
if (Mode == NVPTX::PTXPrmtMode::NONE)
5857+
return Selector;
5858+
5859+
const unsigned V = Selector.trunc(2).getZExtValue();
5860+
5861+
const auto GetSelector = [](unsigned S0, unsigned S1, unsigned S2,
5862+
unsigned S3) {
5863+
return APInt(32, S0 | (S1 << 4) | (S2 << 8) | (S3 << 12));
5864+
};
5865+
5866+
switch (Mode) {
5867+
case NVPTX::PTXPrmtMode::F4E:
5868+
return GetSelector(V, V + 1, V + 2, V + 3);
5869+
case NVPTX::PTXPrmtMode::B4E:
5870+
return GetSelector(V, (V - 1) & 7, (V - 2) & 7, (V - 3) & 7);
5871+
case NVPTX::PTXPrmtMode::RC8:
5872+
return GetSelector(V, V, V, V);
5873+
case NVPTX::PTXPrmtMode::ECL:
5874+
return GetSelector(V, std::max(V, 1U), std::max(V, 2U), 3U);
5875+
case NVPTX::PTXPrmtMode::ECR:
5876+
return GetSelector(0, std::min(V, 1U), std::min(V, 2U), V);
5877+
case NVPTX::PTXPrmtMode::RC16: {
5878+
unsigned V1 = (V & 1) << 1;
5879+
return GetSelector(V1, V1 + 1, V1, V1 + 1);
5880+
}
5881+
default:
5882+
llvm_unreachable("Invalid PRMT mode");
5883+
}
5884+
}
5885+
5886+
static APInt computePRMT(APInt A, APInt B, APInt Selector, unsigned Mode) {
5887+
// {b, a} = {{b7, b6, b5, b4}, {b3, b2, b1, b0}}
5888+
APInt BitField = B.concat(A);
5889+
APInt SelectorVal = getPRMTSelector(Selector, Mode);
5890+
APInt Result(32, 0);
5891+
for (unsigned I : llvm::seq(4U)) {
5892+
APInt Sel = SelectorVal.extractBits(4, I * 4);
5893+
unsigned Idx = Sel.getLoBits(3).getZExtValue();
5894+
unsigned Sign = Sel.getHiBits(1).getZExtValue();
5895+
APInt Byte = BitField.extractBits(8, Idx * 8);
5896+
if (Sign)
5897+
Byte = Byte.ashr(8);
5898+
Result.insertBits(Byte, I * 8);
5899+
}
5900+
return Result;
5901+
}
5902+
5903+
static SDValue combinePRMT(SDNode *N, TargetLowering::DAGCombinerInfo &DCI,
5904+
CodeGenOptLevel OptLevel) {
5905+
if (OptLevel == CodeGenOptLevel::None)
5906+
return SDValue();
5907+
5908+
// Constant fold PRMT
5909+
if (isa<ConstantSDNode>(N->getOperand(0)) &&
5910+
isa<ConstantSDNode>(N->getOperand(1)) &&
5911+
isa<ConstantSDNode>(N->getOperand(2)))
5912+
return DCI.DAG.getConstant(computePRMT(N->getConstantOperandAPInt(0),
5913+
N->getConstantOperandAPInt(1),
5914+
N->getConstantOperandAPInt(2),
5915+
N->getConstantOperandVal(3)),
5916+
SDLoc(N), N->getValueType(0));
5917+
5918+
return SDValue();
5919+
}
5920+
58005921
SDValue NVPTXTargetLowering::PerformDAGCombine(SDNode *N,
58015922
DAGCombinerInfo &DCI) const {
58025923
CodeGenOptLevel OptLevel = getTargetMachine().getOptLevel();
58035924
switch (N->getOpcode()) {
5804-
default: break;
5805-
case ISD::ADD:
5806-
return PerformADDCombine(N, DCI, OptLevel);
5807-
case ISD::FADD:
5808-
return PerformFADDCombine(N, DCI, OptLevel);
5809-
case ISD::MUL:
5810-
return PerformMULCombine(N, DCI, OptLevel);
5811-
case ISD::SHL:
5812-
return PerformSHLCombine(N, DCI, OptLevel);
5813-
case ISD::AND:
5814-
return PerformANDCombine(N, DCI);
5815-
case ISD::UREM:
5816-
case ISD::SREM:
5817-
return PerformREMCombine(N, DCI, OptLevel);
5818-
case ISD::SETCC:
5819-
return PerformSETCCCombine(N, DCI, STI.getSmVersion());
5820-
case ISD::LOAD:
5821-
case NVPTXISD::LoadParamV2:
5822-
case NVPTXISD::LoadV2:
5823-
case NVPTXISD::LoadV4:
5824-
return combineUnpackingMovIntoLoad(N, DCI);
5825-
case NVPTXISD::StoreParam:
5826-
case NVPTXISD::StoreParamV2:
5827-
case NVPTXISD::StoreParamV4:
5828-
return PerformStoreParamCombine(N, DCI);
5829-
case ISD::STORE:
5830-
case NVPTXISD::StoreV2:
5831-
case NVPTXISD::StoreV4:
5832-
return PerformStoreCombine(N, DCI);
5833-
case ISD::EXTRACT_VECTOR_ELT:
5834-
return PerformEXTRACTCombine(N, DCI);
5835-
case ISD::VSELECT:
5836-
return PerformVSELECTCombine(N, DCI);
5837-
case ISD::BUILD_VECTOR:
5838-
return PerformBUILD_VECTORCombine(N, DCI);
5839-
case ISD::ADDRSPACECAST:
5840-
return combineADDRSPACECAST(N, DCI);
5925+
default:
5926+
break;
5927+
case ISD::ADD:
5928+
return PerformADDCombine(N, DCI, OptLevel);
5929+
case ISD::ADDRSPACECAST:
5930+
return combineADDRSPACECAST(N, DCI);
5931+
case ISD::AND:
5932+
return PerformANDCombine(N, DCI);
5933+
case ISD::BUILD_VECTOR:
5934+
return PerformBUILD_VECTORCombine(N, DCI);
5935+
case ISD::EXTRACT_VECTOR_ELT:
5936+
return PerformEXTRACTCombine(N, DCI);
5937+
case ISD::FADD:
5938+
return PerformFADDCombine(N, DCI, OptLevel);
5939+
case ISD::LOAD:
5940+
case NVPTXISD::LoadParamV2:
5941+
case NVPTXISD::LoadV2:
5942+
case NVPTXISD::LoadV4:
5943+
return combineUnpackingMovIntoLoad(N, DCI);
5944+
case ISD::MUL:
5945+
return PerformMULCombine(N, DCI, OptLevel);
5946+
case NVPTXISD::PRMT:
5947+
return combinePRMT(N, DCI, OptLevel);
5948+
case ISD::SETCC:
5949+
return PerformSETCCCombine(N, DCI, STI.getSmVersion());
5950+
case ISD::SHL:
5951+
return PerformSHLCombine(N, DCI, OptLevel);
5952+
case ISD::SREM:
5953+
case ISD::UREM:
5954+
return PerformREMCombine(N, DCI, OptLevel);
5955+
case NVPTXISD::StoreParam:
5956+
case NVPTXISD::StoreParamV2:
5957+
case NVPTXISD::StoreParamV4:
5958+
return PerformStoreParamCombine(N, DCI);
5959+
case ISD::STORE:
5960+
case NVPTXISD::StoreV2:
5961+
case NVPTXISD::StoreV4:
5962+
return PerformStoreCombine(N, DCI);
5963+
case ISD::VSELECT:
5964+
return PerformVSELECTCombine(N, DCI);
58415965
}
58425966
return SDValue();
58435967
}
@@ -6387,7 +6511,7 @@ static void computeKnownBitsForPRMT(const SDValue Op, KnownBits &Known,
63876511
ConstantSDNode *Selector = dyn_cast<ConstantSDNode>(Op.getOperand(2));
63886512
unsigned Mode = Op.getConstantOperandVal(3);
63896513

6390-
if (Mode != NVPTX::PTXPrmtMode::NONE || !Selector)
6514+
if (!Selector)
63916515
return;
63926516

63936517
KnownBits AKnown = DAG.computeKnownBits(A, Depth);
@@ -6396,7 +6520,7 @@ static void computeKnownBitsForPRMT(const SDValue Op, KnownBits &Known,
63966520
// {b, a} = {{b7, b6, b5, b4}, {b3, b2, b1, b0}}
63976521
KnownBits BitField = BKnown.concat(AKnown);
63986522

6399-
APInt SelectorVal = Selector->getAPIntValue();
6523+
APInt SelectorVal = getPRMTSelector(Selector->getAPIntValue(), Mode);
64006524
for (unsigned I : llvm::seq(std::min(4U, Known.getBitWidth() / 8))) {
64016525
APInt Sel = SelectorVal.extractBits(4, I * 4);
64026526
unsigned Idx = Sel.getLoBits(3).getZExtValue();

llvm/lib/Target/NVPTX/NVPTXInstrInfo.td

Lines changed: 19 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1453,18 +1453,33 @@ let hasSideEffects = false in {
14531453
(ins PrmtMode:$mode),
14541454
"prmt.b32$mode",
14551455
[(set i32:$d, (prmt i32:$a, i32:$b, imm:$c, imm:$mode))]>;
1456+
def PRMT_B32rir
1457+
: BasicFlagsNVPTXInst<(outs B32:$d),
1458+
(ins B32:$a, i32imm:$b, B32:$c),
1459+
(ins PrmtMode:$mode),
1460+
"prmt.b32$mode",
1461+
[(set i32:$d, (prmt i32:$a, imm:$b, i32:$c, imm:$mode))]>;
14561462
def PRMT_B32rii
14571463
: BasicFlagsNVPTXInst<(outs B32:$d),
14581464
(ins B32:$a, i32imm:$b, Hexu32imm:$c),
14591465
(ins PrmtMode:$mode),
14601466
"prmt.b32$mode",
14611467
[(set i32:$d, (prmt i32:$a, imm:$b, imm:$c, imm:$mode))]>;
1462-
def PRMT_B32rir
1468+
def PRMT_B32irr
14631469
: BasicFlagsNVPTXInst<(outs B32:$d),
1464-
(ins B32:$a, i32imm:$b, B32:$c),
1465-
(ins PrmtMode:$mode),
1470+
(ins i32imm:$a, B32:$b, B32:$c), (ins PrmtMode:$mode),
1471+
"prmt.b32$mode",
1472+
[(set i32:$d, (prmt imm:$a, i32:$b, i32:$c, imm:$mode))]>;
1473+
def PRMT_B32iri
1474+
: BasicFlagsNVPTXInst<(outs B32:$d),
1475+
(ins i32imm:$a, B32:$b, Hexu32imm:$c), (ins PrmtMode:$mode),
1476+
"prmt.b32$mode",
1477+
[(set i32:$d, (prmt imm:$a, i32:$b, imm:$c, imm:$mode))]>;
1478+
def PRMT_B32iir
1479+
: BasicFlagsNVPTXInst<(outs B32:$d),
1480+
(ins i32imm:$a, i32imm:$b, B32:$c), (ins PrmtMode:$mode),
14661481
"prmt.b32$mode",
1467-
[(set i32:$d, (prmt i32:$a, imm:$b, i32:$c, imm:$mode))]>;
1482+
[(set i32:$d, (prmt imm:$a, imm:$b, i32:$c, imm:$mode))]>;
14681483

14691484
}
14701485

llvm/lib/Target/NVPTX/NVPTXIntrinsics.td

Lines changed: 0 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -1047,24 +1047,6 @@ class F_MATH_3<string OpcStr, NVPTXRegClass t_regclass,
10471047
// MISC
10481048
//
10491049

1050-
class PRMT3Pat<Intrinsic prmt_intrinsic, PatLeaf prmt_mode>
1051-
: Pat<(prmt_intrinsic i32:$a, i32:$b, i32:$c),
1052-
(PRMT_B32rrr $a, $b, $c, prmt_mode)>;
1053-
1054-
class PRMT2Pat<Intrinsic prmt_intrinsic, PatLeaf prmt_mode>
1055-
: Pat<(prmt_intrinsic i32:$a, i32:$c),
1056-
(PRMT_B32rir $a, (i32 0), $c, prmt_mode)>;
1057-
1058-
def : PRMT3Pat<int_nvvm_prmt, PrmtNONE>;
1059-
def : PRMT3Pat<int_nvvm_prmt_f4e, PrmtF4E>;
1060-
def : PRMT3Pat<int_nvvm_prmt_b4e, PrmtB4E>;
1061-
1062-
def : PRMT2Pat<int_nvvm_prmt_rc8, PrmtRC8>;
1063-
def : PRMT2Pat<int_nvvm_prmt_ecl, PrmtECL>;
1064-
def : PRMT2Pat<int_nvvm_prmt_ecr, PrmtECR>;
1065-
def : PRMT2Pat<int_nvvm_prmt_rc16, PrmtRC16>;
1066-
1067-
10681050
def INT_NVVM_NANOSLEEP_I : BasicNVPTXInst<(outs), (ins i32imm:$i), "nanosleep.u32",
10691051
[(int_nvvm_nanosleep imm:$i)]>,
10701052
Requires<[hasPTX<63>, hasSM<70>]>;

0 commit comments

Comments
 (0)