Skip to content

Commit 58f0c21

Browse files
author
git apple-llvm automerger
committed
Merge commit '443cdd0b48b8' from llvm.org/main into next
2 parents 0bc5065 + 443cdd0 commit 58f0c21

File tree

3 files changed

+77
-15
lines changed

3 files changed

+77
-15
lines changed

llvm/lib/Target/RISCV/RISCVISelLowering.cpp

Lines changed: 8 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -8412,13 +8412,18 @@ SDValue RISCVTargetLowering::lowerPARTIAL_REDUCE_MLA(SDValue Op,
84128412
assert(ArgVT == B.getSimpleValueType() &&
84138413
ArgVT.getVectorElementType() == MVT::i8);
84148414

8415+
// The zvqdotq pseudos are defined with sources and destination both
8416+
// being i32. This cast is needed for correctness to avoid incorrect
8417+
// .vx matching of i8 splats.
8418+
A = DAG.getBitcast(VT, A);
8419+
B = DAG.getBitcast(VT, B);
8420+
84158421
MVT ContainerVT = VT;
84168422
if (VT.isFixedLengthVector()) {
84178423
ContainerVT = getContainerForFixedLengthVector(VT);
84188424
Accum = convertToScalableVector(ContainerVT, Accum, DAG, Subtarget);
8419-
MVT ArgContainerVT = getContainerForFixedLengthVector(ArgVT);
8420-
A = convertToScalableVector(ArgContainerVT, A, DAG, Subtarget);
8421-
B = convertToScalableVector(ArgContainerVT, B, DAG, Subtarget);
8425+
A = convertToScalableVector(ContainerVT, A, DAG, Subtarget);
8426+
B = convertToScalableVector(ContainerVT, B, DAG, Subtarget);
84228427
}
84238428

84248429
bool IsSigned = Op.getOpcode() == ISD::PARTIAL_REDUCE_SMLA;

llvm/test/CodeGen/RISCV/rvv/fixed-vectors-zvqdotq.ll

Lines changed: 16 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -598,7 +598,6 @@ entry:
598598
ret <1 x i32> %res
599599
}
600600

601-
; FIXME: This case is wrong. We should be splatting 128 to each i8 lane!
602601
define <1 x i32> @vqdotu_vx_partial_reduce(<4 x i8> %a, <4 x i8> %b) {
603602
; NODOT-LABEL: vqdotu_vx_partial_reduce:
604603
; NODOT: # %bb.0: # %entry
@@ -618,10 +617,13 @@ define <1 x i32> @vqdotu_vx_partial_reduce(<4 x i8> %a, <4 x i8> %b) {
618617
;
619618
; DOT-LABEL: vqdotu_vx_partial_reduce:
620619
; DOT: # %bb.0: # %entry
621-
; DOT-NEXT: vsetivli zero, 1, e32, mf2, ta, ma
620+
; DOT-NEXT: vsetivli zero, 1, e32, m1, ta, ma
622621
; DOT-NEXT: vmv.s.x v9, zero
623622
; DOT-NEXT: li a0, 128
624-
; DOT-NEXT: vqdotu.vx v9, v8, a0
623+
; DOT-NEXT: vsetivli zero, 4, e8, mf4, ta, ma
624+
; DOT-NEXT: vmv.v.x v10, a0
625+
; DOT-NEXT: vsetivli zero, 1, e32, mf2, ta, ma
626+
; DOT-NEXT: vqdotu.vv v9, v8, v10
625627
; DOT-NEXT: vmv1r.v v8, v9
626628
; DOT-NEXT: ret
627629
entry:
@@ -631,7 +633,6 @@ entry:
631633
ret <1 x i32> %res
632634
}
633635

634-
; FIXME: This case is wrong. We should be splatting 128 to each i8 lane!
635636
define <1 x i32> @vqdot_vx_partial_reduce(<4 x i8> %a, <4 x i8> %b) {
636637
; NODOT-LABEL: vqdot_vx_partial_reduce:
637638
; NODOT: # %bb.0: # %entry
@@ -652,10 +653,13 @@ define <1 x i32> @vqdot_vx_partial_reduce(<4 x i8> %a, <4 x i8> %b) {
652653
;
653654
; DOT-LABEL: vqdot_vx_partial_reduce:
654655
; DOT: # %bb.0: # %entry
655-
; DOT-NEXT: vsetivli zero, 1, e32, mf2, ta, ma
656+
; DOT-NEXT: vsetivli zero, 1, e32, m1, ta, ma
656657
; DOT-NEXT: vmv.s.x v9, zero
657658
; DOT-NEXT: li a0, 128
658-
; DOT-NEXT: vqdot.vx v9, v8, a0
659+
; DOT-NEXT: vsetivli zero, 4, e8, mf4, ta, ma
660+
; DOT-NEXT: vmv.v.x v10, a0
661+
; DOT-NEXT: vsetivli zero, 1, e32, mf2, ta, ma
662+
; DOT-NEXT: vqdot.vv v9, v8, v10
659663
; DOT-NEXT: vmv1r.v v8, v9
660664
; DOT-NEXT: ret
661665
entry:
@@ -1372,7 +1376,6 @@ entry:
13721376
}
13731377

13741378

1375-
; FIXME: This case is wrong. We should be splatting 128 to each i8 lane!
13761379
define <4 x i32> @partial_of_sext(<16 x i8> %a) {
13771380
; NODOT-LABEL: partial_of_sext:
13781381
; NODOT: # %bb.0: # %entry
@@ -1393,10 +1396,11 @@ define <4 x i32> @partial_of_sext(<16 x i8> %a) {
13931396
;
13941397
; DOT-LABEL: partial_of_sext:
13951398
; DOT: # %bb.0: # %entry
1399+
; DOT-NEXT: vsetivli zero, 16, e8, m1, ta, ma
1400+
; DOT-NEXT: vmv.v.i v10, 1
13961401
; DOT-NEXT: vsetivli zero, 4, e32, m1, ta, ma
13971402
; DOT-NEXT: vmv.v.i v9, 0
1398-
; DOT-NEXT: li a0, 1
1399-
; DOT-NEXT: vqdot.vx v9, v8, a0
1403+
; DOT-NEXT: vqdot.vv v9, v8, v10
14001404
; DOT-NEXT: vmv.v.v v8, v9
14011405
; DOT-NEXT: ret
14021406
entry:
@@ -1405,7 +1409,6 @@ entry:
14051409
ret <4 x i32> %res
14061410
}
14071411

1408-
; FIXME: This case is wrong. We should be splatting 128 to each i8 lane!
14091412
define <4 x i32> @partial_of_zext(<16 x i8> %a) {
14101413
; NODOT-LABEL: partial_of_zext:
14111414
; NODOT: # %bb.0: # %entry
@@ -1426,10 +1429,11 @@ define <4 x i32> @partial_of_zext(<16 x i8> %a) {
14261429
;
14271430
; DOT-LABEL: partial_of_zext:
14281431
; DOT: # %bb.0: # %entry
1432+
; DOT-NEXT: vsetivli zero, 16, e8, m1, ta, ma
1433+
; DOT-NEXT: vmv.v.i v10, 1
14291434
; DOT-NEXT: vsetivli zero, 4, e32, m1, ta, ma
14301435
; DOT-NEXT: vmv.v.i v9, 0
1431-
; DOT-NEXT: li a0, 1
1432-
; DOT-NEXT: vqdotu.vx v9, v8, a0
1436+
; DOT-NEXT: vqdotu.vv v9, v8, v10
14331437
; DOT-NEXT: vmv.v.v v8, v9
14341438
; DOT-NEXT: ret
14351439
entry:

llvm/test/CodeGen/RISCV/rvv/zvqdotq-sdnode.ll

Lines changed: 53 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -957,3 +957,56 @@ entry:
957957
%res = call <vscale x 1 x i32> @llvm.experimental.vector.partial.reduce.add(<vscale x 1 x i32> zeroinitializer, <vscale x 4 x i32> %mul)
958958
ret <vscale x 1 x i32> %res
959959
}
960+
961+
962+
define <vscale x 4 x i32> @partial_of_sext(<vscale x 16 x i8> %a) {
963+
; NODOT-LABEL: partial_of_sext:
964+
; NODOT: # %bb.0: # %entry
965+
; NODOT-NEXT: vsetvli a0, zero, e32, m8, ta, ma
966+
; NODOT-NEXT: vsext.vf4 v16, v8
967+
; NODOT-NEXT: vsetvli a0, zero, e32, m2, ta, ma
968+
; NODOT-NEXT: vadd.vv v8, v22, v16
969+
; NODOT-NEXT: vadd.vv v10, v18, v20
970+
; NODOT-NEXT: vadd.vv v8, v10, v8
971+
; NODOT-NEXT: ret
972+
;
973+
; DOT-LABEL: partial_of_sext:
974+
; DOT: # %bb.0: # %entry
975+
; DOT-NEXT: vsetvli a0, zero, e8, m2, ta, ma
976+
; DOT-NEXT: vmv.v.i v12, 1
977+
; DOT-NEXT: vsetvli a0, zero, e32, m2, ta, ma
978+
; DOT-NEXT: vmv.v.i v10, 0
979+
; DOT-NEXT: vqdot.vv v10, v8, v12
980+
; DOT-NEXT: vmv.v.v v8, v10
981+
; DOT-NEXT: ret
982+
entry:
983+
%a.ext = sext <vscale x 16 x i8> %a to <vscale x 16 x i32>
984+
%res = call <vscale x 4 x i32> @llvm.experimental.vector.partial.reduce.add(<vscale x 4 x i32> zeroinitializer, <vscale x 16 x i32> %a.ext)
985+
ret <vscale x 4 x i32> %res
986+
}
987+
988+
define <vscale x 4 x i32> @partial_of_zext(<vscale x 16 x i8> %a) {
989+
; NODOT-LABEL: partial_of_zext:
990+
; NODOT: # %bb.0: # %entry
991+
; NODOT-NEXT: vsetvli a0, zero, e32, m8, ta, ma
992+
; NODOT-NEXT: vzext.vf4 v16, v8
993+
; NODOT-NEXT: vsetvli a0, zero, e32, m2, ta, ma
994+
; NODOT-NEXT: vadd.vv v8, v22, v16
995+
; NODOT-NEXT: vadd.vv v10, v18, v20
996+
; NODOT-NEXT: vadd.vv v8, v10, v8
997+
; NODOT-NEXT: ret
998+
;
999+
; DOT-LABEL: partial_of_zext:
1000+
; DOT: # %bb.0: # %entry
1001+
; DOT-NEXT: vsetvli a0, zero, e8, m2, ta, ma
1002+
; DOT-NEXT: vmv.v.i v12, 1
1003+
; DOT-NEXT: vsetvli a0, zero, e32, m2, ta, ma
1004+
; DOT-NEXT: vmv.v.i v10, 0
1005+
; DOT-NEXT: vqdotu.vv v10, v8, v12
1006+
; DOT-NEXT: vmv.v.v v8, v10
1007+
; DOT-NEXT: ret
1008+
entry:
1009+
%a.ext = zext <vscale x 16 x i8> %a to <vscale x 16 x i32>
1010+
%res = call <vscale x 4 x i32> @llvm.experimental.vector.partial.reduce.add(<vscale x 4 x i32> zeroinitializer, <vscale x 16 x i32> %a.ext)
1011+
ret <vscale x 4 x i32> %res
1012+
}

0 commit comments

Comments
 (0)