Skip to content

Commit a4b45c2

Browse files
committed
[RISCV] Allow fractional LMUL for reduction start value
For reductions, we need to put the start value into a source vector. For fractional LMULs, we can perform the operation at the original LMUL. For LMUL > 1, we eventually want to use a scalar insert, but that's outside the scope of this patch. Differential Revision: https://reviews.llvm.org/D139747
1 parent 81084bf commit a4b45c2

File tree

9 files changed

+245
-244
lines changed

9 files changed

+245
-244
lines changed

llvm/lib/Target/RISCV/RISCVISelLowering.cpp

Lines changed: 11 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -5814,9 +5814,16 @@ static SDValue lowerReductionSeq(unsigned RVVOpcode, SDValue StartValue,
58145814
const MVT M1VT = getLMUL1VT(VecVT);
58155815
const MVT XLenVT = Subtarget.getXLenVT();
58165816

5817+
// The reduction needs an LMUL1 input; do the splat at either LMUL1
5818+
// or the original VT if fractional.
5819+
auto InnerVT = VecVT.bitsLE(M1VT) ? VecVT : M1VT;
58175820
SDValue InitialSplat =
58185821
lowerScalarSplat(SDValue(), StartValue, DAG.getConstant(1, DL, XLenVT),
5819-
M1VT, DL, DAG, Subtarget);
5822+
InnerVT, DL, DAG, Subtarget);
5823+
if (M1VT != InnerVT)
5824+
InitialSplat = DAG.getNode(ISD::INSERT_SUBVECTOR, DL, M1VT,
5825+
DAG.getUNDEF(M1VT),
5826+
InitialSplat, DAG.getConstant(0, DL, XLenVT));
58205827
SDValue PassThru = hasNonZeroAVL(VL) ? DAG.getUNDEF(M1VT) : InitialSplat;
58215828
SDValue Reduction = DAG.getNode(RVVOpcode, DL, M1VT, PassThru, Vec,
58225829
InitialSplat, Mask, VL);
@@ -8014,6 +8021,9 @@ static SDValue combineBinOpToReduce(SDNode *N, SelectionDAG &DAG) {
80148021
return SDValue();
80158022

80168023
SDValue ScalarV = Reduce.getOperand(2);
8024+
if (ScalarV.getOpcode() == ISD::INSERT_SUBVECTOR &&
8025+
ScalarV.getOperand(0)->isUndef())
8026+
ScalarV = ScalarV.getOperand(1);
80178027

80188028
// Make sure that ScalarV is a splat with VL=1.
80198029
if (ScalarV.getOpcode() != RISCVISD::VFMV_S_F_VL &&

llvm/test/CodeGen/RISCV/rvv/fixed-vectors-reduction-fp-vp.ll

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,7 @@ declare half @llvm.vp.reduce.fadd.v2f16(half, <2 x half>, <2 x i1>, i32)
99
define half @vpreduce_fadd_v2f16(half %s, <2 x half> %v, <2 x i1> %m, i32 zeroext %evl) {
1010
; CHECK-LABEL: vpreduce_fadd_v2f16:
1111
; CHECK: # %bb.0:
12-
; CHECK-NEXT: vsetivli zero, 1, e16, m1, ta, ma
12+
; CHECK-NEXT: vsetivli zero, 1, e16, mf4, ta, ma
1313
; CHECK-NEXT: vfmv.s.f v9, fa0
1414
; CHECK-NEXT: vsetvli zero, a0, e16, mf4, tu, ma
1515
; CHECK-NEXT: vfredusum.vs v9, v8, v9, v0.t
@@ -22,7 +22,7 @@ define half @vpreduce_fadd_v2f16(half %s, <2 x half> %v, <2 x i1> %m, i32 zeroex
2222
define half @vpreduce_ord_fadd_v2f16(half %s, <2 x half> %v, <2 x i1> %m, i32 zeroext %evl) {
2323
; CHECK-LABEL: vpreduce_ord_fadd_v2f16:
2424
; CHECK: # %bb.0:
25-
; CHECK-NEXT: vsetivli zero, 1, e16, m1, ta, ma
25+
; CHECK-NEXT: vsetivli zero, 1, e16, mf4, ta, ma
2626
; CHECK-NEXT: vfmv.s.f v9, fa0
2727
; CHECK-NEXT: vsetvli zero, a0, e16, mf4, tu, ma
2828
; CHECK-NEXT: vfredosum.vs v9, v8, v9, v0.t
@@ -37,7 +37,7 @@ declare half @llvm.vp.reduce.fadd.v4f16(half, <4 x half>, <4 x i1>, i32)
3737
define half @vpreduce_fadd_v4f16(half %s, <4 x half> %v, <4 x i1> %m, i32 zeroext %evl) {
3838
; CHECK-LABEL: vpreduce_fadd_v4f16:
3939
; CHECK: # %bb.0:
40-
; CHECK-NEXT: vsetivli zero, 1, e16, m1, ta, ma
40+
; CHECK-NEXT: vsetivli zero, 1, e16, mf2, ta, ma
4141
; CHECK-NEXT: vfmv.s.f v9, fa0
4242
; CHECK-NEXT: vsetvli zero, a0, e16, mf2, tu, ma
4343
; CHECK-NEXT: vfredusum.vs v9, v8, v9, v0.t
@@ -50,7 +50,7 @@ define half @vpreduce_fadd_v4f16(half %s, <4 x half> %v, <4 x i1> %m, i32 zeroex
5050
define half @vpreduce_ord_fadd_v4f16(half %s, <4 x half> %v, <4 x i1> %m, i32 zeroext %evl) {
5151
; CHECK-LABEL: vpreduce_ord_fadd_v4f16:
5252
; CHECK: # %bb.0:
53-
; CHECK-NEXT: vsetivli zero, 1, e16, m1, ta, ma
53+
; CHECK-NEXT: vsetivli zero, 1, e16, mf2, ta, ma
5454
; CHECK-NEXT: vfmv.s.f v9, fa0
5555
; CHECK-NEXT: vsetvli zero, a0, e16, mf2, tu, ma
5656
; CHECK-NEXT: vfredosum.vs v9, v8, v9, v0.t
@@ -65,7 +65,7 @@ declare float @llvm.vp.reduce.fadd.v2f32(float, <2 x float>, <2 x i1>, i32)
6565
define float @vpreduce_fadd_v2f32(float %s, <2 x float> %v, <2 x i1> %m, i32 zeroext %evl) {
6666
; CHECK-LABEL: vpreduce_fadd_v2f32:
6767
; CHECK: # %bb.0:
68-
; CHECK-NEXT: vsetivli zero, 1, e32, m1, ta, ma
68+
; CHECK-NEXT: vsetivli zero, 1, e32, mf2, ta, ma
6969
; CHECK-NEXT: vfmv.s.f v9, fa0
7070
; CHECK-NEXT: vsetvli zero, a0, e32, mf2, tu, ma
7171
; CHECK-NEXT: vfredusum.vs v9, v8, v9, v0.t
@@ -78,7 +78,7 @@ define float @vpreduce_fadd_v2f32(float %s, <2 x float> %v, <2 x i1> %m, i32 zer
7878
define float @vpreduce_ord_fadd_v2f32(float %s, <2 x float> %v, <2 x i1> %m, i32 zeroext %evl) {
7979
; CHECK-LABEL: vpreduce_ord_fadd_v2f32:
8080
; CHECK: # %bb.0:
81-
; CHECK-NEXT: vsetivli zero, 1, e32, m1, ta, ma
81+
; CHECK-NEXT: vsetivli zero, 1, e32, mf2, ta, ma
8282
; CHECK-NEXT: vfmv.s.f v9, fa0
8383
; CHECK-NEXT: vsetvli zero, a0, e32, mf2, tu, ma
8484
; CHECK-NEXT: vfredosum.vs v9, v8, v9, v0.t

llvm/test/CodeGen/RISCV/rvv/fixed-vectors-reduction-fp.ll

Lines changed: 16 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -317,11 +317,10 @@ define float @vreduce_fwadd_v1f32(<1 x half>* %x, float %s) {
317317
define float @vreduce_ord_fwadd_v1f32(<1 x half>* %x, float %s) {
318318
; CHECK-LABEL: vreduce_ord_fwadd_v1f32:
319319
; CHECK: # %bb.0:
320-
; CHECK-NEXT: vsetivli zero, 1, e16, mf4, ta, ma
320+
; CHECK-NEXT: vsetivli zero, 1, e32, mf2, ta, ma
321321
; CHECK-NEXT: vle16.v v8, (a0)
322-
; CHECK-NEXT: vsetivli zero, 1, e32, m1, ta, ma
323322
; CHECK-NEXT: vfmv.s.f v9, fa0
324-
; CHECK-NEXT: vsetivli zero, 1, e16, mf4, ta, ma
323+
; CHECK-NEXT: vsetvli zero, zero, e16, mf4, ta, ma
325324
; CHECK-NEXT: vfwredosum.vs v8, v8, v9
326325
; CHECK-NEXT: vsetivli zero, 0, e32, m1, ta, ma
327326
; CHECK-NEXT: vfmv.f.s fa0, v8
@@ -365,11 +364,10 @@ define float @vreduce_ord_fadd_v2f32(<2 x float>* %x, float %s) {
365364
define float @vreduce_fwadd_v2f32(<2 x half>* %x, float %s) {
366365
; CHECK-LABEL: vreduce_fwadd_v2f32:
367366
; CHECK: # %bb.0:
368-
; CHECK-NEXT: vsetivli zero, 2, e16, mf4, ta, ma
367+
; CHECK-NEXT: vsetivli zero, 2, e32, mf2, ta, ma
369368
; CHECK-NEXT: vle16.v v8, (a0)
370-
; CHECK-NEXT: vsetivli zero, 1, e32, m1, ta, ma
371369
; CHECK-NEXT: vfmv.s.f v9, fa0
372-
; CHECK-NEXT: vsetivli zero, 2, e16, mf4, ta, ma
370+
; CHECK-NEXT: vsetvli zero, zero, e16, mf4, ta, ma
373371
; CHECK-NEXT: vfwredusum.vs v8, v8, v9
374372
; CHECK-NEXT: vsetivli zero, 0, e32, m1, ta, ma
375373
; CHECK-NEXT: vfmv.f.s fa0, v8
@@ -383,11 +381,10 @@ define float @vreduce_fwadd_v2f32(<2 x half>* %x, float %s) {
383381
define float @vreduce_ord_fwadd_v2f32(<2 x half>* %x, float %s) {
384382
; CHECK-LABEL: vreduce_ord_fwadd_v2f32:
385383
; CHECK: # %bb.0:
386-
; CHECK-NEXT: vsetivli zero, 2, e16, mf4, ta, ma
384+
; CHECK-NEXT: vsetivli zero, 2, e32, mf2, ta, ma
387385
; CHECK-NEXT: vle16.v v8, (a0)
388-
; CHECK-NEXT: vsetivli zero, 1, e32, m1, ta, ma
389386
; CHECK-NEXT: vfmv.s.f v9, fa0
390-
; CHECK-NEXT: vsetivli zero, 2, e16, mf4, ta, ma
387+
; CHECK-NEXT: vsetvli zero, zero, e16, mf4, ta, ma
391388
; CHECK-NEXT: vfwredosum.vs v8, v8, v9
392389
; CHECK-NEXT: vsetivli zero, 0, e32, m1, ta, ma
393390
; CHECK-NEXT: vfmv.f.s fa0, v8
@@ -1185,7 +1182,7 @@ define half @vreduce_fmin_v2f16(<2 x half>* %x) {
11851182
; CHECK-NEXT: vle16.v v8, (a0)
11861183
; CHECK-NEXT: lui a0, %hi(.LCPI68_0)
11871184
; CHECK-NEXT: addi a0, a0, %lo(.LCPI68_0)
1188-
; CHECK-NEXT: vsetivli zero, 1, e16, m1, ta, ma
1185+
; CHECK-NEXT: vsetivli zero, 1, e16, mf4, ta, ma
11891186
; CHECK-NEXT: vlse16.v v9, (a0), zero
11901187
; CHECK-NEXT: vsetivli zero, 2, e16, mf4, ta, ma
11911188
; CHECK-NEXT: vfredmin.vs v8, v8, v9
@@ -1205,7 +1202,7 @@ define half @vreduce_fmin_v4f16(<4 x half>* %x) {
12051202
; CHECK-NEXT: vle16.v v8, (a0)
12061203
; CHECK-NEXT: lui a0, %hi(.LCPI69_0)
12071204
; CHECK-NEXT: addi a0, a0, %lo(.LCPI69_0)
1208-
; CHECK-NEXT: vsetivli zero, 1, e16, m1, ta, ma
1205+
; CHECK-NEXT: vsetivli zero, 1, e16, mf2, ta, ma
12091206
; CHECK-NEXT: vlse16.v v9, (a0), zero
12101207
; CHECK-NEXT: vsetivli zero, 4, e16, mf2, ta, ma
12111208
; CHECK-NEXT: vfredmin.vs v8, v8, v9
@@ -1223,7 +1220,7 @@ define half @vreduce_fmin_v4f16_nonans(<4 x half>* %x) {
12231220
; CHECK-NEXT: vle16.v v8, (a0)
12241221
; CHECK-NEXT: lui a0, %hi(.LCPI70_0)
12251222
; CHECK-NEXT: addi a0, a0, %lo(.LCPI70_0)
1226-
; CHECK-NEXT: vsetivli zero, 1, e16, m1, ta, ma
1223+
; CHECK-NEXT: vsetivli zero, 1, e16, mf2, ta, ma
12271224
; CHECK-NEXT: vlse16.v v9, (a0), zero
12281225
; CHECK-NEXT: vsetivli zero, 4, e16, mf2, ta, ma
12291226
; CHECK-NEXT: vfredmin.vs v8, v8, v9
@@ -1241,7 +1238,7 @@ define half @vreduce_fmin_v4f16_nonans_noinfs(<4 x half>* %x) {
12411238
; CHECK-NEXT: vle16.v v8, (a0)
12421239
; CHECK-NEXT: lui a0, %hi(.LCPI71_0)
12431240
; CHECK-NEXT: addi a0, a0, %lo(.LCPI71_0)
1244-
; CHECK-NEXT: vsetivli zero, 1, e16, m1, ta, ma
1241+
; CHECK-NEXT: vsetivli zero, 1, e16, mf2, ta, ma
12451242
; CHECK-NEXT: vlse16.v v9, (a0), zero
12461243
; CHECK-NEXT: vsetivli zero, 4, e16, mf2, ta, ma
12471244
; CHECK-NEXT: vfredmin.vs v8, v8, v9
@@ -1285,7 +1282,7 @@ define float @vreduce_fmin_v2f32(<2 x float>* %x) {
12851282
; CHECK-NEXT: vle32.v v8, (a0)
12861283
; CHECK-NEXT: lui a0, %hi(.LCPI73_0)
12871284
; CHECK-NEXT: addi a0, a0, %lo(.LCPI73_0)
1288-
; CHECK-NEXT: vsetivli zero, 1, e32, m1, ta, ma
1285+
; CHECK-NEXT: vsetivli zero, 1, e32, mf2, ta, ma
12891286
; CHECK-NEXT: vlse32.v v9, (a0), zero
12901287
; CHECK-NEXT: vsetivli zero, 2, e32, mf2, ta, ma
12911288
; CHECK-NEXT: vfredmin.vs v8, v8, v9
@@ -1490,7 +1487,7 @@ define half @vreduce_fmax_v2f16(<2 x half>* %x) {
14901487
; CHECK-NEXT: vle16.v v8, (a0)
14911488
; CHECK-NEXT: lui a0, %hi(.LCPI83_0)
14921489
; CHECK-NEXT: addi a0, a0, %lo(.LCPI83_0)
1493-
; CHECK-NEXT: vsetivli zero, 1, e16, m1, ta, ma
1490+
; CHECK-NEXT: vsetivli zero, 1, e16, mf4, ta, ma
14941491
; CHECK-NEXT: vlse16.v v9, (a0), zero
14951492
; CHECK-NEXT: vsetivli zero, 2, e16, mf4, ta, ma
14961493
; CHECK-NEXT: vfredmax.vs v8, v8, v9
@@ -1510,7 +1507,7 @@ define half @vreduce_fmax_v4f16(<4 x half>* %x) {
15101507
; CHECK-NEXT: vle16.v v8, (a0)
15111508
; CHECK-NEXT: lui a0, %hi(.LCPI84_0)
15121509
; CHECK-NEXT: addi a0, a0, %lo(.LCPI84_0)
1513-
; CHECK-NEXT: vsetivli zero, 1, e16, m1, ta, ma
1510+
; CHECK-NEXT: vsetivli zero, 1, e16, mf2, ta, ma
15141511
; CHECK-NEXT: vlse16.v v9, (a0), zero
15151512
; CHECK-NEXT: vsetivli zero, 4, e16, mf2, ta, ma
15161513
; CHECK-NEXT: vfredmax.vs v8, v8, v9
@@ -1528,7 +1525,7 @@ define half @vreduce_fmax_v4f16_nonans(<4 x half>* %x) {
15281525
; CHECK-NEXT: vle16.v v8, (a0)
15291526
; CHECK-NEXT: lui a0, %hi(.LCPI85_0)
15301527
; CHECK-NEXT: addi a0, a0, %lo(.LCPI85_0)
1531-
; CHECK-NEXT: vsetivli zero, 1, e16, m1, ta, ma
1528+
; CHECK-NEXT: vsetivli zero, 1, e16, mf2, ta, ma
15321529
; CHECK-NEXT: vlse16.v v9, (a0), zero
15331530
; CHECK-NEXT: vsetivli zero, 4, e16, mf2, ta, ma
15341531
; CHECK-NEXT: vfredmax.vs v8, v8, v9
@@ -1546,7 +1543,7 @@ define half @vreduce_fmax_v4f16_nonans_noinfs(<4 x half>* %x) {
15461543
; CHECK-NEXT: vle16.v v8, (a0)
15471544
; CHECK-NEXT: lui a0, %hi(.LCPI86_0)
15481545
; CHECK-NEXT: addi a0, a0, %lo(.LCPI86_0)
1549-
; CHECK-NEXT: vsetivli zero, 1, e16, m1, ta, ma
1546+
; CHECK-NEXT: vsetivli zero, 1, e16, mf2, ta, ma
15501547
; CHECK-NEXT: vlse16.v v9, (a0), zero
15511548
; CHECK-NEXT: vsetivli zero, 4, e16, mf2, ta, ma
15521549
; CHECK-NEXT: vfredmax.vs v8, v8, v9
@@ -1590,7 +1587,7 @@ define float @vreduce_fmax_v2f32(<2 x float>* %x) {
15901587
; CHECK-NEXT: vle32.v v8, (a0)
15911588
; CHECK-NEXT: lui a0, %hi(.LCPI88_0)
15921589
; CHECK-NEXT: addi a0, a0, %lo(.LCPI88_0)
1593-
; CHECK-NEXT: vsetivli zero, 1, e32, m1, ta, ma
1590+
; CHECK-NEXT: vsetivli zero, 1, e32, mf2, ta, ma
15941591
; CHECK-NEXT: vlse32.v v9, (a0), zero
15951592
; CHECK-NEXT: vsetivli zero, 2, e32, mf2, ta, ma
15961593
; CHECK-NEXT: vfredmax.vs v8, v8, v9

0 commit comments

Comments
 (0)