Skip to content

Commit 6048be9

Browse files
jiang1997amd-eochoalo
authored andcommitted
[mlir] Introduce AlignmentAttrOpInterface to expose MaybeAlign (llvm#161440)
Introduce a common interface for operations with alignment attributes across MemRef, Vector, and SPIRV dialects. The interface exposes getMaybeAlign() to retrieve alignment as llvm::MaybeAlign. This is the second part of the PRs addressing issue llvm#155677. Co-authored-by: Erick Ochoa Lopez <[email protected]>
1 parent 52a672f commit 6048be9

File tree

13 files changed

+168
-20
lines changed

13 files changed

+168
-20
lines changed

mlir/include/mlir/Dialect/MemRef/IR/MemRef.h

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@
1313
#include "mlir/Dialect/Arith/IR/Arith.h"
1414
#include "mlir/Dialect/Utils/ReshapeOpsUtils.h"
1515
#include "mlir/IR/Dialect.h"
16+
#include "mlir/Interfaces/AlignmentAttrInterface.h"
1617
#include "mlir/Interfaces/CallInterfaces.h"
1718
#include "mlir/Interfaces/CastInterfaces.h"
1819
#include "mlir/Interfaces/ControlFlowInterfaces.h"

mlir/include/mlir/Dialect/MemRef/IR/MemRefOps.td

Lines changed: 11 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@
1111

1212
include "mlir/Dialect/Arith/IR/ArithBase.td"
1313
include "mlir/Dialect/MemRef/IR/MemRefBase.td"
14+
include "mlir/Interfaces/AlignmentAttrInterface.td"
1415
include "mlir/Interfaces/CastInterfaces.td"
1516
include "mlir/Interfaces/ControlFlowInterfaces.td"
1617
include "mlir/Interfaces/InferIntRangeInterface.td"
@@ -65,15 +66,15 @@ class AllocLikeOp<string mnemonic,
6566
list<Trait> traits = []> :
6667
MemRef_Op<mnemonic,
6768
!listconcat([
68-
AttrSizedOperandSegments
69+
AttrSizedOperandSegments,
70+
DeclareOpInterfaceMethods<AlignmentAttrOpInterface>
6971
], traits)> {
7072

7173
let arguments = (ins Variadic<Index>:$dynamicSizes,
7274
// The symbolic operands (the ones in square brackets)
7375
// bind to the symbols of the memref's layout map.
7476
Variadic<Index>:$symbolOperands,
75-
ConfinedAttr<OptionalAttr<I64Attr>,
76-
[IntMinValue<0>]>:$alignment);
77+
OptionalAttr<IntValidAlignment<I64Attr>>:$alignment);
7778
let results = (outs Res<AnyMemRef, "",
7879
[MemAlloc<resource, 0, FullEffect>]>:$memref);
7980

@@ -269,7 +270,8 @@ def MemRef_AllocOp : AllocLikeOp<"alloc", DefaultResource, [
269270
//===----------------------------------------------------------------------===//
270271

271272

272-
def MemRef_ReallocOp : MemRef_Op<"realloc"> {
273+
def MemRef_ReallocOp : MemRef_Op<"realloc",
274+
[DeclareOpInterfaceMethods<AlignmentAttrOpInterface>]> {
273275
let summary = "memory reallocation operation";
274276
let description = [{
275277
The `realloc` operation changes the size of a memory region. The memory
@@ -335,8 +337,7 @@ def MemRef_ReallocOp : MemRef_Op<"realloc"> {
335337
let arguments = (ins Arg<MemRefRankOf<[AnyType], [1]>, "",
336338
[MemFreeAt<0, FullEffect>]>:$source,
337339
Optional<Index>:$dynamicResultSize,
338-
ConfinedAttr<OptionalAttr<I64Attr>,
339-
[IntMinValue<0>]>:$alignment);
340+
OptionalAttr<IntValidAlignment<I64Attr>>:$alignment);
340341

341342
let results = (outs Res<MemRefRankOf<[AnyType], [1]>, "",
342343
[MemAlloc<DefaultResource, 1,
@@ -1160,7 +1161,8 @@ def MemRef_GetGlobalOp : MemRef_Op<"get_global",
11601161
// GlobalOp
11611162
//===----------------------------------------------------------------------===//
11621163

1163-
def MemRef_GlobalOp : MemRef_Op<"global", [Symbol]> {
1164+
def MemRef_GlobalOp : MemRef_Op<"global", [Symbol,
1165+
DeclareOpInterfaceMethods<AlignmentAttrOpInterface>]> {
11641166
let summary = "declare or define a global memref variable";
11651167
let description = [{
11661168
The `memref.global` operation declares or defines a named global memref
@@ -1235,6 +1237,7 @@ def LoadOp : MemRef_Op<"load",
12351237
"memref", "result",
12361238
"::llvm::cast<MemRefType>($_self).getElementType()">,
12371239
MemRefsNormalizable,
1240+
DeclareOpInterfaceMethods<AlignmentAttrOpInterface>,
12381241
DeclareOpInterfaceMethods<MemorySpaceCastConsumerOpInterface>,
12391242
DeclareOpInterfaceMethods<PromotableMemOpInterface>,
12401243
DeclareOpInterfaceMethods<DestructurableAccessorOpInterface>]> {
@@ -2010,6 +2013,7 @@ def MemRef_StoreOp : MemRef_Op<"store",
20102013
"memref", "value",
20112014
"::llvm::cast<MemRefType>($_self).getElementType()">,
20122015
MemRefsNormalizable,
2016+
DeclareOpInterfaceMethods<AlignmentAttrOpInterface>,
20132017
DeclareOpInterfaceMethods<MemorySpaceCastConsumerOpInterface>,
20142018
DeclareOpInterfaceMethods<PromotableMemOpInterface>,
20152019
DeclareOpInterfaceMethods<DestructurableAccessorOpInterface>]> {

mlir/include/mlir/Dialect/SPIRV/IR/SPIRVCooperativeMatrixOps.td

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,8 @@
1616
#ifndef MLIR_DIALECT_SPIRV_IR_COOPERATIVE_MATRIX_OPS
1717
#define MLIR_DIALECT_SPIRV_IR_COOPERATIVE_MATRIX_OPS
1818

19+
include "mlir/Interfaces/AlignmentAttrInterface.td"
20+
1921
//===----------------------------------------------------------------------===//
2022
// SPV_KHR_cooperative_matrix extension ops.
2123
//===----------------------------------------------------------------------===//
@@ -62,7 +64,7 @@ def SPIRV_KHRCooperativeMatrixLengthOp :
6264

6365
// -----
6466

65-
def SPIRV_KHRCooperativeMatrixLoadOp : SPIRV_KhrVendorOp<"CooperativeMatrixLoad", []> {
67+
def SPIRV_KHRCooperativeMatrixLoadOp : SPIRV_KhrVendorOp<"CooperativeMatrixLoad", [DeclareOpInterfaceMethods<AlignmentAttrOpInterface>]> {
6668
let summary = "Loads a cooperative matrix through a pointer";
6769

6870
let description = [{
@@ -148,7 +150,7 @@ def SPIRV_KHRCooperativeMatrixLoadOp : SPIRV_KhrVendorOp<"CooperativeMatrixLoad"
148150

149151
// -----
150152

151-
def SPIRV_KHRCooperativeMatrixStoreOp : SPIRV_KhrVendorOp<"CooperativeMatrixStore", []> {
153+
def SPIRV_KHRCooperativeMatrixStoreOp : SPIRV_KhrVendorOp<"CooperativeMatrixStore", [DeclareOpInterfaceMethods<AlignmentAttrOpInterface>]> {
152154
let summary = "Stores a cooperative matrix through a pointer";
153155

154156
let description = [{

mlir/include/mlir/Dialect/SPIRV/IR/SPIRVMemoryOps.td

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,8 @@
1515
#define MLIR_DIALECT_SPIRV_IR_MEMORY_OPS
1616

1717
include "mlir/Dialect/SPIRV/IR/SPIRVBase.td"
18+
include "mlir/Interfaces/AlignmentAttrInterface.td"
19+
1820

1921
// -----
2022

@@ -79,7 +81,7 @@ def SPIRV_AccessChainOp : SPIRV_Op<"AccessChain", [Pure]> {
7981

8082
// -----
8183

82-
def SPIRV_CopyMemoryOp : SPIRV_Op<"CopyMemory", []> {
84+
def SPIRV_CopyMemoryOp : SPIRV_Op<"CopyMemory", [DeclareOpInterfaceMethods<AlignmentAttrOpInterface>]> {
8385
let summary = [{
8486
Copy from the memory pointed to by Source to the memory pointed to by
8587
Target. Both operands must be non-void pointers and having the same <id>
@@ -182,7 +184,7 @@ def SPIRV_InBoundsPtrAccessChainOp : SPIRV_Op<"InBoundsPtrAccessChain", [Pure]>
182184

183185
// -----
184186

185-
def SPIRV_LoadOp : SPIRV_Op<"Load", []> {
187+
def SPIRV_LoadOp : SPIRV_Op<"Load", [DeclareOpInterfaceMethods<AlignmentAttrOpInterface>]> {
186188
let summary = "Load through a pointer.";
187189

188190
let description = [{
@@ -310,7 +312,7 @@ def SPIRV_PtrAccessChainOp : SPIRV_Op<"PtrAccessChain", [Pure]> {
310312

311313
// -----
312314

313-
def SPIRV_StoreOp : SPIRV_Op<"Store", []> {
315+
def SPIRV_StoreOp : SPIRV_Op<"Store", [DeclareOpInterfaceMethods<AlignmentAttrOpInterface>]> {
314316
let summary = "Store through a pointer.";
315317

316318
let description = [{

mlir/include/mlir/Dialect/SPIRV/IR/SPIRVOps.h

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,7 @@
2020
#include "mlir/Dialect/SPIRV/Interfaces/SPIRVImageInterfaces.h"
2121
#include "mlir/IR/BuiltinOps.h"
2222
#include "mlir/IR/OpImplementation.h"
23+
#include "mlir/Interfaces/AlignmentAttrInterface.h"
2324
#include "mlir/Interfaces/CallInterfaces.h"
2425
#include "mlir/Interfaces/ControlFlowInterfaces.h"
2526
#include "mlir/Interfaces/FunctionInterfaces.h"

mlir/include/mlir/Dialect/Vector/IR/VectorOps.h

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,7 @@
2323
#include "mlir/IR/Dialect.h"
2424
#include "mlir/IR/OpDefinition.h"
2525
#include "mlir/IR/PatternMatch.h"
26+
#include "mlir/Interfaces/AlignmentAttrInterface.h"
2627
#include "mlir/Interfaces/ControlFlowInterfaces.h"
2728
#include "mlir/Interfaces/DestinationStyleOpInterface.h"
2829
#include "mlir/Interfaces/IndexingMapOpInterface.h"

mlir/include/mlir/Dialect/Vector/IR/VectorOps.td

Lines changed: 27 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@ include "mlir/Dialect/Vector/Interfaces/MaskableOpInterface.td"
1919
include "mlir/Dialect/Vector/Interfaces/MaskingOpInterface.td"
2020
include "mlir/Dialect/Vector/IR/Vector.td"
2121
include "mlir/Dialect/Vector/IR/VectorAttributes.td"
22+
include "mlir/Interfaces/AlignmentAttrInterface.td"
2223
include "mlir/Interfaces/ControlFlowInterfaces.td"
2324
include "mlir/Interfaces/DestinationStyleOpInterface.td"
2425
include "mlir/Interfaces/IndexingMapOpInterface.td"
@@ -1653,7 +1654,8 @@ def Vector_TransferWriteOp :
16531654

16541655
def Vector_LoadOp : Vector_Op<"load", [
16551656
DeclareOpInterfaceMethods<VectorUnrollOpInterface, ["getShapeForUnroll"]>,
1656-
DeclareOpInterfaceMethods<MemorySpaceCastConsumerOpInterface>
1657+
DeclareOpInterfaceMethods<MemorySpaceCastConsumerOpInterface>,
1658+
DeclareOpInterfaceMethods<AlignmentAttrOpInterface>
16571659
]> {
16581660
let summary = "reads an n-D slice of memory into an n-D vector";
16591661
let description = [{
@@ -1770,7 +1772,8 @@ def Vector_LoadOp : Vector_Op<"load", [
17701772

17711773
def Vector_StoreOp : Vector_Op<"store", [
17721774
DeclareOpInterfaceMethods<VectorUnrollOpInterface, ["getShapeForUnroll"]>,
1773-
DeclareOpInterfaceMethods<MemorySpaceCastConsumerOpInterface>
1775+
DeclareOpInterfaceMethods<MemorySpaceCastConsumerOpInterface>,
1776+
DeclareOpInterfaceMethods<AlignmentAttrOpInterface>
17741777
]> {
17751778
let summary = "writes an n-D vector to an n-D slice of memory";
17761779
let description = [{
@@ -1875,7 +1878,10 @@ def Vector_StoreOp : Vector_Op<"store", [
18751878
}
18761879

18771880
def Vector_MaskedLoadOp :
1878-
Vector_Op<"maskedload", [DeclareOpInterfaceMethods<MemorySpaceCastConsumerOpInterface>]>,
1881+
Vector_Op<"maskedload", [
1882+
DeclareOpInterfaceMethods<MemorySpaceCastConsumerOpInterface>,
1883+
DeclareOpInterfaceMethods<AlignmentAttrOpInterface>
1884+
]>,
18791885
Arguments<(ins Arg<AnyMemRef, "", [MemRead]>:$base,
18801886
Variadic<Index>:$indices,
18811887
VectorOfNonZeroRankOf<[I1]>:$mask,
@@ -1967,7 +1973,10 @@ def Vector_MaskedLoadOp :
19671973
}
19681974

19691975
def Vector_MaskedStoreOp :
1970-
Vector_Op<"maskedstore", [DeclareOpInterfaceMethods<MemorySpaceCastConsumerOpInterface>]>,
1976+
Vector_Op<"maskedstore", [
1977+
DeclareOpInterfaceMethods<MemorySpaceCastConsumerOpInterface>,
1978+
DeclareOpInterfaceMethods<AlignmentAttrOpInterface>
1979+
]>,
19711980
Arguments<(ins Arg<AnyMemRef, "", [MemWrite]>:$base,
19721981
Variadic<Index>:$indices,
19731982
VectorOfNonZeroRankOf<[I1]>:$mask,
@@ -2048,7 +2057,8 @@ def Vector_GatherOp :
20482057
Vector_Op<"gather", [
20492058
DeclareOpInterfaceMethods<MaskableOpInterface>,
20502059
DeclareOpInterfaceMethods<MemorySpaceCastConsumerOpInterface>,
2051-
DeclareOpInterfaceMethods<VectorUnrollOpInterface, ["getShapeForUnroll"]>
2060+
DeclareOpInterfaceMethods<VectorUnrollOpInterface, ["getShapeForUnroll"]>,
2061+
DeclareOpInterfaceMethods<AlignmentAttrOpInterface>
20522062
]>,
20532063
Arguments<(ins Arg<TensorOrMemRef<[AnyType]>, "", [MemRead]>:$base,
20542064
Variadic<Index>:$offsets,
@@ -2151,7 +2161,10 @@ def Vector_GatherOp :
21512161
}
21522162

21532163
def Vector_ScatterOp :
2154-
Vector_Op<"scatter", [DeclareOpInterfaceMethods<MemorySpaceCastConsumerOpInterface>]>,
2164+
Vector_Op<"scatter", [
2165+
DeclareOpInterfaceMethods<MemorySpaceCastConsumerOpInterface>,
2166+
DeclareOpInterfaceMethods<AlignmentAttrOpInterface>
2167+
]>,
21552168
Arguments<(ins Arg<AnyMemRef, "", [MemWrite]>:$base,
21562169
Variadic<Index>:$offsets,
21572170
VectorOfNonZeroRankOf<[AnyInteger, Index]>:$indices,
@@ -2236,7 +2249,10 @@ def Vector_ScatterOp :
22362249
}
22372250

22382251
def Vector_ExpandLoadOp :
2239-
Vector_Op<"expandload", [DeclareOpInterfaceMethods<MemorySpaceCastConsumerOpInterface>]>,
2252+
Vector_Op<"expandload", [
2253+
DeclareOpInterfaceMethods<MemorySpaceCastConsumerOpInterface>,
2254+
DeclareOpInterfaceMethods<AlignmentAttrOpInterface>
2255+
]>,
22402256
Arguments<(ins Arg<AnyMemRef, "", [MemRead]>:$base,
22412257
Variadic<Index>:$indices,
22422258
FixedVectorOfNonZeroRankOf<[I1]>:$mask,
@@ -2324,7 +2340,10 @@ def Vector_ExpandLoadOp :
23242340
}
23252341

23262342
def Vector_CompressStoreOp :
2327-
Vector_Op<"compressstore", [DeclareOpInterfaceMethods<MemorySpaceCastConsumerOpInterface>]>,
2343+
Vector_Op<"compressstore", [
2344+
DeclareOpInterfaceMethods<MemorySpaceCastConsumerOpInterface>,
2345+
DeclareOpInterfaceMethods<AlignmentAttrOpInterface>
2346+
]>,
23282347
Arguments<(ins Arg<AnyMemRef, "", [MemWrite]>:$base,
23292348
Variadic<Index>:$indices,
23302349
FixedVectorOfNonZeroRankOf<[I1]>:$mask,
Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,21 @@
1+
//===- AlignmentAttrInterface.h - Alignment attribute interface -*- C++ -*-===//
2+
//
3+
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4+
// See https://llvm.org/LICENSE.txt for license information.
5+
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6+
//
7+
//===----------------------------------------------------------------------===//
8+
9+
#ifndef MLIR_INTERFACES_ALIGNMENTATTRINTERFACE_H
10+
#define MLIR_INTERFACES_ALIGNMENTATTRINTERFACE_H
11+
12+
#include "mlir/IR/OpDefinition.h"
13+
#include "llvm/Support/Alignment.h"
14+
15+
namespace mlir {
16+
class MLIRContext;
17+
} // namespace mlir
18+
19+
#include "mlir/Interfaces/AlignmentAttrInterface.h.inc"
20+
21+
#endif // MLIR_INTERFACES_ALIGNMENTATTRINTERFACE_H
Lines changed: 65 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,65 @@
1+
//===- AlignmentAttrInterface.td - Alignment attribute interface -*- tablegen -*-===//
2+
//
3+
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4+
// See https://llvm.org/LICENSE.txt for license information.
5+
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6+
//
7+
//===----------------------------------------------------------------------===//
8+
//
9+
// This file defines an interface for operations that expose an optional
10+
// alignment attribute.
11+
//
12+
//===----------------------------------------------------------------------===//
13+
14+
#ifndef MLIR_INTERFACES_ALIGNMENTATTRINTERFACE_TD
15+
#define MLIR_INTERFACES_ALIGNMENTATTRINTERFACE_TD
16+
17+
include "mlir/IR/OpBase.td"
18+
19+
def AlignmentAttrOpInterface : OpInterface<"AlignmentAttrOpInterface"> {
20+
let description = [{
21+
An interface for operations that carry an optional alignment attribute and
22+
want to expose it as an `llvm::MaybeAlign` helper.
23+
}];
24+
25+
let cppNamespace = "::mlir";
26+
27+
let methods = [
28+
InterfaceMethod<[{
29+
Returns the alignment encoded on the operation as an `llvm::MaybeAlign`.
30+
Operations providing a differently named accessor can override the
31+
default implementation.
32+
}],
33+
"::llvm::MaybeAlign",
34+
"getMaybeAlign",
35+
(ins),
36+
[{
37+
// Defensive: trait implementations are expected to validate power-of-two
38+
// alignments, but we still guard against accidental misuse.
39+
auto alignmentOpt = $_op.getAlignment();
40+
if (!alignmentOpt || *alignmentOpt <= 0)
41+
return ::llvm::MaybeAlign();
42+
uint64_t value = static_cast<uint64_t>(*alignmentOpt);
43+
if (!::llvm::isPowerOf2_64(value))
44+
return ::llvm::MaybeAlign();
45+
return ::llvm::MaybeAlign(value);
46+
}]
47+
>
48+
];
49+
50+
let extraTraitClassDeclaration = [{
51+
::llvm::MaybeAlign getMaybeAlign() {
52+
// Defensive: trait implementations are expected to validate power-of-two
53+
// alignments, but we still guard against accidental misuse.
54+
auto alignmentOpt = (*static_cast<ConcreteOp *>(this)).getAlignment();
55+
if (!alignmentOpt || *alignmentOpt <= 0)
56+
return ::llvm::MaybeAlign();
57+
uint64_t value = static_cast<uint64_t>(*alignmentOpt);
58+
if (!::llvm::isPowerOf2_64(value))
59+
return ::llvm::MaybeAlign();
60+
return ::llvm::MaybeAlign(value);
61+
}
62+
}];
63+
}
64+
65+
#endif // MLIR_INTERFACES_ALIGNMENTATTRINTERFACE_TD

mlir/include/mlir/Interfaces/CMakeLists.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
1+
add_mlir_interface(AlignmentAttrInterface)
12
add_mlir_interface(CallInterfaces)
23
add_mlir_interface(CastInterfaces)
34
add_mlir_interface(ControlFlowInterfaces)

0 commit comments

Comments
 (0)