Skip to content

Commit a4dc257

Browse files
author
git apple-llvm automerger
committed
Merge commit '783b050f8826' from llvm.org/main into next
2 parents 65d5eaa + 783b050 commit a4dc257

File tree

3 files changed

+227
-7
lines changed

3 files changed

+227
-7
lines changed

llvm/lib/Transforms/Scalar/LowerMatrixIntrinsics.cpp

Lines changed: 35 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -97,6 +97,12 @@ static cl::opt<MatrixLayoutTy> MatrixLayout(
9797
static cl::opt<bool> PrintAfterTransposeOpt("matrix-print-after-transpose-opt",
9898
cl::init(false));
9999

100+
static cl::opt<unsigned> SplitMatmulRemainderOverThreshold(
101+
"matrix-split-matmul-remainder-over-threshold", cl::Hidden,
102+
cl::desc("Illegal remainder vectors over this size in bits should be split "
103+
"in the inner loop of matmul"),
104+
cl::init(0));
105+
100106
/// Helper function to either return Scope, if it is a subprogram or the
101107
/// attached subprogram for a local scope.
102108
static DISubprogram *getSubprogram(DIScope *Scope) {
@@ -1720,6 +1726,31 @@ class LowerMatrixIntrinsics {
17201726
ToRemove.push_back(MatMul);
17211727
}
17221728

1729+
/// Given \p Remainder iterations of the the matmul inner loop,
1730+
/// potentially lower \p Blocksize that is used for the underlying
1731+
/// vector.
1732+
unsigned capBlockSize(unsigned BlockSize, unsigned Remainder, Type *EltType) {
1733+
if (BlockSize <= Remainder)
1734+
return BlockSize;
1735+
1736+
// If the remainder is also a legal type just use it.
1737+
auto *VecTy = FixedVectorType::get(EltType, Remainder);
1738+
if (TTI.isTypeLegal(VecTy))
1739+
return Remainder;
1740+
1741+
// Similarly, if the vector is small enough that we don't want
1742+
// to split further.
1743+
if (VecTy->getPrimitiveSizeInBits() <= SplitMatmulRemainderOverThreshold)
1744+
return Remainder;
1745+
1746+
// Gradually lower the vectorization factor to cover the
1747+
// remainder.
1748+
do {
1749+
BlockSize /= 2;
1750+
} while (BlockSize > Remainder);
1751+
return BlockSize;
1752+
}
1753+
17231754
/// Compute \p Result += \p A * \p B for input matrices with left-associating
17241755
/// addition.
17251756
///
@@ -1757,10 +1788,8 @@ class LowerMatrixIntrinsics {
17571788
bool isSumZero = isa<ConstantAggregateZero>(Result.getColumn(J));
17581789

17591790
for (unsigned I = 0; I < R; I += BlockSize) {
1760-
// Gradually lower the vectorization factor to cover the remainder.
1761-
while (I + BlockSize > R)
1762-
BlockSize /= 2;
1763-
1791+
// Lower block size to make sure we stay within bounds.
1792+
BlockSize = capBlockSize(BlockSize, R - I, Result.getElementType());
17641793
Value *Sum = IsTiled ? Result.extractVector(I, J, BlockSize, Builder)
17651794
: nullptr;
17661795
for (unsigned K = 0; K < M; ++K) {
@@ -1785,9 +1814,8 @@ class LowerMatrixIntrinsics {
17851814
unsigned BlockSize = VF;
17861815
bool isSumZero = isa<ConstantAggregateZero>(Result.getRow(I));
17871816
for (unsigned J = 0; J < C; J += BlockSize) {
1788-
// Gradually lower the vectorization factor to cover the remainder.
1789-
while (J + BlockSize > C)
1790-
BlockSize /= 2;
1817+
// Lower the vectorization factor to cover the remainder.
1818+
BlockSize = capBlockSize(BlockSize, C - J, Result.getElementType());
17911819

17921820
Value *Sum = nullptr;
17931821
for (unsigned K = 0; K < M; ++K) {
Lines changed: 96 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,96 @@
1+
; NOTE: Assertions have been autogenerated by utils/update_test_checks.py UTC_ARGS: --version 5
2+
; RUN: opt -passes='lower-matrix-intrinsics' -matrix-default-layout=row-major -S < %s | FileCheck --check-prefix=SPLIT_REMAINDER %s
3+
; RUN: opt -passes='lower-matrix-intrinsics' -matrix-split-matmul-remainder-over-threshold=96 -matrix-default-layout=row-major -S < %s | FileCheck --check-prefix=NO_SPLIT_REMAINDER %s
4+
; RUN: opt -passes='lower-matrix-intrinsics' -matrix-split-matmul-remainder-over-threshold=64 -matrix-default-layout=row-major -S < %s | FileCheck --check-prefix=SPLIT_REMAINDER %s
5+
6+
; REQUIRES: aarch64-registered-target
7+
8+
target datalayout = "e-m:o-i64:64-f80:128-n8:8:32:64-S128"
9+
target triple = "aarch64-apple-ios"
10+
11+
define void @matmul(ptr %a, ptr %b, ptr %c) {
12+
; SPLIT_REMAINDER-LABEL: define void @matmul(
13+
; SPLIT_REMAINDER-SAME: ptr [[A:%.*]], ptr [[B:%.*]], ptr [[C:%.*]]) {
14+
; SPLIT_REMAINDER-NEXT: [[COL_LOAD:%.*]] = load <3 x float>, ptr [[A]], align 4
15+
; SPLIT_REMAINDER-NEXT: [[COL_LOAD1:%.*]] = load <3 x float>, ptr [[B]], align 4
16+
; SPLIT_REMAINDER-NEXT: [[VEC_GEP:%.*]] = getelementptr float, ptr [[B]], i64 3
17+
; SPLIT_REMAINDER-NEXT: [[COL_LOAD2:%.*]] = load <3 x float>, ptr [[VEC_GEP]], align 4
18+
; SPLIT_REMAINDER-NEXT: [[VEC_GEP3:%.*]] = getelementptr float, ptr [[B]], i64 6
19+
; SPLIT_REMAINDER-NEXT: [[COL_LOAD4:%.*]] = load <3 x float>, ptr [[VEC_GEP3]], align 4
20+
; SPLIT_REMAINDER-NEXT: [[BLOCK:%.*]] = shufflevector <3 x float> [[COL_LOAD1]], <3 x float> poison, <2 x i32> <i32 0, i32 1>
21+
; SPLIT_REMAINDER-NEXT: [[TMP1:%.*]] = extractelement <3 x float> [[COL_LOAD]], i64 0
22+
; SPLIT_REMAINDER-NEXT: [[SPLAT_SPLATINSERT:%.*]] = insertelement <2 x float> poison, float [[TMP1]], i64 0
23+
; SPLIT_REMAINDER-NEXT: [[SPLAT_SPLAT:%.*]] = shufflevector <2 x float> [[SPLAT_SPLATINSERT]], <2 x float> poison, <2 x i32> zeroinitializer
24+
; SPLIT_REMAINDER-NEXT: [[TMP2:%.*]] = fmul <2 x float> [[SPLAT_SPLAT]], [[BLOCK]]
25+
; SPLIT_REMAINDER-NEXT: [[BLOCK5:%.*]] = shufflevector <3 x float> [[COL_LOAD2]], <3 x float> poison, <2 x i32> <i32 0, i32 1>
26+
; SPLIT_REMAINDER-NEXT: [[TMP3:%.*]] = extractelement <3 x float> [[COL_LOAD]], i64 1
27+
; SPLIT_REMAINDER-NEXT: [[SPLAT_SPLATINSERT6:%.*]] = insertelement <2 x float> poison, float [[TMP3]], i64 0
28+
; SPLIT_REMAINDER-NEXT: [[SPLAT_SPLAT7:%.*]] = shufflevector <2 x float> [[SPLAT_SPLATINSERT6]], <2 x float> poison, <2 x i32> zeroinitializer
29+
; SPLIT_REMAINDER-NEXT: [[TMP4:%.*]] = fmul <2 x float> [[SPLAT_SPLAT7]], [[BLOCK5]]
30+
; SPLIT_REMAINDER-NEXT: [[TMP5:%.*]] = fadd <2 x float> [[TMP2]], [[TMP4]]
31+
; SPLIT_REMAINDER-NEXT: [[BLOCK8:%.*]] = shufflevector <3 x float> [[COL_LOAD4]], <3 x float> poison, <2 x i32> <i32 0, i32 1>
32+
; SPLIT_REMAINDER-NEXT: [[TMP6:%.*]] = extractelement <3 x float> [[COL_LOAD]], i64 2
33+
; SPLIT_REMAINDER-NEXT: [[SPLAT_SPLATINSERT9:%.*]] = insertelement <2 x float> poison, float [[TMP6]], i64 0
34+
; SPLIT_REMAINDER-NEXT: [[SPLAT_SPLAT10:%.*]] = shufflevector <2 x float> [[SPLAT_SPLATINSERT9]], <2 x float> poison, <2 x i32> zeroinitializer
35+
; SPLIT_REMAINDER-NEXT: [[TMP7:%.*]] = fmul <2 x float> [[SPLAT_SPLAT10]], [[BLOCK8]]
36+
; SPLIT_REMAINDER-NEXT: [[TMP8:%.*]] = fadd <2 x float> [[TMP5]], [[TMP7]]
37+
; SPLIT_REMAINDER-NEXT: [[TMP9:%.*]] = shufflevector <2 x float> [[TMP8]], <2 x float> poison, <3 x i32> <i32 0, i32 1, i32 poison>
38+
; SPLIT_REMAINDER-NEXT: [[TMP10:%.*]] = shufflevector <3 x float> poison, <3 x float> [[TMP9]], <3 x i32> <i32 3, i32 4, i32 2>
39+
; SPLIT_REMAINDER-NEXT: [[BLOCK11:%.*]] = shufflevector <3 x float> [[COL_LOAD1]], <3 x float> poison, <1 x i32> <i32 2>
40+
; SPLIT_REMAINDER-NEXT: [[TMP11:%.*]] = extractelement <3 x float> [[COL_LOAD]], i64 0
41+
; SPLIT_REMAINDER-NEXT: [[SPLAT_SPLATINSERT12:%.*]] = insertelement <1 x float> poison, float [[TMP11]], i64 0
42+
; SPLIT_REMAINDER-NEXT: [[SPLAT_SPLAT13:%.*]] = shufflevector <1 x float> [[SPLAT_SPLATINSERT12]], <1 x float> poison, <1 x i32> zeroinitializer
43+
; SPLIT_REMAINDER-NEXT: [[TMP12:%.*]] = fmul <1 x float> [[SPLAT_SPLAT13]], [[BLOCK11]]
44+
; SPLIT_REMAINDER-NEXT: [[BLOCK14:%.*]] = shufflevector <3 x float> [[COL_LOAD2]], <3 x float> poison, <1 x i32> <i32 2>
45+
; SPLIT_REMAINDER-NEXT: [[TMP13:%.*]] = extractelement <3 x float> [[COL_LOAD]], i64 1
46+
; SPLIT_REMAINDER-NEXT: [[SPLAT_SPLATINSERT15:%.*]] = insertelement <1 x float> poison, float [[TMP13]], i64 0
47+
; SPLIT_REMAINDER-NEXT: [[SPLAT_SPLAT16:%.*]] = shufflevector <1 x float> [[SPLAT_SPLATINSERT15]], <1 x float> poison, <1 x i32> zeroinitializer
48+
; SPLIT_REMAINDER-NEXT: [[TMP14:%.*]] = fmul <1 x float> [[SPLAT_SPLAT16]], [[BLOCK14]]
49+
; SPLIT_REMAINDER-NEXT: [[TMP15:%.*]] = fadd <1 x float> [[TMP12]], [[TMP14]]
50+
; SPLIT_REMAINDER-NEXT: [[BLOCK17:%.*]] = shufflevector <3 x float> [[COL_LOAD4]], <3 x float> poison, <1 x i32> <i32 2>
51+
; SPLIT_REMAINDER-NEXT: [[TMP16:%.*]] = extractelement <3 x float> [[COL_LOAD]], i64 2
52+
; SPLIT_REMAINDER-NEXT: [[SPLAT_SPLATINSERT18:%.*]] = insertelement <1 x float> poison, float [[TMP16]], i64 0
53+
; SPLIT_REMAINDER-NEXT: [[SPLAT_SPLAT19:%.*]] = shufflevector <1 x float> [[SPLAT_SPLATINSERT18]], <1 x float> poison, <1 x i32> zeroinitializer
54+
; SPLIT_REMAINDER-NEXT: [[TMP17:%.*]] = fmul <1 x float> [[SPLAT_SPLAT19]], [[BLOCK17]]
55+
; SPLIT_REMAINDER-NEXT: [[TMP18:%.*]] = fadd <1 x float> [[TMP15]], [[TMP17]]
56+
; SPLIT_REMAINDER-NEXT: [[TMP19:%.*]] = shufflevector <1 x float> [[TMP18]], <1 x float> poison, <3 x i32> <i32 0, i32 poison, i32 poison>
57+
; SPLIT_REMAINDER-NEXT: [[TMP20:%.*]] = shufflevector <3 x float> [[TMP10]], <3 x float> [[TMP19]], <3 x i32> <i32 0, i32 1, i32 3>
58+
; SPLIT_REMAINDER-NEXT: store <3 x float> [[TMP20]], ptr [[C]], align 4
59+
; SPLIT_REMAINDER-NEXT: ret void
60+
;
61+
; NO_SPLIT_REMAINDER-LABEL: define void @matmul(
62+
; NO_SPLIT_REMAINDER-SAME: ptr [[A:%.*]], ptr [[B:%.*]], ptr [[C:%.*]]) {
63+
; NO_SPLIT_REMAINDER-NEXT: [[COL_LOAD:%.*]] = load <3 x float>, ptr [[A]], align 4
64+
; NO_SPLIT_REMAINDER-NEXT: [[COL_LOAD1:%.*]] = load <3 x float>, ptr [[B]], align 4
65+
; NO_SPLIT_REMAINDER-NEXT: [[VEC_GEP:%.*]] = getelementptr float, ptr [[B]], i64 3
66+
; NO_SPLIT_REMAINDER-NEXT: [[COL_LOAD2:%.*]] = load <3 x float>, ptr [[VEC_GEP]], align 4
67+
; NO_SPLIT_REMAINDER-NEXT: [[VEC_GEP3:%.*]] = getelementptr float, ptr [[B]], i64 6
68+
; NO_SPLIT_REMAINDER-NEXT: [[COL_LOAD4:%.*]] = load <3 x float>, ptr [[VEC_GEP3]], align 4
69+
; NO_SPLIT_REMAINDER-NEXT: [[BLOCK:%.*]] = shufflevector <3 x float> [[COL_LOAD1]], <3 x float> poison, <3 x i32> <i32 0, i32 1, i32 2>
70+
; NO_SPLIT_REMAINDER-NEXT: [[TMP1:%.*]] = extractelement <3 x float> [[COL_LOAD]], i64 0
71+
; NO_SPLIT_REMAINDER-NEXT: [[SPLAT_SPLATINSERT:%.*]] = insertelement <3 x float> poison, float [[TMP1]], i64 0
72+
; NO_SPLIT_REMAINDER-NEXT: [[SPLAT_SPLAT:%.*]] = shufflevector <3 x float> [[SPLAT_SPLATINSERT]], <3 x float> poison, <3 x i32> zeroinitializer
73+
; NO_SPLIT_REMAINDER-NEXT: [[TMP2:%.*]] = fmul <3 x float> [[SPLAT_SPLAT]], [[BLOCK]]
74+
; NO_SPLIT_REMAINDER-NEXT: [[BLOCK5:%.*]] = shufflevector <3 x float> [[COL_LOAD2]], <3 x float> poison, <3 x i32> <i32 0, i32 1, i32 2>
75+
; NO_SPLIT_REMAINDER-NEXT: [[TMP3:%.*]] = extractelement <3 x float> [[COL_LOAD]], i64 1
76+
; NO_SPLIT_REMAINDER-NEXT: [[SPLAT_SPLATINSERT6:%.*]] = insertelement <3 x float> poison, float [[TMP3]], i64 0
77+
; NO_SPLIT_REMAINDER-NEXT: [[SPLAT_SPLAT7:%.*]] = shufflevector <3 x float> [[SPLAT_SPLATINSERT6]], <3 x float> poison, <3 x i32> zeroinitializer
78+
; NO_SPLIT_REMAINDER-NEXT: [[TMP4:%.*]] = fmul <3 x float> [[SPLAT_SPLAT7]], [[BLOCK5]]
79+
; NO_SPLIT_REMAINDER-NEXT: [[TMP5:%.*]] = fadd <3 x float> [[TMP2]], [[TMP4]]
80+
; NO_SPLIT_REMAINDER-NEXT: [[BLOCK8:%.*]] = shufflevector <3 x float> [[COL_LOAD4]], <3 x float> poison, <3 x i32> <i32 0, i32 1, i32 2>
81+
; NO_SPLIT_REMAINDER-NEXT: [[TMP6:%.*]] = extractelement <3 x float> [[COL_LOAD]], i64 2
82+
; NO_SPLIT_REMAINDER-NEXT: [[SPLAT_SPLATINSERT9:%.*]] = insertelement <3 x float> poison, float [[TMP6]], i64 0
83+
; NO_SPLIT_REMAINDER-NEXT: [[SPLAT_SPLAT10:%.*]] = shufflevector <3 x float> [[SPLAT_SPLATINSERT9]], <3 x float> poison, <3 x i32> zeroinitializer
84+
; NO_SPLIT_REMAINDER-NEXT: [[TMP7:%.*]] = fmul <3 x float> [[SPLAT_SPLAT10]], [[BLOCK8]]
85+
; NO_SPLIT_REMAINDER-NEXT: [[TMP8:%.*]] = fadd <3 x float> [[TMP5]], [[TMP7]]
86+
; NO_SPLIT_REMAINDER-NEXT: [[TMP9:%.*]] = shufflevector <3 x float> [[TMP8]], <3 x float> poison, <3 x i32> <i32 0, i32 1, i32 2>
87+
; NO_SPLIT_REMAINDER-NEXT: [[TMP10:%.*]] = shufflevector <3 x float> poison, <3 x float> [[TMP9]], <3 x i32> <i32 3, i32 4, i32 5>
88+
; NO_SPLIT_REMAINDER-NEXT: store <3 x float> [[TMP10]], ptr [[C]], align 4
89+
; NO_SPLIT_REMAINDER-NEXT: ret void
90+
;
91+
%a_load = load <3 x float>, ptr %a, align 4
92+
%b_load = load <9 x float>, ptr %b, align 4
93+
%matmul = tail call <3 x float> @llvm.matrix.multiply.v3f32.v9f32.v3f32(<3 x float> %a_load, <9 x float> %b_load, i32 1, i32 3, i32 3)
94+
store <3 x float> %matmul, ptr %c, align 4
95+
ret void
96+
}

0 commit comments

Comments
 (0)