@@ -486,6 +486,7 @@ namespace {
486486 SDValue visitSIGN_EXTEND_INREG(SDNode *N);
487487 SDValue visitEXTEND_VECTOR_INREG(SDNode *N);
488488 SDValue visitTRUNCATE(SDNode *N);
489+ SDValue visitTRUNCATE_USAT_U(SDNode *N);
489490 SDValue visitBITCAST(SDNode *N);
490491 SDValue visitFREEZE(SDNode *N);
491492 SDValue visitBUILD_PAIR(SDNode *N);
@@ -1910,6 +1911,7 @@ SDValue DAGCombiner::visit(SDNode *N) {
19101911 case ISD::ZERO_EXTEND_VECTOR_INREG:
19111912 case ISD::ANY_EXTEND_VECTOR_INREG: return visitEXTEND_VECTOR_INREG(N);
19121913 case ISD::TRUNCATE: return visitTRUNCATE(N);
1914+ case ISD::TRUNCATE_USAT_U: return visitTRUNCATE_USAT_U(N);
19131915 case ISD::BITCAST: return visitBITCAST(N);
19141916 case ISD::BUILD_PAIR: return visitBUILD_PAIR(N);
19151917 case ISD::FADD: return visitFADD(N);
@@ -13198,7 +13200,9 @@ SDValue DAGCombiner::matchVSelectOpSizesWithSetCC(SDNode *Cast) {
1319813200 unsigned CastOpcode = Cast->getOpcode();
1319913201 assert((CastOpcode == ISD::SIGN_EXTEND || CastOpcode == ISD::ZERO_EXTEND ||
1320013202 CastOpcode == ISD::TRUNCATE || CastOpcode == ISD::FP_EXTEND ||
13201- CastOpcode == ISD::FP_ROUND) &&
13203+ CastOpcode == ISD::TRUNCATE_SSAT_S ||
13204+ CastOpcode == ISD::TRUNCATE_SSAT_U ||
13205+ CastOpcode == ISD::TRUNCATE_USAT_U || CastOpcode == ISD::FP_ROUND) &&
1320213206 "Unexpected opcode for vector select narrowing/widening");
1320313207
1320413208 // We only do this transform before legal ops because the pattern may be
@@ -14910,6 +14914,132 @@ SDValue DAGCombiner::visitEXTEND_VECTOR_INREG(SDNode *N) {
1491014914 return SDValue();
1491114915}
1491214916
14917+ SDValue DAGCombiner::visitTRUNCATE_USAT_U(SDNode *N) {
14918+ EVT VT = N->getValueType(0);
14919+ SDValue N0 = N->getOperand(0);
14920+
14921+ std::function<SDValue(SDValue)> MatchFPTOINT = [&](SDValue Val) -> SDValue {
14922+ if (Val.getOpcode() == ISD::FP_TO_UINT)
14923+ return Val;
14924+ return SDValue();
14925+ };
14926+
14927+ SDValue FPInstr = MatchFPTOINT(N0);
14928+ if (!FPInstr)
14929+ return SDValue();
14930+
14931+ EVT FPVT = FPInstr.getOperand(0).getValueType();
14932+ if (!DAG.getTargetLoweringInfo().shouldConvertFpToSat(ISD::FP_TO_UINT_SAT,
14933+ FPVT, VT))
14934+ return SDValue();
14935+ return DAG.getNode(ISD::FP_TO_UINT_SAT, SDLoc(FPInstr), VT,
14936+ FPInstr.getOperand(0),
14937+ DAG.getValueType(VT.getScalarType()));
14938+ }
14939+
14940+ /// Detect patterns of truncation with unsigned saturation:
14941+ ///
14942+ /// (truncate (umin (x, unsigned_max_of_dest_type)) to dest_type).
14943+ /// Return the source value x to be truncated or SDValue() if the pattern was
14944+ /// not matched.
14945+ ///
14946+ static SDValue detectUSatUPattern(SDValue In, EVT VT) {
14947+ unsigned NumDstBits = VT.getScalarSizeInBits();
14948+ unsigned NumSrcBits = In.getScalarValueSizeInBits();
14949+ // Saturation with truncation. We truncate from InVT to VT.
14950+ assert(NumSrcBits > NumDstBits && "Unexpected types for truncate operation");
14951+
14952+ SDValue Min;
14953+ APInt UnsignedMax = APInt::getMaxValue(NumDstBits).zext(NumSrcBits);
14954+ if (sd_match(In, m_UMin(m_Value(Min), m_SpecificInt(UnsignedMax))))
14955+ return Min;
14956+
14957+ return SDValue();
14958+ }
14959+
14960+ /// Detect patterns of truncation with signed saturation:
14961+ /// (truncate (smin (smax (x, signed_min_of_dest_type),
14962+ /// signed_max_of_dest_type)) to dest_type)
14963+ /// or:
14964+ /// (truncate (smax (smin (x, signed_max_of_dest_type),
14965+ /// signed_min_of_dest_type)) to dest_type).
14966+ ///
14967+ /// Return the source value to be truncated or SDValue() if the pattern was not
14968+ /// matched.
14969+ static SDValue detectSSatSPattern(SDValue In, EVT VT) {
14970+ unsigned NumDstBits = VT.getScalarSizeInBits();
14971+ unsigned NumSrcBits = In.getScalarValueSizeInBits();
14972+ // Saturation with truncation. We truncate from InVT to VT.
14973+ assert(NumSrcBits > NumDstBits && "Unexpected types for truncate operation");
14974+
14975+ SDValue Val;
14976+ APInt SignedMax = APInt::getSignedMaxValue(NumDstBits).sext(NumSrcBits);
14977+ APInt SignedMin = APInt::getSignedMinValue(NumDstBits).sext(NumSrcBits);
14978+
14979+ if (sd_match(In, m_SMin(m_SMax(m_Value(Val), m_SpecificInt(SignedMin)),
14980+ m_SpecificInt(SignedMax))))
14981+ return Val;
14982+
14983+ if (sd_match(In, m_SMax(m_SMin(m_Value(Val), m_SpecificInt(SignedMax)),
14984+ m_SpecificInt(SignedMin))))
14985+ return Val;
14986+
14987+ return SDValue();
14988+ }
14989+
14990+ /// Detect patterns of truncation with unsigned saturation:
14991+ static SDValue detectSSatUPattern(SDValue In, EVT VT, SelectionDAG &DAG,
14992+ const SDLoc &DL) {
14993+ unsigned NumDstBits = VT.getScalarSizeInBits();
14994+ unsigned NumSrcBits = In.getScalarValueSizeInBits();
14995+ // Saturation with truncation. We truncate from InVT to VT.
14996+ assert(NumSrcBits > NumDstBits && "Unexpected types for truncate operation");
14997+
14998+ SDValue Val;
14999+ APInt UnsignedMax = APInt::getMaxValue(NumDstBits).zext(NumSrcBits);
15000+ // Min == 0, Max is unsigned max of destination type.
15001+ if (sd_match(In, m_SMax(m_SMin(m_Value(Val), m_SpecificInt(UnsignedMax)),
15002+ m_Zero())))
15003+ return Val;
15004+
15005+ if (sd_match(In, m_SMin(m_SMax(m_Value(Val), m_Zero()),
15006+ m_SpecificInt(UnsignedMax))))
15007+ return Val;
15008+
15009+ if (sd_match(In, m_UMin(m_SMax(m_Value(Val), m_Zero()),
15010+ m_SpecificInt(UnsignedMax))))
15011+ return Val;
15012+
15013+ return SDValue();
15014+ }
15015+
15016+ static SDValue foldToSaturated(SDNode *N, EVT &VT, SDValue &Src, EVT &SrcVT,
15017+ SDLoc &DL, const TargetLowering &TLI,
15018+ SelectionDAG &DAG) {
15019+ auto AllowedTruncateSat = [&](unsigned Opc, EVT SrcVT, EVT VT) -> bool {
15020+ return (TLI.isOperationLegalOrCustom(Opc, SrcVT) &&
15021+ TLI.isTypeDesirableForOp(Opc, VT));
15022+ };
15023+
15024+ if (Src.getOpcode() == ISD::SMIN || Src.getOpcode() == ISD::SMAX) {
15025+ if (AllowedTruncateSat(ISD::TRUNCATE_SSAT_S, SrcVT, VT))
15026+ if (SDValue SSatVal = detectSSatSPattern(Src, VT))
15027+ return DAG.getNode(ISD::TRUNCATE_SSAT_S, DL, VT, SSatVal);
15028+ if (AllowedTruncateSat(ISD::TRUNCATE_SSAT_U, SrcVT, VT))
15029+ if (SDValue SSatVal = detectSSatUPattern(Src, VT, DAG, DL))
15030+ return DAG.getNode(ISD::TRUNCATE_SSAT_U, DL, VT, SSatVal);
15031+ } else if (Src.getOpcode() == ISD::UMIN) {
15032+ if (AllowedTruncateSat(ISD::TRUNCATE_SSAT_U, SrcVT, VT))
15033+ if (SDValue SSatVal = detectSSatUPattern(Src, VT, DAG, DL))
15034+ return DAG.getNode(ISD::TRUNCATE_SSAT_U, DL, VT, SSatVal);
15035+ if (AllowedTruncateSat(ISD::TRUNCATE_USAT_U, SrcVT, VT))
15036+ if (SDValue USatVal = detectUSatUPattern(Src, VT))
15037+ return DAG.getNode(ISD::TRUNCATE_USAT_U, DL, VT, USatVal);
15038+ }
15039+
15040+ return SDValue();
15041+ }
15042+
1491315043SDValue DAGCombiner::visitTRUNCATE(SDNode *N) {
1491415044 SDValue N0 = N->getOperand(0);
1491515045 EVT VT = N->getValueType(0);
@@ -14925,6 +15055,10 @@ SDValue DAGCombiner::visitTRUNCATE(SDNode *N) {
1492515055 if (N0.getOpcode() == ISD::TRUNCATE)
1492615056 return DAG.getNode(ISD::TRUNCATE, DL, VT, N0.getOperand(0));
1492715057
15058+ // fold saturated truncate
15059+ if (SDValue SaturatedTR = foldToSaturated(N, VT, N0, SrcVT, DL, TLI, DAG))
15060+ return SaturatedTR;
15061+
1492815062 // fold (truncate c1) -> c1
1492915063 if (SDValue C = DAG.FoldConstantArithmetic(ISD::TRUNCATE, DL, VT, {N0}))
1493015064 return C;
0 commit comments