Skip to content

Commit b36b81f

Browse files
committed
Update
1 parent 9220772 commit b36b81f

27 files changed

+260
-136
lines changed
Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1 +1,2 @@
11
add_subdirectory(Proton)
2+
add_subdirectory(ProtonGPU)
Lines changed: 7 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1,18 +1,19 @@
11
set(MLIR_BINARY_DIR ${CMAKE_BINARY_DIR})
22

33
set(LLVM_TARGET_DEFINITIONS ProtonOps.td)
4-
mlir_tablegen(Dialect.h.inc -gen-dialect-decls -dialect=proton)
5-
mlir_tablegen(Dialect.cpp.inc -gen-dialect-defs -dialect=proton)
6-
mlir_tablegen(OpsConversions.inc -gen-llvmir-conversions)
74
mlir_tablegen(Ops.h.inc -gen-op-decls)
85
mlir_tablegen(Ops.cpp.inc -gen-op-defs)
96
mlir_tablegen(OpsEnums.h.inc -gen-enum-decls)
107
mlir_tablegen(OpsEnums.cpp.inc -gen-enum-defs)
11-
add_mlir_doc(ProtonDialect ProtonDialect dialects/ -gen-dialect-doc)
128
add_mlir_doc(ProtonOps ProtonOps dialects/ -gen-op-doc)
13-
add_public_tablegen_target(ProtonTableGen)
9+
10+
set(LLVM_TARGET_DEFINITIONS ProtonDialect.td)
11+
mlir_tablegen(Dialect.h.inc -gen-dialect-decls -dialect=proton)
12+
mlir_tablegen(Dialect.cpp.inc -gen-dialect-defs -dialect=proton)
13+
add_mlir_doc(ProtonDialect ProtonDialect dialects/ -gen-dialect-doc)
1414

1515
set(LLVM_TARGET_DEFINITIONS ProtonAttrDefs.td)
1616
mlir_tablegen(ProtonAttrDefs.h.inc -gen-attrdef-decls)
1717
mlir_tablegen(ProtonAttrDefs.cpp.inc -gen-attrdef-defs)
18-
add_public_tablegen_target(ProtonAttrDefsIncGen)
18+
19+
add_public_tablegen_target(ProtonTableGen)

third_party/proton/dialect/include/Dialect/Proton/IR/Dialect.h

Lines changed: 0 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -6,20 +6,12 @@
66
#include "mlir/IR/Dialect.h"
77
#include "mlir/IR/PatternMatch.h"
88
#include "proton/dialect/include/Dialect/Proton/IR/Dialect.h.inc"
9-
#include "proton/dialect/include/Dialect/Proton/IR/OpsEnums.h.inc"
109
#include "triton/Dialect/Triton/IR/Dialect.h"
11-
#include "triton/Dialect/TritonGPU/IR/Dialect.h"
1210

1311
#define GET_ATTRDEF_CLASSES
1412
#include "proton/dialect/include/Dialect/Proton/IR/ProtonAttrDefs.h.inc"
1513

1614
#define GET_OP_CLASSES
1715
#include "proton/dialect/include/Dialect/Proton/IR/Ops.h.inc"
1816

19-
namespace mlir {
20-
namespace triton {
21-
namespace proton {} // namespace proton
22-
} // namespace triton
23-
} // namespace mlir
24-
2517
#endif // DIALECT_PROTON_IR_DIALECT_H_
Lines changed: 12 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,12 +1,20 @@
1-
#ifndef PROTON_ATTRDEFS
2-
#define PROTON_ATTRDEFS
1+
#ifndef PROTON_ATTR_DEFS
2+
#define PROTON_ATTR_DEFS
33

44
include "mlir/IR/AttrTypeBase.td"
55
include "ProtonDialect.td"
66

77
class Proton_Attr<string name, list<Trait> traits = [],
8-
string baseCppClass = "::mlir::Attribute">
8+
string baseCppClass = "::mlir::Attribute">
99
: AttrDef<Proton_Dialect, name, traits, baseCppClass> {
1010
}
1111

12-
#endif // PROTON_ATTRDEFS
12+
def MetricAttr : I32EnumAttr<
13+
"Metric", "",
14+
[
15+
I32EnumAttrCase<"CYCLE", 0, "cycle">,
16+
]> {
17+
let cppNamespace = "::mlir::triton::proton";
18+
}
19+
20+
#endif // PROTON_ATTR_DEFS

third_party/proton/dialect/include/Dialect/Proton/IR/ProtonDialect.td

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,10 @@ def Proton_Dialect : Dialect {
1313
}];
1414

1515
let dependentDialects = [];
16+
17+
let useDefaultTypePrinterParser = 1;
18+
let useDefaultAttributePrinterParser = 1;
19+
let usePropertiesForAttributes = 1;
1620
}
1721

1822
#endif

third_party/proton/dialect/include/Dialect/Proton/IR/ProtonOps.td

Lines changed: 0 additions & 86 deletions
Original file line numberDiff line numberDiff line change
@@ -3,41 +3,16 @@
33

44
include "mlir/IR/OpBase.td"
55
include "mlir/IR/EnumAttr.td"
6-
include "triton/Dialect/Triton/IR/TritonTypes.td"
7-
include "triton/Dialect/TritonGPU/IR/TritonGPUTypes.td"
86
include "mlir/Dialect/LLVMIR/LLVMOpBase.td"
97
include "mlir/Interfaces/InferTypeOpInterface.td"
108
include "mlir/Interfaces/SideEffectInterfaces.td"
11-
include "triton/Dialect/Triton/IR/TritonInterfaces.td"
129
include "ProtonDialect.td"
1310
include "ProtonAttrDefs.td"
1411

1512
class PT_Op<string mnemonic, list<Trait> traits = []> :
1613
Op<Proton_Dialect, mnemonic, !listconcat(traits, [])> {
1714
}
1815

19-
def GlobalMemory : Resource<"::mlir::triton::GlobalMemory">;
20-
def SharedMemory : Resource<"::mlir::triton::gpu::SharedMemory">;
21-
22-
// Proton profiling metric.
23-
def MetricAttr : I32EnumAttr<
24-
"Metric", "",
25-
[
26-
I32EnumAttrCase<"CYCLE", 0, "cycle">,
27-
]> {
28-
let cppNamespace = "::mlir::triton::proton";
29-
}
30-
31-
// Proton profiling granularity.
32-
def GranularityAttr : I32EnumAttr<
33-
"Granularity", "",
34-
[
35-
I32EnumAttrCase<"WARPGROUP", 0, "warpgroup">,
36-
I32EnumAttrCase<"WARP", 1, "warp">,
37-
]> {
38-
let cppNamespace = "::mlir::triton::proton";
39-
}
40-
4116
def PT_InitScopeOp : PT_Op<"init_scope", [Pure]> {
4217
let summary = "Initialize a scope";
4318

@@ -84,65 +59,4 @@ def PT_RecordOp : PT_Op<"record", [
8459
let assemblyFormat = "(`start` $isStart^):(`end`)? $scopeId attr-dict";
8560
}
8661

87-
def PT_CircularRecordOp : PT_Op<"circular_record", [
88-
MemoryEffects<[MemRead<DefaultResource>, MemWrite<DefaultResource>]>
89-
]> {
90-
let summary = "Record a GPU metric event into a circular buffer";
91-
92-
let description = [{
93-
Records a metric event into a circular buffer backed by the internal memory `data`.
94-
The circular buffer indexing `indexPtr` is automatically maintained. Older events
95-
get dropped if the `data` is full.
96-
}];
97-
let arguments = (
98-
ins UnitAttr: $isStart,
99-
I32: $scopeId,
100-
TTG_MemDescType:$data,
101-
TT_PtrLike :$indexPtr,
102-
DefaultValuedAttr<MetricAttr, "Metric::CYCLE">:$metric,
103-
DefaultValuedAttr<GranularityAttr, "Granularity::WARPGROUP">:$granularity
104-
);
105-
let hasVerifier = 1;
106-
107-
let assemblyFormat = [{
108-
(`start` $isStart^):(`end`)? $scopeId `,` $data `,` $indexPtr attr-dict `:`
109-
qualified(type($data)) `,` type($indexPtr)
110-
}];
111-
}
112-
113-
def PT_FinalizeOp : PT_Op<"finalize", [
114-
MemoryEffects<[MemRead<SharedMemory>]>,
115-
MemoryEffects<[MemRead<GlobalMemory>]>,
116-
MemoryEffects<[MemWrite<GlobalMemory>]>
117-
]> {
118-
let summary = "Finalize the intra kernel profiler";
119-
120-
let description = [{
121-
Finalize the intra kernel profiler, writing back the metadata and profile to the global memory.
122-
}];
123-
let arguments = (
124-
ins TTG_MemDescType:$data,
125-
TT_PtrLike :$indexPtr,
126-
TT_PtrLike :$ptr,
127-
I32Attr :$size
128-
);
129-
130-
let assemblyFormat = [{$data `,` $indexPtr `,` $ptr attr-dict `:` qualified(type($data)) `,` type($indexPtr) `,` type($ptr)}];
131-
}
132-
133-
def PT_InitOp : PT_Op<"init", [
134-
MemoryEffects<[MemAlloc<GlobalMemory>]>
135-
]> {
136-
let summary = "Initialize the intra kernel profiler";
137-
138-
let description = [{
139-
Stack allocation and initialization for the intra kernel profiler.
140-
`indexPtr` stores the number of entries proton recorded (zero initialized).
141-
We expect `indexPtr` to be register promoted during the LLVM lowering phase.
142-
}];
143-
let arguments = (ins);
144-
let results = (outs TT_PtrLike :$indexPtr);
145-
let assemblyFormat = "attr-dict `:` type($indexPtr)";
146-
}
147-
14862
#endif // PROTON_OPS
Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1 @@
1+
add_subdirectory(IR)
Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,19 @@
1+
set(MLIR_BINARY_DIR ${CMAKE_BINARY_DIR})
2+
3+
set(LLVM_TARGET_DEFINITIONS ProtonGPUOps.td)
4+
mlir_tablegen(Ops.h.inc -gen-op-decls)
5+
mlir_tablegen(Ops.cpp.inc -gen-op-defs)
6+
mlir_tablegen(OpsEnums.h.inc -gen-enum-decls)
7+
mlir_tablegen(OpsEnums.cpp.inc -gen-enum-defs)
8+
add_mlir_doc(ProtonGPUOps ProtonGPUOps dialects/ -gen-op-doc)
9+
10+
set(LLVM_TARGET_DEFINITIONS ProtonGPUDialect.td)
11+
mlir_tablegen(Dialect.h.inc -gen-dialect-decls -dialect=proton_gpu)
12+
mlir_tablegen(Dialect.cpp.inc -gen-dialect-defs -dialect=proton_gpu)
13+
add_mlir_doc(ProtonGPUDialect ProtonGPUDialect dialects/ -gen-dialect-doc)
14+
15+
set(LLVM_TARGET_DEFINITIONS ProtonGPUAttrDefs.td)
16+
mlir_tablegen(ProtonGPUAttrDefs.h.inc -gen-attrdef-decls)
17+
mlir_tablegen(ProtonGPUAttrDefs.cpp.inc -gen-attrdef-defs)
18+
19+
add_public_tablegen_target(ProtonGPUTableGen)
Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,19 @@
1+
#ifndef DIALECT_PROTON_GPU_IR_DIALECT_H_
2+
#define DIALECT_PROTON_GPU_IR_DIALECT_H_
3+
4+
#include "mlir/Dialect/LLVMIR/LLVMDialect.h"
5+
#include "mlir/IR/BuiltinOps.h"
6+
#include "mlir/IR/Dialect.h"
7+
#include "mlir/IR/PatternMatch.h"
8+
#include "proton/dialect/include/Dialect/Proton/IR/Dialect.h.inc"
9+
#include "proton/dialect/include/Dialect/ProtonGPU/IR/Dialect.h.inc"
10+
#include "triton/Dialect/Triton/IR/Dialect.h"
11+
#include "triton/Dialect/TritonGPU/IR/Dialect.h"
12+
13+
#define GET_ATTRDEF_CLASSES
14+
#include "proton/dialect/include/Dialect/ProtonGPU/IR/ProtonGPUAttrDefs.h.inc"
15+
16+
#define GET_OP_CLASSES
17+
#include "proton/dialect/include/Dialect/ProtonGPU/IR/Ops.h.inc"
18+
19+
#endif // DIALECT_PROTON_GPU_IR_DIALECT_H_
Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,15 @@
1+
#ifndef PROTON_GPU_ATTR_DEFS
2+
#define PROTON_GPU_ATTR_DEFS
3+
4+
include "mlir/IR/EnumAttr.td"
5+
6+
def GranularityAttr : I32EnumAttr<
7+
"Granularity", "",
8+
[
9+
I32EnumAttrCase<"WARPGROUP", 0, "warpgroup">,
10+
I32EnumAttrCase<"WARP", 1, "warp">,
11+
]> {
12+
let cppNamespace = "::mlir::triton::proton::proton_gpu";
13+
}
14+
15+
#endif // PROTON_GPU_ATTR_DEFS

0 commit comments

Comments
 (0)