@@ -5217,103 +5217,131 @@ bool NVPTXTargetLowering::allowUnsafeFPMath(MachineFunction &MF) const {
52175217 return F.getFnAttribute (" unsafe-fp-math" ).getValueAsBool ();
52185218}
52195219
5220+ static bool isConstZero (const SDValue &Operand) {
5221+ const auto *Const = dyn_cast<ConstantSDNode>(Operand);
5222+ return Const && Const->getZExtValue () == 0 ;
5223+ }
5224+
52205225// / PerformADDCombineWithOperands - Try DAG combinations for an ADD with
52215226// / operands N0 and N1. This is a helper for PerformADDCombine that is
52225227// / called with the default operands, and if that fails, with commuted
52235228// / operands.
5224- static SDValue PerformADDCombineWithOperands (
5225- SDNode *N, SDValue N0, SDValue N1, TargetLowering::DAGCombinerInfo &DCI,
5226- const NVPTXSubtarget &Subtarget, CodeGenOptLevel OptLevel) {
5227- SelectionDAG &DAG = DCI.DAG ;
5228- // Skip non-integer, non-scalar case
5229- EVT VT=N0.getValueType ();
5230- if (VT.isVector ())
5229+ static SDValue
5230+ PerformADDCombineWithOperands (SDNode *N, SDValue N0, SDValue N1,
5231+ TargetLowering::DAGCombinerInfo &DCI) {
5232+ EVT VT = N0.getValueType ();
5233+
5234+ // Since integer multiply-add costs the same as integer multiply
5235+ // but is more costly than integer add, do the fusion only when
5236+ // the mul is only used in the add.
5237+ // TODO: this may not be true for later architectures, consider relaxing this
5238+ if (!N0.getNode ()->hasOneUse ())
52315239 return SDValue ();
52325240
52335241 // fold (add (mul a, b), c) -> (mad a, b, c)
52345242 //
5235- if (N0.getOpcode () == ISD::MUL) {
5236- assert (VT.isInteger ());
5237- // For integer:
5238- // Since integer multiply-add costs the same as integer multiply
5239- // but is more costly than integer add, do the fusion only when
5240- // the mul is only used in the add.
5241- if (OptLevel == CodeGenOptLevel::None || VT != MVT::i32 ||
5242- !N0.getNode ()->hasOneUse ())
5243+ if (N0.getOpcode () == ISD::MUL)
5244+ return DCI.DAG .getNode (NVPTXISD::IMAD, SDLoc (N), VT, N0.getOperand (0 ),
5245+ N0.getOperand (1 ), N1);
5246+
5247+ // fold (add (select cond, 0, (mul a, b)), c)
5248+ // -> (select cond, c, (mad a, b, c))
5249+ //
5250+ if (N0.getOpcode () == ISD::SELECT) {
5251+ unsigned ZeroOpNum;
5252+ if (isConstZero (N0->getOperand (1 )))
5253+ ZeroOpNum = 1 ;
5254+ else if (isConstZero (N0->getOperand (2 )))
5255+ ZeroOpNum = 2 ;
5256+ else
5257+ return SDValue ();
5258+
5259+ SDValue M = N0->getOperand ((ZeroOpNum == 1 ) ? 2 : 1 );
5260+ if (M->getOpcode () != ISD::MUL || !M.getNode ()->hasOneUse ())
52435261 return SDValue ();
52445262
5245- // Do the folding
5246- return DAG.getNode (NVPTXISD::IMAD, SDLoc (N), VT,
5247- N0.getOperand (0 ), N0.getOperand (1 ), N1);
5263+ SDValue MAD = DCI.DAG .getNode (NVPTXISD::IMAD, SDLoc (N), VT,
5264+ M->getOperand (0 ), M->getOperand (1 ), N1);
5265+ return DCI.DAG .getSelect (SDLoc (N), VT, N0->getOperand (0 ),
5266+ ((ZeroOpNum == 1 ) ? N1 : MAD),
5267+ ((ZeroOpNum == 1 ) ? MAD : N1));
52485268 }
5249- else if (N0.getOpcode () == ISD::FMUL) {
5250- if (VT == MVT::f32 || VT == MVT::f64 ) {
5251- const auto *TLI = static_cast <const NVPTXTargetLowering *>(
5252- &DAG.getTargetLoweringInfo ());
5253- if (!TLI->allowFMA (DAG.getMachineFunction (), OptLevel))
5254- return SDValue ();
52555269
5256- // For floating point:
5257- // Do the fusion only when the mul has less than 5 uses and all
5258- // are add.
5259- // The heuristic is that if a use is not an add, then that use
5260- // cannot be fused into fma, therefore mul is still needed anyway.
5261- // If there are more than 4 uses, even if they are all add, fusing
5262- // them will increase register pressue.
5263- //
5264- int numUses = 0 ;
5265- int nonAddCount = 0 ;
5266- for (const SDNode *User : N0.getNode ()->uses ()) {
5267- numUses++;
5268- if (User->getOpcode () != ISD::FADD)
5269- ++nonAddCount;
5270- }
5270+ return SDValue ();
5271+ }
5272+
5273+ static SDValue
5274+ PerformFADDCombineWithOperands (SDNode *N, SDValue N0, SDValue N1,
5275+ TargetLowering::DAGCombinerInfo &DCI,
5276+ CodeGenOptLevel OptLevel) {
5277+ EVT VT = N0.getValueType ();
5278+ if (N0.getOpcode () == ISD::FMUL) {
5279+ const auto *TLI = static_cast <const NVPTXTargetLowering *>(
5280+ &DCI.DAG .getTargetLoweringInfo ());
5281+ if (!TLI->allowFMA (DCI.DAG .getMachineFunction (), OptLevel))
5282+ return SDValue ();
5283+
5284+ // For floating point:
5285+ // Do the fusion only when the mul has less than 5 uses and all
5286+ // are add.
5287+ // The heuristic is that if a use is not an add, then that use
5288+ // cannot be fused into fma, therefore mul is still needed anyway.
5289+ // If there are more than 4 uses, even if they are all add, fusing
5290+ // them will increase register pressue.
5291+ //
5292+ int numUses = 0 ;
5293+ int nonAddCount = 0 ;
5294+ for (const SDNode *User : N0.getNode ()->uses ()) {
5295+ numUses++;
5296+ if (User->getOpcode () != ISD::FADD)
5297+ ++nonAddCount;
52715298 if (numUses >= 5 )
52725299 return SDValue ();
5273- if (nonAddCount) {
5274- int orderNo = N->getIROrder ();
5275- int orderNo2 = N0.getNode ()->getIROrder ();
5276- // simple heuristics here for considering potential register
5277- // pressure, the logics here is that the differnce are used
5278- // to measure the distance between def and use, the longer distance
5279- // more likely cause register pressure.
5280- if (orderNo - orderNo2 < 500 )
5281- return SDValue ();
5282-
5283- // Now, check if at least one of the FMUL's operands is live beyond the node N,
5284- // which guarantees that the FMA will not increase register pressure at node N.
5285- bool opIsLive = false ;
5286- const SDNode *left = N0.getOperand (0 ).getNode ();
5287- const SDNode *right = N0.getOperand (1 ).getNode ();
5288-
5289- if (isa<ConstantSDNode>(left) || isa<ConstantSDNode>(right))
5290- opIsLive = true ;
5291-
5292- if (!opIsLive)
5293- for (const SDNode *User : left->uses ()) {
5294- int orderNo3 = User->getIROrder ();
5295- if (orderNo3 > orderNo) {
5296- opIsLive = true ;
5297- break ;
5298- }
5299- }
5300+ }
5301+ if (nonAddCount) {
5302+ int orderNo = N->getIROrder ();
5303+ int orderNo2 = N0.getNode ()->getIROrder ();
5304+ // simple heuristics here for considering potential register
5305+ // pressure, the logics here is that the differnce are used
5306+ // to measure the distance between def and use, the longer distance
5307+ // more likely cause register pressure.
5308+ if (orderNo - orderNo2 < 500 )
5309+ return SDValue ();
53005310
5301- if (!opIsLive)
5302- for (const SDNode *User : right->uses ()) {
5303- int orderNo3 = User->getIROrder ();
5304- if (orderNo3 > orderNo) {
5305- opIsLive = true ;
5306- break ;
5307- }
5311+ // Now, check if at least one of the FMUL's operands is live beyond the
5312+ // node N, which guarantees that the FMA will not increase register
5313+ // pressure at node N.
5314+ bool opIsLive = false ;
5315+ const SDNode *left = N0.getOperand (0 ).getNode ();
5316+ const SDNode *right = N0.getOperand (1 ).getNode ();
5317+
5318+ if (isa<ConstantSDNode>(left) || isa<ConstantSDNode>(right))
5319+ opIsLive = true ;
5320+
5321+ if (!opIsLive)
5322+ for (const SDNode *User : left->uses ()) {
5323+ int orderNo3 = User->getIROrder ();
5324+ if (orderNo3 > orderNo) {
5325+ opIsLive = true ;
5326+ break ;
53085327 }
5328+ }
53095329
5310- if (!opIsLive)
5311- return SDValue ();
5312- }
5330+ if (!opIsLive)
5331+ for (const SDNode *User : right->uses ()) {
5332+ int orderNo3 = User->getIROrder ();
5333+ if (orderNo3 > orderNo) {
5334+ opIsLive = true ;
5335+ break ;
5336+ }
5337+ }
53135338
5314- return DAG. getNode (ISD::FMA, SDLoc (N), VT,
5315- N0. getOperand ( 0 ), N0. getOperand ( 1 ), N1 );
5339+ if (!opIsLive)
5340+ return SDValue ( );
53165341 }
5342+
5343+ return DCI.DAG .getNode (ISD::FMA, SDLoc (N), VT, N0.getOperand (0 ),
5344+ N0.getOperand (1 ), N1);
53175345 }
53185346
53195347 return SDValue ();
@@ -5334,18 +5362,44 @@ static SDValue PerformStoreRetvalCombine(SDNode *N) {
53345362// /
53355363static SDValue PerformADDCombine (SDNode *N,
53365364 TargetLowering::DAGCombinerInfo &DCI,
5337- const NVPTXSubtarget &Subtarget,
5365+ CodeGenOptLevel OptLevel) {
5366+ if (OptLevel == CodeGenOptLevel::None)
5367+ return SDValue ();
5368+
5369+ SDValue N0 = N->getOperand (0 );
5370+ SDValue N1 = N->getOperand (1 );
5371+
5372+ // Skip non-integer, non-scalar case
5373+ EVT VT = N0.getValueType ();
5374+ if (VT.isVector () || VT != MVT::i32 )
5375+ return SDValue ();
5376+
5377+ // First try with the default operand order.
5378+ if (SDValue Result = PerformADDCombineWithOperands (N, N0, N1, DCI))
5379+ return Result;
5380+
5381+ // If that didn't work, try again with the operands commuted.
5382+ return PerformADDCombineWithOperands (N, N1, N0, DCI);
5383+ }
5384+
5385+ // / PerformFADDCombine - Target-specific dag combine xforms for ISD::FADD.
5386+ // /
5387+ static SDValue PerformFADDCombine (SDNode *N,
5388+ TargetLowering::DAGCombinerInfo &DCI,
53385389 CodeGenOptLevel OptLevel) {
53395390 SDValue N0 = N->getOperand (0 );
53405391 SDValue N1 = N->getOperand (1 );
53415392
5393+ EVT VT = N0.getValueType ();
5394+ if (VT.isVector () || !(VT == MVT::f32 || VT == MVT::f64 ))
5395+ return SDValue ();
5396+
53425397 // First try with the default operand order.
5343- if (SDValue Result =
5344- PerformADDCombineWithOperands (N, N0, N1, DCI, Subtarget, OptLevel))
5398+ if (SDValue Result = PerformFADDCombineWithOperands (N, N0, N1, DCI, OptLevel))
53455399 return Result;
53465400
53475401 // If that didn't work, try again with the operands commuted.
5348- return PerformADDCombineWithOperands (N, N1, N0, DCI, Subtarget , OptLevel);
5402+ return PerformFADDCombineWithOperands (N, N1, N0, DCI, OptLevel);
53495403}
53505404
53515405static SDValue PerformANDCombine (SDNode *N,
@@ -5878,8 +5932,9 @@ SDValue NVPTXTargetLowering::PerformDAGCombine(SDNode *N,
58785932 switch (N->getOpcode ()) {
58795933 default : break ;
58805934 case ISD::ADD:
5935+ return PerformADDCombine (N, DCI, OptLevel);
58815936 case ISD::FADD:
5882- return PerformADDCombine (N, DCI, STI , OptLevel);
5937+ return PerformFADDCombine (N, DCI, OptLevel);
58835938 case ISD::MUL:
58845939 return PerformMULCombine (N, DCI, OptLevel);
58855940 case ISD::SHL:
0 commit comments