Skip to content

Commit 625aa09

Browse files
authored
[Matrix] Use data layout index type for lowering matrix intrinsics (llvm#162646)
To properly support the matrix intrinsics on, e.g., 32-bit platforms (without the need to emit `libc` calls), `LowerMatrixIntrinsics` pass should generate code that performs strided index calculations using the same pointer bit-width as the matrix pointers, as determined by the data layout. This patch updates the `LowerMatrixInstrics` transform to make this the case. PR: llvm#162646
1 parent dbab36a commit 625aa09

File tree

6 files changed

+901
-29
lines changed

6 files changed

+901
-29
lines changed

llvm/docs/LangRef.rst

Lines changed: 9 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -21074,12 +21074,12 @@ Overview:
2107421074

2107521075
The '``llvm.matrix.column.major.load.*``' intrinsics load a ``<Rows> x <Cols>``
2107621076
matrix using a stride of ``%Stride`` to compute the start address of the
21077-
different columns. The offset is computed using ``%Stride``'s bitwidth. This
21078-
allows for convenient loading of sub matrixes. If ``<IsVolatile>`` is true, the
21079-
intrinsic is considered a :ref:`volatile memory access <volatile>`. The result
21080-
matrix is returned in the result vector. If the ``%Ptr`` argument is known to
21081-
be aligned to some boundary, this can be specified as an attribute on the
21082-
argument.
21077+
different columns. This allows for convenient loading of sub matrixes.
21078+
Independent of ``%Stride``'s bitwidth, the offset is computed using the target
21079+
daya layout's pointer index type. If ``<IsVolatile>`` is true, the intrinsic is
21080+
considered a :ref:`volatile memory access <volatile>`. The result matrix is
21081+
returned in the result vector. If the ``%Ptr`` argument is known to be aligned
21082+
to some boundary, this can be specified as an attribute on the argument.
2108321083

2108421084
Arguments:
2108521085
""""""""""
@@ -21114,9 +21114,9 @@ Overview:
2111421114

2111521115
The '``llvm.matrix.column.major.store.*``' intrinsics store the ``<Rows> x
2111621116
<Cols>`` matrix in ``%In`` to memory using a stride of ``%Stride`` between
21117-
columns. The offset is computed using ``%Stride``'s bitwidth. If
21118-
``<IsVolatile>`` is true, the intrinsic is considered a
21119-
:ref:`volatile memory access <volatile>`.
21117+
columns. Independent of ``%Stride``'s bitwidth, the offset is computed using
21118+
the target daya layout's pointer index type. If ``<IsVolatile>`` is true, the
21119+
intrinsic is considered a :ref:`volatile memory access <volatile>`.
2112021120

2112121121
If the ``%Ptr`` argument is known to be aligned to some boundary, this can be
2112221122
specified as an attribute on the argument.

llvm/lib/Transforms/Scalar/LowerMatrixIntrinsics.cpp

Lines changed: 31 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -1295,6 +1295,24 @@ class LowerMatrixIntrinsics {
12951295
return commonAlignment(InitialAlign, ElementSizeInBits / 8);
12961296
}
12971297

1298+
IntegerType *getIndexType(Value *Ptr) const {
1299+
return cast<IntegerType>(DL.getIndexType(Ptr->getType()));
1300+
}
1301+
1302+
Value *getIndex(Value *Ptr, uint64_t V) const {
1303+
return ConstantInt::get(getIndexType(Ptr), V);
1304+
}
1305+
1306+
Value *castToIndexType(Value *Ptr, Value *V, IRBuilder<> &Builder) const {
1307+
assert(isa<IntegerType>(V->getType()) &&
1308+
"Attempted to cast non-integral type to integer index");
1309+
// In case the data layout's index type differs in width from the type of
1310+
// the value we're given, truncate or zero extend to the appropriate width.
1311+
// We zero extend here as indices are unsigned.
1312+
return Builder.CreateZExtOrTrunc(V, getIndexType(Ptr),
1313+
V->getName() + ".cast");
1314+
}
1315+
12981316
/// Load a matrix with \p Shape starting at \p Ptr and using \p Stride between
12991317
/// vectors.
13001318
MatrixTy loadMatrix(Type *Ty, Value *Ptr, MaybeAlign MAlign, Value *Stride,
@@ -1304,6 +1322,7 @@ class LowerMatrixIntrinsics {
13041322
Type *VecTy = FixedVectorType::get(EltTy, Shape.getStride());
13051323
Value *EltPtr = Ptr;
13061324
MatrixTy Result;
1325+
Stride = castToIndexType(Ptr, Stride, Builder);
13071326
for (unsigned I = 0, E = Shape.getNumVectors(); I < E; ++I) {
13081327
Value *GEP = computeVectorAddr(
13091328
EltPtr, Builder.getIntN(Stride->getType()->getScalarSizeInBits(), I),
@@ -1325,14 +1344,14 @@ class LowerMatrixIntrinsics {
13251344
ShapeInfo ResultShape, Type *EltTy,
13261345
IRBuilder<> &Builder) {
13271346
Value *Offset = Builder.CreateAdd(
1328-
Builder.CreateMul(J, Builder.getInt64(MatrixShape.getStride())), I);
1347+
Builder.CreateMul(J, getIndex(MatrixPtr, MatrixShape.getStride())), I);
13291348

13301349
Value *TileStart = Builder.CreateGEP(EltTy, MatrixPtr, Offset);
13311350
auto *TileTy = FixedVectorType::get(EltTy, ResultShape.NumRows *
13321351
ResultShape.NumColumns);
13331352

13341353
return loadMatrix(TileTy, TileStart, Align,
1335-
Builder.getInt64(MatrixShape.getStride()), IsVolatile,
1354+
getIndex(MatrixPtr, MatrixShape.getStride()), IsVolatile,
13361355
ResultShape, Builder);
13371356
}
13381357

@@ -1363,14 +1382,15 @@ class LowerMatrixIntrinsics {
13631382
MaybeAlign MAlign, bool IsVolatile, ShapeInfo MatrixShape,
13641383
Value *I, Value *J, Type *EltTy, IRBuilder<> &Builder) {
13651384
Value *Offset = Builder.CreateAdd(
1366-
Builder.CreateMul(J, Builder.getInt64(MatrixShape.getStride())), I);
1385+
Builder.CreateMul(J, getIndex(MatrixPtr, MatrixShape.getStride())), I);
13671386

13681387
Value *TileStart = Builder.CreateGEP(EltTy, MatrixPtr, Offset);
13691388
auto *TileTy = FixedVectorType::get(EltTy, StoreVal.getNumRows() *
13701389
StoreVal.getNumColumns());
13711390

13721391
storeMatrix(TileTy, StoreVal, TileStart, MAlign,
1373-
Builder.getInt64(MatrixShape.getStride()), IsVolatile, Builder);
1392+
getIndex(MatrixPtr, MatrixShape.getStride()), IsVolatile,
1393+
Builder);
13741394
}
13751395

13761396
/// Store matrix \p StoreVal starting at \p Ptr and using \p Stride between
@@ -1380,6 +1400,7 @@ class LowerMatrixIntrinsics {
13801400
IRBuilder<> &Builder) {
13811401
auto *VType = cast<FixedVectorType>(Ty);
13821402
Value *EltPtr = Ptr;
1403+
Stride = castToIndexType(Ptr, Stride, Builder);
13831404
for (auto Vec : enumerate(StoreVal.vectors())) {
13841405
Value *GEP = computeVectorAddr(
13851406
EltPtr,
@@ -2011,18 +2032,17 @@ class LowerMatrixIntrinsics {
20112032
const unsigned TileM = std::min(M - K, unsigned(TileSize));
20122033
MatrixTy A =
20132034
loadMatrix(APtr, LoadOp0->getAlign(), LoadOp0->isVolatile(),
2014-
LShape, Builder.getInt64(I), Builder.getInt64(K),
2035+
LShape, getIndex(APtr, I), getIndex(APtr, K),
20152036
{TileR, TileM}, EltType, Builder);
20162037
MatrixTy B =
20172038
loadMatrix(BPtr, LoadOp1->getAlign(), LoadOp1->isVolatile(),
2018-
RShape, Builder.getInt64(K), Builder.getInt64(J),
2039+
RShape, getIndex(BPtr, K), getIndex(BPtr, J),
20192040
{TileM, TileC}, EltType, Builder);
20202041
emitMatrixMultiply(Res, A, B, Builder, true, false,
20212042
getFastMathFlags(MatMul));
20222043
}
20232044
storeMatrix(Res, CPtr, Store->getAlign(), Store->isVolatile(), {R, M},
2024-
Builder.getInt64(I), Builder.getInt64(J), EltType,
2025-
Builder);
2045+
getIndex(CPtr, I), getIndex(CPtr, J), EltType, Builder);
20262046
}
20272047
}
20282048

@@ -2254,15 +2274,14 @@ class LowerMatrixIntrinsics {
22542274
/// Lower load instructions.
22552275
MatrixTy VisitLoad(LoadInst *Inst, const ShapeInfo &SI, Value *Ptr,
22562276
IRBuilder<> &Builder) {
2257-
return LowerLoad(Inst, Ptr, Inst->getAlign(),
2258-
Builder.getInt64(SI.getStride()), Inst->isVolatile(), SI,
2259-
Builder);
2277+
return LowerLoad(Inst, Ptr, Inst->getAlign(), getIndex(Ptr, SI.getStride()),
2278+
Inst->isVolatile(), SI, Builder);
22602279
}
22612280

22622281
MatrixTy VisitStore(StoreInst *Inst, const ShapeInfo &SI, Value *StoredVal,
22632282
Value *Ptr, IRBuilder<> &Builder) {
22642283
return LowerStore(Inst, StoredVal, Ptr, Inst->getAlign(),
2265-
Builder.getInt64(SI.getStride()), Inst->isVolatile(), SI,
2284+
getIndex(Ptr, SI.getStride()), Inst->isVolatile(), SI,
22662285
Builder);
22672286
}
22682287

0 commit comments

Comments
 (0)