@@ -1048,9 +1048,12 @@ NVPTXTargetLowering::NVPTXTargetLowering(const NVPTXTargetMachine &TM,
1048
1048
MVT::v32i32, MVT::v64i32, MVT::v128i32},
1049
1049
Custom);
1050
1050
1051
- setOperationAction (ISD::INTRINSIC_WO_CHAIN, MVT::Other, Custom);
1052
- // Enable custom lowering for the i128 bit operand with clusterlaunchcontrol
1053
- setOperationAction (ISD::INTRINSIC_WO_CHAIN, MVT::i128 , Custom);
1051
+ // Enable custom lowering for the following:
1052
+ // * MVT::i128 - clusterlaunchcontrol
1053
+ // * MVT::i32 - prmt
1054
+ // * MVT::Other - internal.addrspace.wrap
1055
+ setOperationAction (ISD::INTRINSIC_WO_CHAIN, {MVT::i32 , MVT::i128 , MVT::Other},
1056
+ Custom);
1054
1057
}
1055
1058
1056
1059
const char *NVPTXTargetLowering::getTargetNodeName (unsigned Opcode) const {
@@ -2060,6 +2063,19 @@ NVPTXTargetLowering::LowerCONCAT_VECTORS(SDValue Op, SelectionDAG &DAG) const {
2060
2063
return DAG.getBuildVector (Node->getValueType (0 ), dl, Ops);
2061
2064
}
2062
2065
2066
+ static SDValue getPRMT (SDValue A, SDValue B, SDValue Selector, SDLoc DL,
2067
+ SelectionDAG &DAG,
2068
+ unsigned Mode = NVPTX::PTXPrmtMode::NONE) {
2069
+ return DAG.getNode (NVPTXISD::PRMT, DL, MVT::i32 ,
2070
+ {A, B, Selector, DAG.getConstant (Mode, DL, MVT::i32 )});
2071
+ }
2072
+
2073
+ static SDValue getPRMT (SDValue A, SDValue B, uint64_t Selector, SDLoc DL,
2074
+ SelectionDAG &DAG,
2075
+ unsigned Mode = NVPTX::PTXPrmtMode::NONE) {
2076
+ return getPRMT (A, B, DAG.getConstant (Selector, DL, MVT::i32 ), DL, DAG, Mode);
2077
+ }
2078
+
2063
2079
SDValue NVPTXTargetLowering::LowerBITCAST (SDValue Op, SelectionDAG &DAG) const {
2064
2080
// Handle bitcasting from v2i8 without hitting the default promotion
2065
2081
// strategy which goes through stack memory.
@@ -2111,15 +2127,12 @@ SDValue NVPTXTargetLowering::LowerBUILD_VECTOR(SDValue Op,
2111
2127
L = DAG.getAnyExtOrTrunc (L, DL, MVT::i32 );
2112
2128
R = DAG.getAnyExtOrTrunc (R, DL, MVT::i32 );
2113
2129
}
2114
- return DAG.getNode (
2115
- NVPTXISD::PRMT, DL, MVT::v4i8,
2116
- {L, R, DAG.getConstant (SelectionValue, DL, MVT::i32 ),
2117
- DAG.getConstant (NVPTX::PTXPrmtMode::NONE, DL, MVT::i32 )});
2130
+ return getPRMT (L, R, SelectionValue, DL, DAG);
2118
2131
};
2119
2132
auto PRMT__10 = GetPRMT (Op->getOperand (0 ), Op->getOperand (1 ), true , 0x3340 );
2120
2133
auto PRMT__32 = GetPRMT (Op->getOperand (2 ), Op->getOperand (3 ), true , 0x3340 );
2121
2134
auto PRMT3210 = GetPRMT (PRMT__10, PRMT__32, false , 0x5410 );
2122
- return DAG.getNode (ISD::BITCAST, DL, VT, PRMT3210);
2135
+ return DAG.getBitcast ( VT, PRMT3210);
2123
2136
}
2124
2137
2125
2138
// Get value or the Nth operand as an APInt(32). Undef values treated as 0.
@@ -2176,11 +2189,14 @@ SDValue NVPTXTargetLowering::LowerEXTRACT_VECTOR_ELT(SDValue Op,
2176
2189
SDValue Selector = DAG.getNode (ISD::OR, DL, MVT::i32 ,
2177
2190
DAG.getZExtOrTrunc (Index, DL, MVT::i32 ),
2178
2191
DAG.getConstant (0x7770 , DL, MVT::i32 ));
2179
- SDValue PRMT = DAG.getNode (
2180
- NVPTXISD::PRMT, DL, MVT::i32 ,
2181
- {DAG.getBitcast (MVT::i32 , Vector), DAG.getConstant (0 , DL, MVT::i32 ),
2182
- Selector, DAG.getConstant (NVPTX::PTXPrmtMode::NONE, DL, MVT::i32 )});
2183
- return DAG.getAnyExtOrTrunc (PRMT, DL, Op->getValueType (0 ));
2192
+ SDValue PRMT = getPRMT (DAG.getBitcast (MVT::i32 , Vector),
2193
+ DAG.getConstant (0 , DL, MVT::i32 ), Selector, DL, DAG);
2194
+ SDValue Ext = DAG.getAnyExtOrTrunc (PRMT, DL, Op->getValueType (0 ));
2195
+ SDNodeFlags Flags;
2196
+ Flags.setNoSignedWrap (Ext.getScalarValueSizeInBits () > 8 );
2197
+ Flags.setNoUnsignedWrap (Ext.getScalarValueSizeInBits () >= 8 );
2198
+ Ext->setFlags (Flags);
2199
+ return Ext;
2184
2200
}
2185
2201
2186
2202
// Constant index will be matched by tablegen.
@@ -2242,9 +2258,9 @@ SDValue NVPTXTargetLowering::LowerVECTOR_SHUFFLE(SDValue Op,
2242
2258
}
2243
2259
2244
2260
SDLoc DL (Op);
2245
- return DAG. getNode (NVPTXISD:: PRMT, DL, MVT::v4i8 , V1, V2 ,
2246
- DAG.getConstant ( Selector, DL, MVT:: i32 ),
2247
- DAG.getConstant (NVPTX::PTXPrmtMode::NONE, DL, MVT:: i32 ) );
2261
+ SDValue PRMT = getPRMT (DAG. getBitcast ( MVT::i32 , V1) ,
2262
+ DAG.getBitcast (MVT:: i32 , V2), Selector, DL, DAG);
2263
+ return DAG.getBitcast (Op. getValueType (), PRMT );
2248
2264
}
2249
2265
// / LowerShiftRightParts - Lower SRL_PARTS, SRA_PARTS, which
2250
2266
// / 1) returns two i32 values and take a 2 x i32 value to shift plus a shift
@@ -2729,10 +2745,46 @@ static SDValue LowerClusterLaunchControlQueryCancel(SDValue Op,
2729
2745
{TryCancelResponse0, TryCancelResponse1});
2730
2746
}
2731
2747
2748
+ static SDValue lowerPrmtIntrinsic (SDValue Op, SelectionDAG &DAG) {
2749
+ const unsigned Mode = [&]() {
2750
+ switch (Op->getConstantOperandVal (0 )) {
2751
+ case Intrinsic::nvvm_prmt:
2752
+ return NVPTX::PTXPrmtMode::NONE;
2753
+ case Intrinsic::nvvm_prmt_b4e:
2754
+ return NVPTX::PTXPrmtMode::B4E;
2755
+ case Intrinsic::nvvm_prmt_ecl:
2756
+ return NVPTX::PTXPrmtMode::ECL;
2757
+ case Intrinsic::nvvm_prmt_ecr:
2758
+ return NVPTX::PTXPrmtMode::ECR;
2759
+ case Intrinsic::nvvm_prmt_f4e:
2760
+ return NVPTX::PTXPrmtMode::F4E;
2761
+ case Intrinsic::nvvm_prmt_rc16:
2762
+ return NVPTX::PTXPrmtMode::RC16;
2763
+ case Intrinsic::nvvm_prmt_rc8:
2764
+ return NVPTX::PTXPrmtMode::RC8;
2765
+ default :
2766
+ llvm_unreachable (" unsupported/unhandled intrinsic" );
2767
+ }
2768
+ }();
2769
+ SDLoc DL (Op);
2770
+ SDValue A = Op->getOperand (1 );
2771
+ SDValue B = Op.getNumOperands () == 4 ? Op.getOperand (2 )
2772
+ : DAG.getConstant (0 , DL, MVT::i32 );
2773
+ SDValue Selector = (Op->op_end () - 1 )->get ();
2774
+ return getPRMT (A, B, Selector, DL, DAG, Mode);
2775
+ }
2732
2776
static SDValue lowerIntrinsicWOChain (SDValue Op, SelectionDAG &DAG) {
2733
2777
switch (Op->getConstantOperandVal (0 )) {
2734
2778
default :
2735
2779
return Op;
2780
+ case Intrinsic::nvvm_prmt:
2781
+ case Intrinsic::nvvm_prmt_b4e:
2782
+ case Intrinsic::nvvm_prmt_ecl:
2783
+ case Intrinsic::nvvm_prmt_ecr:
2784
+ case Intrinsic::nvvm_prmt_f4e:
2785
+ case Intrinsic::nvvm_prmt_rc16:
2786
+ case Intrinsic::nvvm_prmt_rc8:
2787
+ return lowerPrmtIntrinsic (Op, DAG);
2736
2788
case Intrinsic::nvvm_internal_addrspace_wrap:
2737
2789
return Op.getOperand (1 );
2738
2790
case Intrinsic::nvvm_clusterlaunchcontrol_query_cancel_is_canceled:
@@ -5775,11 +5827,10 @@ PerformBUILD_VECTORCombine(SDNode *N, TargetLowering::DAGCombinerInfo &DCI) {
5775
5827
SDLoc DL (N);
5776
5828
auto &DAG = DCI.DAG ;
5777
5829
5778
- auto PRMT = DAG.getNode (
5779
- NVPTXISD::PRMT, DL, MVT::v4i8,
5780
- {Op0, Op1, DAG.getConstant ((Op1Bytes << 8 ) | Op0Bytes, DL, MVT::i32 ),
5781
- DAG.getConstant (NVPTX::PTXPrmtMode::NONE, DL, MVT::i32 )});
5782
- return DAG.getNode (ISD::BITCAST, DL, VT, PRMT);
5830
+ auto PRMT =
5831
+ getPRMT (DAG.getBitcast (MVT::i32 , Op0), DAG.getBitcast (MVT::i32 , Op1),
5832
+ (Op1Bytes << 8 ) | Op0Bytes, DL, DAG);
5833
+ return DAG.getBitcast (VT, PRMT);
5783
5834
}
5784
5835
5785
5836
static SDValue combineADDRSPACECAST (SDNode *N,
@@ -5797,47 +5848,120 @@ static SDValue combineADDRSPACECAST(SDNode *N,
5797
5848
return SDValue ();
5798
5849
}
5799
5850
5851
+ // Given a constant selector value and a prmt mode, return the selector value
5852
+ // normalized to the generic prmt mode. See the PTX ISA documentation for more
5853
+ // details:
5854
+ // https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#data-movement-and-conversion-instructions-prmt
5855
+ static APInt getPRMTSelector (const APInt &Selector, unsigned Mode) {
5856
+ if (Mode == NVPTX::PTXPrmtMode::NONE)
5857
+ return Selector;
5858
+
5859
+ const unsigned V = Selector.trunc (2 ).getZExtValue ();
5860
+
5861
+ const auto GetSelector = [](unsigned S0, unsigned S1, unsigned S2,
5862
+ unsigned S3) {
5863
+ return APInt (32 , S0 | (S1 << 4 ) | (S2 << 8 ) | (S3 << 12 ));
5864
+ };
5865
+
5866
+ switch (Mode) {
5867
+ case NVPTX::PTXPrmtMode::F4E:
5868
+ return GetSelector (V, V + 1 , V + 2 , V + 3 );
5869
+ case NVPTX::PTXPrmtMode::B4E:
5870
+ return GetSelector (V, (V - 1 ) & 7 , (V - 2 ) & 7 , (V - 3 ) & 7 );
5871
+ case NVPTX::PTXPrmtMode::RC8:
5872
+ return GetSelector (V, V, V, V);
5873
+ case NVPTX::PTXPrmtMode::ECL:
5874
+ return GetSelector (V, std::max (V, 1U ), std::max (V, 2U ), 3U );
5875
+ case NVPTX::PTXPrmtMode::ECR:
5876
+ return GetSelector (0 , std::min (V, 1U ), std::min (V, 2U ), V);
5877
+ case NVPTX::PTXPrmtMode::RC16: {
5878
+ unsigned V1 = (V & 1 ) << 1 ;
5879
+ return GetSelector (V1, V1 + 1 , V1, V1 + 1 );
5880
+ }
5881
+ default :
5882
+ llvm_unreachable (" Invalid PRMT mode" );
5883
+ }
5884
+ }
5885
+
5886
+ static APInt computePRMT (APInt A, APInt B, APInt Selector, unsigned Mode) {
5887
+ // {b, a} = {{b7, b6, b5, b4}, {b3, b2, b1, b0}}
5888
+ APInt BitField = B.concat (A);
5889
+ APInt SelectorVal = getPRMTSelector (Selector, Mode);
5890
+ APInt Result (32 , 0 );
5891
+ for (unsigned I : llvm::seq (4U )) {
5892
+ APInt Sel = SelectorVal.extractBits (4 , I * 4 );
5893
+ unsigned Idx = Sel.getLoBits (3 ).getZExtValue ();
5894
+ unsigned Sign = Sel.getHiBits (1 ).getZExtValue ();
5895
+ APInt Byte = BitField.extractBits (8 , Idx * 8 );
5896
+ if (Sign)
5897
+ Byte = Byte.ashr (8 );
5898
+ Result.insertBits (Byte, I * 8 );
5899
+ }
5900
+ return Result;
5901
+ }
5902
+
5903
+ static SDValue combinePRMT (SDNode *N, TargetLowering::DAGCombinerInfo &DCI,
5904
+ CodeGenOptLevel OptLevel) {
5905
+ if (OptLevel == CodeGenOptLevel::None)
5906
+ return SDValue ();
5907
+
5908
+ // Constant fold PRMT
5909
+ if (isa<ConstantSDNode>(N->getOperand (0 )) &&
5910
+ isa<ConstantSDNode>(N->getOperand (1 )) &&
5911
+ isa<ConstantSDNode>(N->getOperand (2 )))
5912
+ return DCI.DAG .getConstant (computePRMT (N->getConstantOperandAPInt (0 ),
5913
+ N->getConstantOperandAPInt (1 ),
5914
+ N->getConstantOperandAPInt (2 ),
5915
+ N->getConstantOperandVal (3 )),
5916
+ SDLoc (N), N->getValueType (0 ));
5917
+
5918
+ return SDValue ();
5919
+ }
5920
+
5800
5921
SDValue NVPTXTargetLowering::PerformDAGCombine (SDNode *N,
5801
5922
DAGCombinerInfo &DCI) const {
5802
5923
CodeGenOptLevel OptLevel = getTargetMachine ().getOptLevel ();
5803
5924
switch (N->getOpcode ()) {
5804
- default : break ;
5805
- case ISD::ADD:
5806
- return PerformADDCombine (N, DCI, OptLevel);
5807
- case ISD::FADD:
5808
- return PerformFADDCombine (N, DCI, OptLevel);
5809
- case ISD::MUL:
5810
- return PerformMULCombine (N, DCI, OptLevel);
5811
- case ISD::SHL:
5812
- return PerformSHLCombine (N, DCI, OptLevel);
5813
- case ISD::AND:
5814
- return PerformANDCombine (N, DCI);
5815
- case ISD::UREM:
5816
- case ISD::SREM:
5817
- return PerformREMCombine (N, DCI, OptLevel);
5818
- case ISD::SETCC:
5819
- return PerformSETCCCombine (N, DCI, STI.getSmVersion ());
5820
- case ISD::LOAD:
5821
- case NVPTXISD::LoadParamV2:
5822
- case NVPTXISD::LoadV2:
5823
- case NVPTXISD::LoadV4:
5824
- return combineUnpackingMovIntoLoad (N, DCI);
5825
- case NVPTXISD::StoreParam:
5826
- case NVPTXISD::StoreParamV2:
5827
- case NVPTXISD::StoreParamV4:
5828
- return PerformStoreParamCombine (N, DCI);
5829
- case ISD::STORE:
5830
- case NVPTXISD::StoreV2:
5831
- case NVPTXISD::StoreV4:
5832
- return PerformStoreCombine (N, DCI);
5833
- case ISD::EXTRACT_VECTOR_ELT:
5834
- return PerformEXTRACTCombine (N, DCI);
5835
- case ISD::VSELECT:
5836
- return PerformVSELECTCombine (N, DCI);
5837
- case ISD::BUILD_VECTOR:
5838
- return PerformBUILD_VECTORCombine (N, DCI);
5839
- case ISD::ADDRSPACECAST:
5840
- return combineADDRSPACECAST (N, DCI);
5925
+ default :
5926
+ break ;
5927
+ case ISD::ADD:
5928
+ return PerformADDCombine (N, DCI, OptLevel);
5929
+ case ISD::ADDRSPACECAST:
5930
+ return combineADDRSPACECAST (N, DCI);
5931
+ case ISD::AND:
5932
+ return PerformANDCombine (N, DCI);
5933
+ case ISD::BUILD_VECTOR:
5934
+ return PerformBUILD_VECTORCombine (N, DCI);
5935
+ case ISD::EXTRACT_VECTOR_ELT:
5936
+ return PerformEXTRACTCombine (N, DCI);
5937
+ case ISD::FADD:
5938
+ return PerformFADDCombine (N, DCI, OptLevel);
5939
+ case ISD::LOAD:
5940
+ case NVPTXISD::LoadParamV2:
5941
+ case NVPTXISD::LoadV2:
5942
+ case NVPTXISD::LoadV4:
5943
+ return combineUnpackingMovIntoLoad (N, DCI);
5944
+ case ISD::MUL:
5945
+ return PerformMULCombine (N, DCI, OptLevel);
5946
+ case NVPTXISD::PRMT:
5947
+ return combinePRMT (N, DCI, OptLevel);
5948
+ case ISD::SETCC:
5949
+ return PerformSETCCCombine (N, DCI, STI.getSmVersion ());
5950
+ case ISD::SHL:
5951
+ return PerformSHLCombine (N, DCI, OptLevel);
5952
+ case ISD::SREM:
5953
+ case ISD::UREM:
5954
+ return PerformREMCombine (N, DCI, OptLevel);
5955
+ case NVPTXISD::StoreParam:
5956
+ case NVPTXISD::StoreParamV2:
5957
+ case NVPTXISD::StoreParamV4:
5958
+ return PerformStoreParamCombine (N, DCI);
5959
+ case ISD::STORE:
5960
+ case NVPTXISD::StoreV2:
5961
+ case NVPTXISD::StoreV4:
5962
+ return PerformStoreCombine (N, DCI);
5963
+ case ISD::VSELECT:
5964
+ return PerformVSELECTCombine (N, DCI);
5841
5965
}
5842
5966
return SDValue ();
5843
5967
}
@@ -6387,7 +6511,7 @@ static void computeKnownBitsForPRMT(const SDValue Op, KnownBits &Known,
6387
6511
ConstantSDNode *Selector = dyn_cast<ConstantSDNode>(Op.getOperand (2 ));
6388
6512
unsigned Mode = Op.getConstantOperandVal (3 );
6389
6513
6390
- if (Mode != NVPTX::PTXPrmtMode::NONE || !Selector)
6514
+ if (!Selector)
6391
6515
return ;
6392
6516
6393
6517
KnownBits AKnown = DAG.computeKnownBits (A, Depth);
@@ -6396,7 +6520,7 @@ static void computeKnownBitsForPRMT(const SDValue Op, KnownBits &Known,
6396
6520
// {b, a} = {{b7, b6, b5, b4}, {b3, b2, b1, b0}}
6397
6521
KnownBits BitField = BKnown.concat (AKnown);
6398
6522
6399
- APInt SelectorVal = Selector->getAPIntValue ();
6523
+ APInt SelectorVal = getPRMTSelector ( Selector->getAPIntValue (), Mode );
6400
6524
for (unsigned I : llvm::seq (std::min (4U , Known.getBitWidth () / 8 ))) {
6401
6525
APInt Sel = SelectorVal.extractBits (4 , I * 4 );
6402
6526
unsigned Idx = Sel.getLoBits (3 ).getZExtValue ();
0 commit comments