@@ -16936,18 +16936,9 @@ struct NodeExtensionHelper {
16936
16936
case RISCVISD::VWSUBU_W_VL:
16937
16937
case RISCVISD::VFWADD_W_VL:
16938
16938
case RISCVISD::VFWSUB_W_VL:
16939
- if (OperandIdx == 1) {
16940
- SupportsZExt =
16941
- Opc == RISCVISD::VWADDU_W_VL || Opc == RISCVISD::VWSUBU_W_VL;
16942
- SupportsSExt =
16943
- Opc == RISCVISD::VWADD_W_VL || Opc == RISCVISD::VWSUB_W_VL;
16944
- SupportsFPExt =
16945
- Opc == RISCVISD::VFWADD_W_VL || Opc == RISCVISD::VFWSUB_W_VL;
16946
- // There's no existing extension here, so we don't have to worry about
16947
- // making sure it gets removed.
16948
- EnforceOneUse = false;
16939
+ // Operand 1 can't be changed.
16940
+ if (OperandIdx == 1)
16949
16941
break;
16950
- }
16951
16942
[[fallthrough]];
16952
16943
default:
16953
16944
fillUpExtensionSupport(Root, DAG, Subtarget);
@@ -16985,20 +16976,20 @@ struct NodeExtensionHelper {
16985
16976
case RISCVISD::ADD_VL:
16986
16977
case RISCVISD::MUL_VL:
16987
16978
case RISCVISD::OR_VL:
16988
- case RISCVISD::VWADD_W_VL:
16989
- case RISCVISD::VWADDU_W_VL:
16990
16979
case RISCVISD::FADD_VL:
16991
16980
case RISCVISD::FMUL_VL:
16992
- case RISCVISD::VFWADD_W_VL:
16993
16981
case RISCVISD::VFMADD_VL:
16994
16982
case RISCVISD::VFNMSUB_VL:
16995
16983
case RISCVISD::VFNMADD_VL:
16996
16984
case RISCVISD::VFMSUB_VL:
16997
16985
return true;
16986
+ case RISCVISD::VWADD_W_VL:
16987
+ case RISCVISD::VWADDU_W_VL:
16998
16988
case ISD::SUB:
16999
16989
case RISCVISD::SUB_VL:
17000
16990
case RISCVISD::VWSUB_W_VL:
17001
16991
case RISCVISD::VWSUBU_W_VL:
16992
+ case RISCVISD::VFWADD_W_VL:
17002
16993
case RISCVISD::FSUB_VL:
17003
16994
case RISCVISD::VFWSUB_W_VL:
17004
16995
case ISD::SHL:
@@ -17117,6 +17108,30 @@ canFoldToVWWithSameExtension(SDNode *Root, const NodeExtensionHelper &LHS,
17117
17108
Subtarget);
17118
17109
}
17119
17110
17111
+ /// Check if \p Root follows a pattern Root(zext(LHS), zext(RHS))
17112
+ ///
17113
+ /// \returns std::nullopt if the pattern doesn't match or a CombineResult that
17114
+ /// can be used to apply the pattern.
17115
+ static std::optional<CombineResult>
17116
+ canFoldToVWWithSameExtZEXT(SDNode *Root, const NodeExtensionHelper &LHS,
17117
+ const NodeExtensionHelper &RHS, SelectionDAG &DAG,
17118
+ const RISCVSubtarget &Subtarget) {
17119
+ return canFoldToVWWithSameExtensionImpl(Root, LHS, RHS, ExtKind::ZExt, DAG,
17120
+ Subtarget);
17121
+ }
17122
+
17123
+ /// Check if \p Root follows a pattern Root(bf16ext(LHS), bf16ext(RHS))
17124
+ ///
17125
+ /// \returns std::nullopt if the pattern doesn't match or a CombineResult that
17126
+ /// can be used to apply the pattern.
17127
+ static std::optional<CombineResult>
17128
+ canFoldToVWWithSameExtBF16(SDNode *Root, const NodeExtensionHelper &LHS,
17129
+ const NodeExtensionHelper &RHS, SelectionDAG &DAG,
17130
+ const RISCVSubtarget &Subtarget) {
17131
+ return canFoldToVWWithSameExtensionImpl(Root, LHS, RHS, ExtKind::BF16Ext, DAG,
17132
+ Subtarget);
17133
+ }
17134
+
17120
17135
/// Check if \p Root follows a pattern Root(LHS, ext(RHS))
17121
17136
///
17122
17137
/// \returns std::nullopt if the pattern doesn't match or a CombineResult that
@@ -17145,52 +17160,49 @@ canFoldToVW_W(SDNode *Root, const NodeExtensionHelper &LHS,
17145
17160
return std::nullopt;
17146
17161
}
17147
17162
17148
- /// Check if \p Root follows a pattern Root(sext(LHS), sext( RHS) )
17163
+ /// Check if \p Root follows a pattern Root(sext(LHS), RHS)
17149
17164
///
17150
17165
/// \returns std::nullopt if the pattern doesn't match or a CombineResult that
17151
17166
/// can be used to apply the pattern.
17152
17167
static std::optional<CombineResult>
17153
17168
canFoldToVWWithSEXT(SDNode *Root, const NodeExtensionHelper &LHS,
17154
17169
const NodeExtensionHelper &RHS, SelectionDAG &DAG,
17155
17170
const RISCVSubtarget &Subtarget) {
17156
- return canFoldToVWWithSameExtensionImpl(Root, LHS, RHS, ExtKind::SExt, DAG,
17157
- Subtarget);
17171
+ if (LHS.SupportsSExt)
17172
+ return CombineResult(NodeExtensionHelper::getSExtOpcode(Root->getOpcode()),
17173
+ Root, LHS, /*LHSExt=*/{ExtKind::SExt}, RHS,
17174
+ /*RHSExt=*/std::nullopt);
17175
+ return std::nullopt;
17158
17176
}
17159
17177
17160
- /// Check if \p Root follows a pattern Root(zext(LHS), zext( RHS) )
17178
+ /// Check if \p Root follows a pattern Root(zext(LHS), RHS)
17161
17179
///
17162
17180
/// \returns std::nullopt if the pattern doesn't match or a CombineResult that
17163
17181
/// can be used to apply the pattern.
17164
17182
static std::optional<CombineResult>
17165
17183
canFoldToVWWithZEXT(SDNode *Root, const NodeExtensionHelper &LHS,
17166
17184
const NodeExtensionHelper &RHS, SelectionDAG &DAG,
17167
17185
const RISCVSubtarget &Subtarget) {
17168
- return canFoldToVWWithSameExtensionImpl(Root, LHS, RHS, ExtKind::ZExt, DAG,
17169
- Subtarget);
17186
+ if (LHS.SupportsZExt)
17187
+ return CombineResult(NodeExtensionHelper::getZExtOpcode(Root->getOpcode()),
17188
+ Root, LHS, /*LHSExt=*/{ExtKind::ZExt}, RHS,
17189
+ /*RHSExt=*/std::nullopt);
17190
+ return std::nullopt;
17170
17191
}
17171
17192
17172
- /// Check if \p Root follows a pattern Root(fpext(LHS), fpext( RHS) )
17193
+ /// Check if \p Root follows a pattern Root(fpext(LHS), RHS)
17173
17194
///
17174
17195
/// \returns std::nullopt if the pattern doesn't match or a CombineResult that
17175
17196
/// can be used to apply the pattern.
17176
17197
static std::optional<CombineResult>
17177
17198
canFoldToVWWithFPEXT(SDNode *Root, const NodeExtensionHelper &LHS,
17178
17199
const NodeExtensionHelper &RHS, SelectionDAG &DAG,
17179
17200
const RISCVSubtarget &Subtarget) {
17180
- return canFoldToVWWithSameExtensionImpl(Root, LHS, RHS, ExtKind::FPExt, DAG,
17181
- Subtarget);
17182
- }
17183
-
17184
- /// Check if \p Root follows a pattern Root(bf16ext(LHS), bf16ext(RHS))
17185
- ///
17186
- /// \returns std::nullopt if the pattern doesn't match or a CombineResult that
17187
- /// can be used to apply the pattern.
17188
- static std::optional<CombineResult>
17189
- canFoldToVWWithBF16EXT(SDNode *Root, const NodeExtensionHelper &LHS,
17190
- const NodeExtensionHelper &RHS, SelectionDAG &DAG,
17191
- const RISCVSubtarget &Subtarget) {
17192
- return canFoldToVWWithSameExtensionImpl(Root, LHS, RHS, ExtKind::BF16Ext, DAG,
17193
- Subtarget);
17201
+ if (LHS.SupportsFPExt)
17202
+ return CombineResult(NodeExtensionHelper::getFPExtOpcode(Root->getOpcode()),
17203
+ Root, LHS, /*LHSExt=*/{ExtKind::FPExt}, RHS,
17204
+ /*RHSExt=*/std::nullopt);
17205
+ return std::nullopt;
17194
17206
}
17195
17207
17196
17208
/// Check if \p Root follows a pattern Root(sext(LHS), zext(RHS))
@@ -17233,7 +17245,7 @@ NodeExtensionHelper::getSupportedFoldings(const SDNode *Root) {
17233
17245
case RISCVISD::VFNMSUB_VL:
17234
17246
Strategies.push_back(canFoldToVWWithSameExtension);
17235
17247
if (Root->getOpcode() == RISCVISD::VFMADD_VL)
17236
- Strategies.push_back(canFoldToVWWithBF16EXT );
17248
+ Strategies.push_back(canFoldToVWWithSameExtBF16 );
17237
17249
break;
17238
17250
case ISD::MUL:
17239
17251
case RISCVISD::MUL_VL:
@@ -17245,7 +17257,7 @@ NodeExtensionHelper::getSupportedFoldings(const SDNode *Root) {
17245
17257
case ISD::SHL:
17246
17258
case RISCVISD::SHL_VL:
17247
17259
// shl -> vwsll
17248
- Strategies.push_back(canFoldToVWWithZEXT );
17260
+ Strategies.push_back(canFoldToVWWithSameExtZEXT );
17249
17261
break;
17250
17262
case RISCVISD::VWADD_W_VL:
17251
17263
case RISCVISD::VWSUB_W_VL:
0 commit comments