Skip to content

Commit 48aed55

Browse files
fywkevinYuanwei Fang
authored andcommitted
[Proton][Dialect] Middle-end Proton operator definitions (#5754)
This PR is a follow-up of #5677 for proton compiler mid-end support, focusing on op definition. Specifically, we 1. added the tablegen definitions of the proton mid-end operators, 2. removed attributes of the front-end `RecordOp` (make it a true marker), 3. cleaned up the dialect's macro. --------- Co-authored-by: Yuanwei Fang <[email protected]>
1 parent 81c251b commit 48aed55

File tree

3 files changed

+82
-21
lines changed

3 files changed

+82
-21
lines changed

third_party/proton/dialect/include/Dialect/Proton/IR/Dialect.h

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,12 +1,14 @@
1-
#ifndef TRITON_DIALECT_PROTON_IR_DIALECT_H_
2-
#define TRITON_DIALECT_PROTON_IR_DIALECT_H_
1+
#ifndef DIALECT_PROTON_IR_DIALECT_H_
2+
#define DIALECT_PROTON_IR_DIALECT_H_
33

44
#include "mlir/Dialect/LLVMIR/LLVMDialect.h"
55
#include "mlir/IR/BuiltinOps.h"
66
#include "mlir/IR/Dialect.h"
77
#include "mlir/IR/PatternMatch.h"
88
#include "proton/dialect/include/Dialect/Proton/IR/Dialect.h.inc"
99
#include "proton/dialect/include/Dialect/Proton/IR/OpsEnums.h.inc"
10+
#include "triton/Dialect/Triton/IR/Dialect.h"
11+
#include "triton/Dialect/TritonGPU/IR/Dialect.h"
1012

1113
#define GET_ATTRDEF_CLASSES
1214
#include "proton/dialect/include/Dialect/Proton/IR/ProtonAttrDefs.h.inc"
@@ -20,4 +22,4 @@ namespace proton {} // namespace proton
2022
} // namespace triton
2123
} // namespace mlir
2224

23-
#endif // TRITON_DIALECT_PROTON_IR_DIALECT_H_
25+
#endif // DIALECT_PROTON_IR_DIALECT_H_

third_party/proton/dialect/include/Dialect/Proton/IR/ProtonOps.td

Lines changed: 71 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -4,17 +4,21 @@
44
include "mlir/IR/OpBase.td"
55
include "mlir/IR/EnumAttr.td"
66
include "triton/Dialect/Triton/IR/TritonTypes.td"
7+
include "triton/Dialect/TritonGPU/IR/TritonGPUTypes.td"
78
include "mlir/Dialect/LLVMIR/LLVMOpBase.td"
89
include "mlir/Interfaces/InferTypeOpInterface.td"
910
include "mlir/Interfaces/SideEffectInterfaces.td"
1011
include "triton/Dialect/Triton/IR/TritonInterfaces.td"
1112
include "ProtonDialect.td"
1213
include "ProtonAttrDefs.td"
1314

14-
class TT_Proton_Op<string mnemonic, list<Trait> traits = []> :
15+
class PT_Op<string mnemonic, list<Trait> traits = []> :
1516
Op<Proton_Dialect, mnemonic, !listconcat(traits, [])> {
1617
}
1718

19+
def GlobalMemory : Resource<"::mlir::triton::GlobalMemory">;
20+
def SharedMemory : Resource<"::mlir::triton::gpu::SharedMemory">;
21+
1822
// Proton profiling metric.
1923
def MetricAttr : I32EnumAttr<
2024
"Metric", "",
@@ -34,32 +38,89 @@ def GranularityAttr : I32EnumAttr<
3438
let cppNamespace = "::mlir::triton::proton";
3539
}
3640

37-
def TT_RecordOp : TT_Proton_Op<"record", [DeclareOpInterfaceMethods<MemoryEffectsOpInterface>]> {
38-
let summary = "Record a GPU hardware event";
41+
def PT_RecordOp : PT_Op<"record", [
42+
MemoryEffects<[MemRead<DefaultResource>]>,
43+
MemoryEffects<[MemWrite<DefaultResource>]>
44+
]> {
45+
let summary = "Record a GPU metric event";
3946

4047
let description = [{
41-
The operator records GPU events from performance counters.
42-
Currently only cycle counter is supported.
48+
The operator records GPU event of a particular metric.
49+
Essentially a marker with a region id.
4350

4451
Example:
4552

4653
```mlir
4754
proton.record() {isStart = true, regionId = 4 : i32}
4855
...
4956
proton.record() {isStart = false, regionId = 4 : i32}
50-
...
51-
proton.record() {isStart = true, regionId = 1 : i32, granularity = 1 : i32}
52-
...
53-
proton.record() {isStart = false, regionId = 1 : i32, granularity = 1 : i32}
5457
```
5558
}];
5659
let arguments = (
5760
ins BoolAttr: $isStart,
61+
ConfinedAttr<I32Attr, [IntNonNegative]>:$regionId
62+
);
63+
64+
let assemblyFormat = " `(` operands `)` attr-dict";
65+
}
66+
67+
def PT_CircularRecordOp : PT_Op<"circular_record", [
68+
MemoryEffects<[MemRead<DefaultResource>]>,
69+
MemoryEffects<[MemWrite<DefaultResource>]>
70+
]> {
71+
let summary = "Record a GPU metric event into a circular buffer";
72+
73+
let description = [{
74+
Records a metric event into a circular buffer backed by the internal memory `data`.
75+
The circular buffer indexing `indexPtr` is automatically maintained. Older events
76+
get dropped if the `data` is full.
77+
}];
78+
let arguments = (
79+
ins TTG_MemDescType:$data,
80+
TT_PtrLike :$indexPtr,
81+
BoolAttr: $isStart,
5882
ConfinedAttr<I32Attr, [IntNonNegative]>:$regionId,
5983
DefaultValuedAttr<MetricAttr, "Metric::CYCLE">:$metric,
6084
DefaultValuedAttr<GranularityAttr, "Granularity::WARPGROUP">:$granularity
6185
);
62-
let assemblyFormat = " `(` operands `)` attr-dict";
86+
let hasVerifier = 1;
87+
88+
let assemblyFormat = [{$data `,` $indexPtr attr-dict `:` qualified(type($data)) `,` type($indexPtr)}];
89+
}
90+
91+
def PT_FinalizeOp : PT_Op<"finalize", [
92+
MemoryEffects<[MemRead<SharedMemory>]>,
93+
MemoryEffects<[MemRead<GlobalMemory>]>,
94+
MemoryEffects<[MemWrite<GlobalMemory>]>
95+
]> {
96+
let summary = "Finalize the intra kernel profiler";
97+
98+
let description = [{
99+
Finalize the intra kernel profiler, writing back the metadata and profile to the global memory.
100+
}];
101+
let arguments = (
102+
ins TTG_MemDescType:$data,
103+
TT_PtrLike :$indexPtr,
104+
TT_PtrLike :$ptr,
105+
I32Attr :$size
106+
);
107+
108+
let assemblyFormat = [{$data `,` $indexPtr `,` $ptr attr-dict `:` qualified(type($data)) `,` type($indexPtr) `,` type($ptr)}];
109+
}
110+
111+
def PT_InitOp : PT_Op<"init", [
112+
MemoryEffects<[MemAlloc<GlobalMemory>]>
113+
]> {
114+
let summary = "Initialize the intra kernel profiler";
115+
116+
let description = [{
117+
Stack allocation and initialization for the intra kernel profiler.
118+
`indexPtr` stores the number of entries proton recorded (zero initialized).
119+
We expect `indexPtr` to be register promoted during the LLVM lowering phase.
120+
}];
121+
let arguments = (ins);
122+
let results = (outs TT_PtrLike :$indexPtr);
123+
let assemblyFormat = "attr-dict `:` type($indexPtr)";
63124
}
64125

65126
#endif // PROTON_OPS

third_party/proton/dialect/lib/Dialect/Proton/IR/Ops.cpp

Lines changed: 6 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -18,14 +18,12 @@ namespace mlir {
1818
namespace triton {
1919
namespace proton {
2020

21-
// -- RecordOp --
22-
void RecordOp::getEffects(
23-
SmallVectorImpl<SideEffects::EffectInstance<MemoryEffects::Effect>>
24-
&effects) {
25-
effects.emplace_back(MemoryEffects::Write::get(),
26-
SideEffects::DefaultResource::get());
27-
effects.emplace_back(MemoryEffects::Read::get(),
28-
SideEffects::DefaultResource::get());
21+
// -- CircularRecordOp --
22+
LogicalResult CircularRecordOp::verify() {
23+
// TODO(fywkevin): checks the following:
24+
// 1. circular buffer size power of 2.
25+
// 2. function's noinline is false.
26+
return success();
2927
}
3028

3129
} // namespace proton

0 commit comments

Comments
 (0)