Skip to content

Commit aa39785

Browse files
authored
[mlir][vector][memref] Add alignment attribute to memory access ops (llvm#144344)
Alignment information is important to allow LLVM backends such as AMDGPU to select wide memory accesses (e.g., dwordx4 or b128). Since this info is not always inferable, it's better to inform LLVM backends explicitly about it. Furthermore, alignment is not necessarily a property of the element type, but of each individual memory access op (we can have overaligned and underaligned accesses compared to the natural/preferred alignment of the element type). This patch introduces `alignment` attribute to memref/vector.load/store ops. Follow-up PRs will 1. Propagate the attribute to LLVM/SPIR-V. 2. Introduce `alignment` attribute to other vector memory access ops: vector.gather + vector.scatter vector.transfer_read + vector.transfer_write vector.compressstore + vector.expandload vector.maskedload + vector.maskedstore 3. Replace `--convert-vector-to-llvm='use-vector-alignment=1` with a simple pass to populate alignment attributes based on the vector types.
1 parent 163da87 commit aa39785

File tree

8 files changed

+172
-6
lines changed

8 files changed

+172
-6
lines changed

mlir/docs/DefiningDialects/Operations.md

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -306,6 +306,8 @@ Right now, the following primitive constraints are supported:
306306
* `IntPositive`: Specifying an integer attribute whose value is positive
307307
* `IntNonNegative`: Specifying an integer attribute whose value is
308308
non-negative
309+
* `IntPowerOf2`: Specifying an integer attribute whose value is a power of
310+
two > 0
309311
* `ArrayMinCount<N>`: Specifying an array attribute to have at least `N`
310312
elements
311313
* `ArrayMaxCount<N>`: Specifying an array attribute to have at most `N`

mlir/include/mlir/Dialect/MemRef/IR/MemRefOps.td

Lines changed: 57 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1216,6 +1216,11 @@ def LoadOp : MemRef_Op<"load",
12161216
be reused in the cache. For details, refer to the
12171217
[https://llvm.org/docs/LangRef.html#load-instruction](LLVM load instruction).
12181218

1219+
An optional `alignment` attribute allows to specify the byte alignment of the
1220+
load operation. It must be a positive power of 2. The operation must access
1221+
memory at an address aligned to this boundary. Violations may lead to
1222+
architecture-specific faults or performance penalties.
1223+
A value of 0 indicates no specific alignment requirement.
12191224
Example:
12201225

12211226
```mlir
@@ -1226,7 +1231,39 @@ def LoadOp : MemRef_Op<"load",
12261231
let arguments = (ins Arg<AnyMemRef, "the reference to load from",
12271232
[MemRead]>:$memref,
12281233
Variadic<Index>:$indices,
1229-
DefaultValuedOptionalAttr<BoolAttr, "false">:$nontemporal);
1234+
DefaultValuedOptionalAttr<BoolAttr, "false">:$nontemporal,
1235+
ConfinedAttr<OptionalAttr<I64Attr>,
1236+
[AllAttrOf<[IntPositive, IntPowerOf2]>]>:$alignment);
1237+
1238+
let builders = [
1239+
OpBuilder<(ins "Value":$memref,
1240+
"ValueRange":$indices,
1241+
CArg<"bool", "false">:$nontemporal,
1242+
CArg<"uint64_t", "0">:$alignment), [{
1243+
return build($_builder, $_state, memref, indices, nontemporal,
1244+
alignment != 0 ? $_builder.getI64IntegerAttr(alignment) :
1245+
nullptr);
1246+
}]>,
1247+
OpBuilder<(ins "Type":$resultType,
1248+
"Value":$memref,
1249+
"ValueRange":$indices,
1250+
CArg<"bool", "false">:$nontemporal,
1251+
CArg<"uint64_t", "0">:$alignment), [{
1252+
return build($_builder, $_state, resultType, memref, indices, nontemporal,
1253+
alignment != 0 ? $_builder.getI64IntegerAttr(alignment) :
1254+
nullptr);
1255+
}]>,
1256+
OpBuilder<(ins "TypeRange":$resultTypes,
1257+
"Value":$memref,
1258+
"ValueRange":$indices,
1259+
CArg<"bool", "false">:$nontemporal,
1260+
CArg<"uint64_t", "0">:$alignment), [{
1261+
return build($_builder, $_state, resultTypes, memref, indices, nontemporal,
1262+
alignment != 0 ? $_builder.getI64IntegerAttr(alignment) :
1263+
nullptr);
1264+
}]>
1265+
];
1266+
12301267
let results = (outs AnyType:$result);
12311268

12321269
let extraClassDeclaration = [{
@@ -1912,6 +1949,11 @@ def MemRef_StoreOp : MemRef_Op<"store",
19121949
be reused in the cache. For details, refer to the
19131950
[https://llvm.org/docs/LangRef.html#store-instruction](LLVM store instruction).
19141951

1952+
An optional `alignment` attribute allows to specify the byte alignment of the
1953+
store operation. It must be a positive power of 2. The operation must access
1954+
memory at an address aligned to this boundary. Violations may lead to
1955+
architecture-specific faults or performance penalties.
1956+
A value of 0 indicates no specific alignment requirement.
19151957
Example:
19161958

19171959
```mlir
@@ -1923,13 +1965,25 @@ def MemRef_StoreOp : MemRef_Op<"store",
19231965
Arg<AnyMemRef, "the reference to store to",
19241966
[MemWrite]>:$memref,
19251967
Variadic<Index>:$indices,
1926-
DefaultValuedOptionalAttr<BoolAttr, "false">:$nontemporal);
1968+
DefaultValuedOptionalAttr<BoolAttr, "false">:$nontemporal,
1969+
ConfinedAttr<OptionalAttr<I64Attr>,
1970+
[AllAttrOf<[IntPositive, IntPowerOf2]>]>:$alignment);
19271971

19281972
let builders = [
1973+
OpBuilder<(ins "Value":$valueToStore,
1974+
"Value":$memref,
1975+
"ValueRange":$indices,
1976+
CArg<"bool", "false">:$nontemporal,
1977+
CArg<"uint64_t", "0">:$alignment), [{
1978+
return build($_builder, $_state, valueToStore, memref, indices, nontemporal,
1979+
alignment != 0 ? $_builder.getI64IntegerAttr(alignment) :
1980+
nullptr);
1981+
}]>,
19291982
OpBuilder<(ins "Value":$valueToStore, "Value":$memref), [{
19301983
$_state.addOperands(valueToStore);
19311984
$_state.addOperands(memref);
1932-
}]>];
1985+
}]>
1986+
];
19331987

19341988
let extraClassDeclaration = [{
19351989
Value getValueToStore() { return getOperand(0); }

mlir/include/mlir/Dialect/Vector/IR/VectorOps.td

Lines changed: 52 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1809,12 +1809,42 @@ def Vector_LoadOp : Vector_Op<"load", [
18091809
```mlir
18101810
%result = vector.load %memref[%c0] : memref<7xf32>, vector<8xf32>
18111811
```
1812+
1813+
An optional `alignment` attribute allows to specify the byte alignment of the
1814+
load operation. It must be a positive power of 2. The operation must access
1815+
memory at an address aligned to this boundary. Violations may lead to
1816+
architecture-specific faults or performance penalties.
1817+
A value of 0 indicates no specific alignment requirement.
18121818
}];
18131819

18141820
let arguments = (ins Arg<AnyMemRef, "the reference to load from",
18151821
[MemRead]>:$base,
18161822
Variadic<Index>:$indices,
1817-
DefaultValuedOptionalAttr<BoolAttr, "false">:$nontemporal);
1823+
DefaultValuedOptionalAttr<BoolAttr, "false">:$nontemporal,
1824+
ConfinedAttr<OptionalAttr<I64Attr>,
1825+
[AllAttrOf<[IntPositive, IntPowerOf2]>]>:$alignment);
1826+
1827+
let builders = [
1828+
OpBuilder<(ins "VectorType":$resultType,
1829+
"Value":$base,
1830+
"ValueRange":$indices,
1831+
CArg<"bool", "false">:$nontemporal,
1832+
CArg<"uint64_t", "0">:$alignment), [{
1833+
return build($_builder, $_state, resultType, base, indices, nontemporal,
1834+
alignment != 0 ? $_builder.getI64IntegerAttr(alignment) :
1835+
nullptr);
1836+
}]>,
1837+
OpBuilder<(ins "TypeRange":$resultTypes,
1838+
"Value":$base,
1839+
"ValueRange":$indices,
1840+
CArg<"bool", "false">:$nontemporal,
1841+
CArg<"uint64_t", "0">:$alignment), [{
1842+
return build($_builder, $_state, resultTypes, base, indices, nontemporal,
1843+
alignment != 0 ? $_builder.getI64IntegerAttr(alignment) :
1844+
nullptr);
1845+
}]>
1846+
];
1847+
18181848
let results = (outs AnyVectorOfAnyRank:$result);
18191849

18201850
let extraClassDeclaration = [{
@@ -1895,15 +1925,34 @@ def Vector_StoreOp : Vector_Op<"store", [
18951925
```mlir
18961926
vector.store %valueToStore, %memref[%c0] : memref<7xf32>, vector<8xf32>
18971927
```
1928+
1929+
An optional `alignment` attribute allows to specify the byte alignment of the
1930+
store operation. It must be a positive power of 2. The operation must access
1931+
memory at an address aligned to this boundary. Violations may lead to
1932+
architecture-specific faults or performance penalties.
1933+
A value of 0 indicates no specific alignment requirement.
18981934
}];
18991935

19001936
let arguments = (ins
19011937
AnyVectorOfAnyRank:$valueToStore,
19021938
Arg<AnyMemRef, "the reference to store to",
19031939
[MemWrite]>:$base,
19041940
Variadic<Index>:$indices,
1905-
DefaultValuedOptionalAttr<BoolAttr, "false">:$nontemporal
1906-
);
1941+
DefaultValuedOptionalAttr<BoolAttr, "false">:$nontemporal,
1942+
ConfinedAttr<OptionalAttr<I64Attr>,
1943+
[AllAttrOf<[IntPositive, IntPowerOf2]>]>:$alignment);
1944+
1945+
let builders = [
1946+
OpBuilder<(ins "Value":$valueToStore,
1947+
"Value":$base,
1948+
"ValueRange":$indices,
1949+
CArg<"bool", "false">:$nontemporal,
1950+
CArg<"uint64_t", "0">:$alignment), [{
1951+
return build($_builder, $_state, valueToStore, base, indices, nontemporal,
1952+
alignment != 0 ? $_builder.getI64IntegerAttr(alignment) :
1953+
nullptr);
1954+
}]>
1955+
];
19071956

19081957
let extraClassDeclaration = [{
19091958
MemRefType getMemRefType() {

mlir/include/mlir/IR/CommonAttrConstraints.td

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -796,6 +796,10 @@ def IntPositive : AttrConstraint<
796796
CPred<"::llvm::cast<::mlir::IntegerAttr>($_self).getValue().isStrictlyPositive()">,
797797
"whose value is positive">;
798798

799+
def IntPowerOf2 : AttrConstraint<
800+
CPred<"::llvm::cast<::mlir::IntegerAttr>($_self).getValue().isPowerOf2()">,
801+
"whose value is a power of two > 0">;
802+
799803
class ArrayMaxCount<int n> : AttrConstraint<
800804
CPred<"::llvm::cast<::mlir::ArrayAttr>($_self).size() <= " # n>,
801805
"with at most " # n # " elements">;

mlir/test/Dialect/MemRef/invalid.mlir

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -962,6 +962,24 @@ func.func @test_store_zero_results2(%x: i32, %p: memref<i32>) {
962962

963963
// -----
964964

965+
func.func @invalid_load_alignment(%memref: memref<4xi32>) {
966+
%c0 = arith.constant 0 : index
967+
// expected-error @below {{'memref.load' op attribute 'alignment' failed to satisfy constraint: 64-bit signless integer attribute whose value is positive and whose value is a power of two > 0}}
968+
%val = memref.load %memref[%c0] { alignment = -1 } : memref<4xi32>
969+
return
970+
}
971+
972+
// -----
973+
974+
func.func @invalid_store_alignment(%memref: memref<4xi32>, %val: i32) {
975+
%c0 = arith.constant 0 : index
976+
// expected-error @below {{'memref.store' op attribute 'alignment' failed to satisfy constraint: 64-bit signless integer attribute whose value is positive and whose value is a power of two > 0}}
977+
memref.store %val, %memref[%c0] { alignment = 3 } : memref<4xi32>
978+
return
979+
}
980+
981+
// -----
982+
965983
func.func @test_alloc_memref_map_rank_mismatch() {
966984
^bb0:
967985
// expected-error@+1 {{memref layout mismatch between rank and affine map: 2 != 1}}

mlir/test/Dialect/MemRef/ops.mlir

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -265,6 +265,17 @@ func.func @zero_dim_no_idx(%arg0 : memref<i32>, %arg1 : memref<i32>, %arg2 : mem
265265
// CHECK: memref.store %{{.*}}, %{{.*}}[] : memref<i32>
266266
}
267267

268+
269+
// CHECK-LABEL: func @load_store_alignment
270+
func.func @load_store_alignment(%memref: memref<4xi32>) {
271+
%c0 = arith.constant 0 : index
272+
// CHECK: memref.load {{.*}} {alignment = 16 : i64}
273+
%val = memref.load %memref[%c0] { alignment = 16 } : memref<4xi32>
274+
// CHECK: memref.store {{.*}} {alignment = 16 : i64}
275+
memref.store %val, %memref[%c0] { alignment = 16 } : memref<4xi32>
276+
return
277+
}
278+
268279
// CHECK-LABEL: func @memref_view(%arg0
269280
func.func @memref_view(%arg0 : index, %arg1 : index, %arg2 : index) {
270281
%0 = memref.alloc() : memref<2048xi8>

mlir/test/Dialect/Vector/invalid.mlir

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1995,6 +1995,15 @@ func.func @vector_load(%src : memref<?xi8>) {
19951995

19961996
// -----
19971997

1998+
func.func @invalid_load_alignment(%memref: memref<4xi32>) {
1999+
%c0 = arith.constant 0 : index
2000+
// expected-error @below {{'vector.load' op attribute 'alignment' failed to satisfy constraint: 64-bit signless integer attribute whose value is positive and whose value is a power of two > 0}}
2001+
%val = vector.load %memref[%c0] { alignment = -1 } : memref<4xi32>, vector<4xi32>
2002+
return
2003+
}
2004+
2005+
// -----
2006+
19982007
//===----------------------------------------------------------------------===//
19992008
// vector.store
20002009
//===----------------------------------------------------------------------===//
@@ -2005,3 +2014,12 @@ func.func @vector_store(%dest : memref<?xi8>, %vec : vector<16x16xi8>) {
20052014
vector.store %vec, %dest[%c0] : memref<?xi8>, vector<16x16xi8>
20062015
return
20072016
}
2017+
2018+
// -----
2019+
2020+
func.func @invalid_store_alignment(%memref: memref<4xi32>, %val: vector<4xi32>) {
2021+
%c0 = arith.constant 0 : index
2022+
// expected-error @below {{'vector.store' op attribute 'alignment' failed to satisfy constraint: 64-bit signless integer attribute whose value is positive and whose value is a power of two > 0}}
2023+
vector.store %val, %memref[%c0] { alignment = 3 } : memref<4xi32>, vector<4xi32>
2024+
return
2025+
}

mlir/test/Dialect/Vector/ops.mlir

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -853,6 +853,16 @@ func.func @vector_load_and_store_2d_vector_memref(%memref : memref<200x100xvecto
853853
return
854854
}
855855

856+
// CHECK-LABEL: func @load_store_alignment
857+
func.func @load_store_alignment(%memref: memref<4xi32>) {
858+
%c0 = arith.constant 0 : index
859+
// CHECK: vector.load {{.*}} {alignment = 16 : i64}
860+
%val = vector.load %memref[%c0] { alignment = 16 } : memref<4xi32>, vector<4xi32>
861+
// CHECK: vector.store {{.*}} {alignment = 16 : i64}
862+
vector.store %val, %memref[%c0] { alignment = 16 } : memref<4xi32>, vector<4xi32>
863+
return
864+
}
865+
856866
// CHECK-LABEL: @masked_load_and_store
857867
func.func @masked_load_and_store(%base: memref<?xf32>, %mask: vector<16xi1>, %passthru: vector<16xf32>) {
858868
%c0 = arith.constant 0 : index

0 commit comments

Comments
 (0)