Skip to content

Commit 15a3183

Browse files
topperctstellar
authored andcommitted
[RISCV] Re-work how VWADD_W_VL and similar _W_VL nodes are handled in combineOp_VLToVWOp_VL. (llvm#159205)
These instructions have one already narrow operand. Previously, we pretended like this operand was a supported extension. This could cause problems when we called getOrCreateExtendedOp on this narrow operand when creating the the VWADD_VL. If the narrow operand happened to be an extend of the opposite type, we would peek through it and then rebuild it with the wrong extension type. So (vwadd_w_vl (i32 (sext X)), (i16 (zext Y))) would become (vwadd_vl (i16 (sext X)), (i16 (sext Y))). To prevent this, we ignore the operand instead and pass std::nullopt for SupportsExt to getOrCreateExtendedOp so it won't peek through any extends on the narrow source. Fixes llvm#159152. (cherry picked from commit 6119d1f)
1 parent f089fb2 commit 15a3183

File tree

2 files changed

+72
-37
lines changed

2 files changed

+72
-37
lines changed

llvm/lib/Target/RISCV/RISCVISelLowering.cpp

Lines changed: 49 additions & 37 deletions
Original file line numberDiff line numberDiff line change
@@ -16936,18 +16936,9 @@ struct NodeExtensionHelper {
1693616936
case RISCVISD::VWSUBU_W_VL:
1693716937
case RISCVISD::VFWADD_W_VL:
1693816938
case RISCVISD::VFWSUB_W_VL:
16939-
if (OperandIdx == 1) {
16940-
SupportsZExt =
16941-
Opc == RISCVISD::VWADDU_W_VL || Opc == RISCVISD::VWSUBU_W_VL;
16942-
SupportsSExt =
16943-
Opc == RISCVISD::VWADD_W_VL || Opc == RISCVISD::VWSUB_W_VL;
16944-
SupportsFPExt =
16945-
Opc == RISCVISD::VFWADD_W_VL || Opc == RISCVISD::VFWSUB_W_VL;
16946-
// There's no existing extension here, so we don't have to worry about
16947-
// making sure it gets removed.
16948-
EnforceOneUse = false;
16939+
// Operand 1 can't be changed.
16940+
if (OperandIdx == 1)
1694916941
break;
16950-
}
1695116942
[[fallthrough]];
1695216943
default:
1695316944
fillUpExtensionSupport(Root, DAG, Subtarget);
@@ -16985,20 +16976,20 @@ struct NodeExtensionHelper {
1698516976
case RISCVISD::ADD_VL:
1698616977
case RISCVISD::MUL_VL:
1698716978
case RISCVISD::OR_VL:
16988-
case RISCVISD::VWADD_W_VL:
16989-
case RISCVISD::VWADDU_W_VL:
1699016979
case RISCVISD::FADD_VL:
1699116980
case RISCVISD::FMUL_VL:
16992-
case RISCVISD::VFWADD_W_VL:
1699316981
case RISCVISD::VFMADD_VL:
1699416982
case RISCVISD::VFNMSUB_VL:
1699516983
case RISCVISD::VFNMADD_VL:
1699616984
case RISCVISD::VFMSUB_VL:
1699716985
return true;
16986+
case RISCVISD::VWADD_W_VL:
16987+
case RISCVISD::VWADDU_W_VL:
1699816988
case ISD::SUB:
1699916989
case RISCVISD::SUB_VL:
1700016990
case RISCVISD::VWSUB_W_VL:
1700116991
case RISCVISD::VWSUBU_W_VL:
16992+
case RISCVISD::VFWADD_W_VL:
1700216993
case RISCVISD::FSUB_VL:
1700316994
case RISCVISD::VFWSUB_W_VL:
1700416995
case ISD::SHL:
@@ -17117,6 +17108,30 @@ canFoldToVWWithSameExtension(SDNode *Root, const NodeExtensionHelper &LHS,
1711717108
Subtarget);
1711817109
}
1711917110

17111+
/// Check if \p Root follows a pattern Root(zext(LHS), zext(RHS))
17112+
///
17113+
/// \returns std::nullopt if the pattern doesn't match or a CombineResult that
17114+
/// can be used to apply the pattern.
17115+
static std::optional<CombineResult>
17116+
canFoldToVWWithSameExtZEXT(SDNode *Root, const NodeExtensionHelper &LHS,
17117+
const NodeExtensionHelper &RHS, SelectionDAG &DAG,
17118+
const RISCVSubtarget &Subtarget) {
17119+
return canFoldToVWWithSameExtensionImpl(Root, LHS, RHS, ExtKind::ZExt, DAG,
17120+
Subtarget);
17121+
}
17122+
17123+
/// Check if \p Root follows a pattern Root(bf16ext(LHS), bf16ext(RHS))
17124+
///
17125+
/// \returns std::nullopt if the pattern doesn't match or a CombineResult that
17126+
/// can be used to apply the pattern.
17127+
static std::optional<CombineResult>
17128+
canFoldToVWWithSameExtBF16(SDNode *Root, const NodeExtensionHelper &LHS,
17129+
const NodeExtensionHelper &RHS, SelectionDAG &DAG,
17130+
const RISCVSubtarget &Subtarget) {
17131+
return canFoldToVWWithSameExtensionImpl(Root, LHS, RHS, ExtKind::BF16Ext, DAG,
17132+
Subtarget);
17133+
}
17134+
1712017135
/// Check if \p Root follows a pattern Root(LHS, ext(RHS))
1712117136
///
1712217137
/// \returns std::nullopt if the pattern doesn't match or a CombineResult that
@@ -17145,52 +17160,49 @@ canFoldToVW_W(SDNode *Root, const NodeExtensionHelper &LHS,
1714517160
return std::nullopt;
1714617161
}
1714717162

17148-
/// Check if \p Root follows a pattern Root(sext(LHS), sext(RHS))
17163+
/// Check if \p Root follows a pattern Root(sext(LHS), RHS)
1714917164
///
1715017165
/// \returns std::nullopt if the pattern doesn't match or a CombineResult that
1715117166
/// can be used to apply the pattern.
1715217167
static std::optional<CombineResult>
1715317168
canFoldToVWWithSEXT(SDNode *Root, const NodeExtensionHelper &LHS,
1715417169
const NodeExtensionHelper &RHS, SelectionDAG &DAG,
1715517170
const RISCVSubtarget &Subtarget) {
17156-
return canFoldToVWWithSameExtensionImpl(Root, LHS, RHS, ExtKind::SExt, DAG,
17157-
Subtarget);
17171+
if (LHS.SupportsSExt)
17172+
return CombineResult(NodeExtensionHelper::getSExtOpcode(Root->getOpcode()),
17173+
Root, LHS, /*LHSExt=*/{ExtKind::SExt}, RHS,
17174+
/*RHSExt=*/std::nullopt);
17175+
return std::nullopt;
1715817176
}
1715917177

17160-
/// Check if \p Root follows a pattern Root(zext(LHS), zext(RHS))
17178+
/// Check if \p Root follows a pattern Root(zext(LHS), RHS)
1716117179
///
1716217180
/// \returns std::nullopt if the pattern doesn't match or a CombineResult that
1716317181
/// can be used to apply the pattern.
1716417182
static std::optional<CombineResult>
1716517183
canFoldToVWWithZEXT(SDNode *Root, const NodeExtensionHelper &LHS,
1716617184
const NodeExtensionHelper &RHS, SelectionDAG &DAG,
1716717185
const RISCVSubtarget &Subtarget) {
17168-
return canFoldToVWWithSameExtensionImpl(Root, LHS, RHS, ExtKind::ZExt, DAG,
17169-
Subtarget);
17186+
if (LHS.SupportsZExt)
17187+
return CombineResult(NodeExtensionHelper::getZExtOpcode(Root->getOpcode()),
17188+
Root, LHS, /*LHSExt=*/{ExtKind::ZExt}, RHS,
17189+
/*RHSExt=*/std::nullopt);
17190+
return std::nullopt;
1717017191
}
1717117192

17172-
/// Check if \p Root follows a pattern Root(fpext(LHS), fpext(RHS))
17193+
/// Check if \p Root follows a pattern Root(fpext(LHS), RHS)
1717317194
///
1717417195
/// \returns std::nullopt if the pattern doesn't match or a CombineResult that
1717517196
/// can be used to apply the pattern.
1717617197
static std::optional<CombineResult>
1717717198
canFoldToVWWithFPEXT(SDNode *Root, const NodeExtensionHelper &LHS,
1717817199
const NodeExtensionHelper &RHS, SelectionDAG &DAG,
1717917200
const RISCVSubtarget &Subtarget) {
17180-
return canFoldToVWWithSameExtensionImpl(Root, LHS, RHS, ExtKind::FPExt, DAG,
17181-
Subtarget);
17182-
}
17183-
17184-
/// Check if \p Root follows a pattern Root(bf16ext(LHS), bf16ext(RHS))
17185-
///
17186-
/// \returns std::nullopt if the pattern doesn't match or a CombineResult that
17187-
/// can be used to apply the pattern.
17188-
static std::optional<CombineResult>
17189-
canFoldToVWWithBF16EXT(SDNode *Root, const NodeExtensionHelper &LHS,
17190-
const NodeExtensionHelper &RHS, SelectionDAG &DAG,
17191-
const RISCVSubtarget &Subtarget) {
17192-
return canFoldToVWWithSameExtensionImpl(Root, LHS, RHS, ExtKind::BF16Ext, DAG,
17193-
Subtarget);
17201+
if (LHS.SupportsFPExt)
17202+
return CombineResult(NodeExtensionHelper::getFPExtOpcode(Root->getOpcode()),
17203+
Root, LHS, /*LHSExt=*/{ExtKind::FPExt}, RHS,
17204+
/*RHSExt=*/std::nullopt);
17205+
return std::nullopt;
1719417206
}
1719517207

1719617208
/// Check if \p Root follows a pattern Root(sext(LHS), zext(RHS))
@@ -17233,7 +17245,7 @@ NodeExtensionHelper::getSupportedFoldings(const SDNode *Root) {
1723317245
case RISCVISD::VFNMSUB_VL:
1723417246
Strategies.push_back(canFoldToVWWithSameExtension);
1723517247
if (Root->getOpcode() == RISCVISD::VFMADD_VL)
17236-
Strategies.push_back(canFoldToVWWithBF16EXT);
17248+
Strategies.push_back(canFoldToVWWithSameExtBF16);
1723717249
break;
1723817250
case ISD::MUL:
1723917251
case RISCVISD::MUL_VL:
@@ -17245,7 +17257,7 @@ NodeExtensionHelper::getSupportedFoldings(const SDNode *Root) {
1724517257
case ISD::SHL:
1724617258
case RISCVISD::SHL_VL:
1724717259
// shl -> vwsll
17248-
Strategies.push_back(canFoldToVWWithZEXT);
17260+
Strategies.push_back(canFoldToVWWithSameExtZEXT);
1724917261
break;
1725017262
case RISCVISD::VWADD_W_VL:
1725117263
case RISCVISD::VWSUB_W_VL:

llvm/test/CodeGen/RISCV/rvv/fixed-vectors-vw-web-simplification.ll

Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -58,3 +58,26 @@ define <2 x i16> @vwmul_v2i16_multiple_users(ptr %x, ptr %y, ptr %z) {
5858
%i = or <2 x i16> %h, %g
5959
ret <2 x i16> %i
6060
}
61+
62+
; Make sure we have a vsext.vl and a vwaddu.vx.
63+
define <4 x i32> @pr159152(<4 x i8> %x) {
64+
; NO_FOLDING-LABEL: pr159152:
65+
; NO_FOLDING: # %bb.0:
66+
; NO_FOLDING-NEXT: vsetivli zero, 4, e16, mf2, ta, ma
67+
; NO_FOLDING-NEXT: vsext.vf2 v9, v8
68+
; NO_FOLDING-NEXT: li a0, 9
69+
; NO_FOLDING-NEXT: vwaddu.vx v8, v9, a0
70+
; NO_FOLDING-NEXT: ret
71+
;
72+
; FOLDING-LABEL: pr159152:
73+
; FOLDING: # %bb.0:
74+
; FOLDING-NEXT: vsetivli zero, 4, e16, mf2, ta, ma
75+
; FOLDING-NEXT: vsext.vf2 v9, v8
76+
; FOLDING-NEXT: li a0, 9
77+
; FOLDING-NEXT: vwaddu.vx v8, v9, a0
78+
; FOLDING-NEXT: ret
79+
%a = sext <4 x i8> %x to <4 x i16>
80+
%b = zext <4 x i16> %a to <4 x i32>
81+
%c = add <4 x i32> %b, <i32 9, i32 9, i32 9, i32 9>
82+
ret <4 x i32> %c
83+
}

0 commit comments

Comments
 (0)