Skip to content

Commit 9220772

Browse files
authored
[PROTON-DEV] Refactor frontend interface (#5825)
1 parent df7a403 commit 9220772

File tree

15 files changed

+208
-67
lines changed

15 files changed

+208
-67
lines changed

lib/Conversion/TritonToTritonGPU/TritonToTritonGPUPass.cpp

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -631,14 +631,15 @@ void populateTritonPatterns(TritonGPUTypeConverter &typeConverter,
631631
}
632632
// Proton patterns
633633
// NOTE: Because Proton's inputs are scalars and not tensors this conversion
634-
// isn't strictly nessessary however you could envision a case where we pass in
634+
// isn't strictly necessary however you could envision a case where we pass in
635635
// tensors in for Triton object specific tracing operations in which case we
636636
// would need to fill in the OpConversionPattern
637637
void populateProtonPatterns(TritonGPUTypeConverter &typeConverter,
638638
RewritePatternSet &patterns) {
639639
MLIRContext *context = patterns.getContext();
640-
patterns.add<GenericOpPattern<triton::proton::RecordOp>>(typeConverter,
641-
context);
640+
patterns.add<GenericOpPattern<triton::proton::RecordOp>,
641+
GenericOpPattern<triton::proton::InitScopeOp>>(typeConverter,
642+
context);
642643
}
643644
//
644645
// SCF patterns

python/src/ir.cc

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1735,8 +1735,12 @@ void init_triton_ir(py::module &&m) {
17351735
})
17361736
// Proton Ops
17371737
.def("create_proton_record",
1738-
[](TritonOpBuilder &self, bool isStart, int32_t regionId) -> void {
1739-
self.create<mlir::triton::proton::RecordOp>(isStart, regionId);
1738+
[](TritonOpBuilder &self, bool isStart, Value &scopeId) -> void {
1739+
self.create<mlir::triton::proton::RecordOp>(isStart, scopeId);
1740+
})
1741+
.def("create_proton_init_scope",
1742+
[](TritonOpBuilder &self, const std::string &name) -> Value {
1743+
return self.create<mlir::triton::proton::InitScopeOp>(name);
17401744
});
17411745

17421746
py::class_<PassManager>(m, "pass_manager", py::module_local())

python/triton/compiler/compiler.py

Lines changed: 4 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -355,6 +355,7 @@ class CompiledKernel:
355355
# TODO: move out of this namespace since it's a runtime thing
356356
launch_enter_hook = None
357357
launch_exit_hook = None
358+
init_handle_hook = None
358359

359360
def __init__(self, src, metadata_group, hash):
360361
from collections import namedtuple
@@ -403,6 +404,8 @@ def _init_handles(self):
403404
# TODO: n_regs, n_spills should be metadata generated when calling `ptxas`
404405
self.module, self.function, self.n_regs, self.n_spills = driver.active.utils.load_binary(
405406
self.name, self.kernel, self.metadata.shared, device)
407+
if self.init_handle_hook is not None:
408+
self.init_handle_hook(self.module, self.function, self.metadata_path)
406409

407410
def __getattribute__(self, name):
408411
if name == 'run':
@@ -415,11 +418,7 @@ def launch_metadata(self, grid, stream, *args):
415418
ret = LazyDict({"name": self.name, "function": self.function, "stream": stream})
416419
if not isinstance(self.src, ASTSource) or self.src.fn.launch_metadata is None:
417420
return ret
418-
arg_dict = {}
419-
arg_idx = 0
420-
for i, arg_name in enumerate(self.src.fn.arg_names):
421-
arg_dict[arg_name] = args[arg_idx]
422-
arg_idx += 1
421+
arg_dict = {name: arg for name, arg in zip(self.src.fn.arg_names, args)}
423422
ret.add(self.src.fn.launch_metadata, (grid, self.metadata, arg_dict))
424423
return ret
425424

test/Proton/ops.mlir

Lines changed: 14 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,12 +1,21 @@
1-
// RUN: triton-opt --split-input-file %s -cse -canonicalize --proton-lowering-pass | FileCheck %s
1+
// RUN: triton-opt --split-input-file %s | FileCheck %s
22

33
module {
4+
// CHECK-LABEL: proton_init_scope
5+
tt.func @proton_init_scope() {
6+
// CHECK: proton.init_scope "name0" : i32
7+
// CHECK-NEXT: tt.return
8+
%0 = proton.init_scope "name0" : i32
9+
tt.return
10+
}
11+
// CHECK-LABEL: proton_record
412
tt.func @proton_record() {
5-
// CHECK: proton.record() {isStart = true, regionId = 1 : i32}
6-
// CHECK-NEXT: proton.record() {isStart = false, regionId = 1 : i32}
13+
// CHECK: proton.record start %0
14+
// CHECK: proton.record end %0
715
// CHECK-NEXT: tt.return
8-
proton.record() {isStart = true, regionId = 1 : i32}
9-
proton.record() {isStart = false, regionId = 1 : i32}
16+
%0 = proton.init_scope "name0" : i32
17+
proton.record start %0
18+
proton.record end %0
1019
tt.return
1120
}
1221
} // end module

third_party/amd/lib/TritonAMDGPUToLLVM/TritonGPUToLLVM.cpp

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -231,6 +231,8 @@ struct ConvertTritonAMDGPUToLLVM
231231
mlir::triton::populatePrintOpToLLVMPattern(typeConverter, patterns,
232232
targetInfo, commonBenefit);
233233

234+
mlir::triton::proton::populateInitScopeOpToLLVMPattern(
235+
typeConverter, patterns, commonBenefit);
234236
mlir::triton::proton::populateRecordOpToLLVMPattern(
235237
typeConverter, patterns, targetInfo, commonBenefit);
236238

third_party/nvidia/lib/TritonNVIDIAGPUToLLVM/TritonGPUToLLVM.cpp

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -153,6 +153,8 @@ struct ConvertTritonGPUToLLVM
153153
targetInfo, benefit);
154154
mlir::triton::populatePrintOpToLLVMPattern(typeConverter, patterns,
155155
targetInfo, benefit);
156+
mlir::triton::proton::populateInitScopeOpToLLVMPattern(typeConverter,
157+
patterns, benefit);
156158
mlir::triton::proton::populateRecordOpToLLVMPattern(typeConverter, patterns,
157159
targetInfo, benefit);
158160
mlir::triton::populateControlFlowOpToLLVMPattern(typeConverter, patterns,

third_party/proton/csrc/include/Context/Shadow.h

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -37,4 +37,4 @@ class ShadowContextSource : public ContextSource, public ScopeInterface {
3737

3838
} // namespace proton
3939

40-
#endif // PROTON_CONTEXT_CONTEXT_H_
40+
#endif // PROTON_CONTEXT_SHADOW_H_

third_party/proton/dialect/include/Dialect/Proton/IR/ProtonOps.td

Lines changed: 52 additions & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -13,61 +13,80 @@ include "ProtonDialect.td"
1313
include "ProtonAttrDefs.td"
1414

1515
class PT_Op<string mnemonic, list<Trait> traits = []> :
16-
Op<Proton_Dialect, mnemonic, !listconcat(traits, [])> {
16+
Op<Proton_Dialect, mnemonic, !listconcat(traits, [])> {
1717
}
1818

1919
def GlobalMemory : Resource<"::mlir::triton::GlobalMemory">;
2020
def SharedMemory : Resource<"::mlir::triton::gpu::SharedMemory">;
2121

2222
// Proton profiling metric.
2323
def MetricAttr : I32EnumAttr<
24-
"Metric", "",
25-
[
26-
I32EnumAttrCase<"CYCLE", 0, "cycle">,
27-
]> {
28-
let cppNamespace = "::mlir::triton::proton";
24+
"Metric", "",
25+
[
26+
I32EnumAttrCase<"CYCLE", 0, "cycle">,
27+
]> {
28+
let cppNamespace = "::mlir::triton::proton";
2929
}
3030

3131
// Proton profiling granularity.
3232
def GranularityAttr : I32EnumAttr<
33-
"Granularity", "",
34-
[
35-
I32EnumAttrCase<"WARPGROUP", 0, "warpgroup">,
36-
I32EnumAttrCase<"WARP", 1, "warp">,
37-
]> {
38-
let cppNamespace = "::mlir::triton::proton";
33+
"Granularity", "",
34+
[
35+
I32EnumAttrCase<"WARPGROUP", 0, "warpgroup">,
36+
I32EnumAttrCase<"WARP", 1, "warp">,
37+
]> {
38+
let cppNamespace = "::mlir::triton::proton";
39+
}
40+
41+
def PT_InitScopeOp : PT_Op<"init_scope", [Pure]> {
42+
let summary = "Initialize a scope";
43+
44+
let description = [{
45+
This operation initializes a scope with the given name and returns a unique id for the scope.
46+
47+
Example:
48+
49+
```mlir
50+
%scope0 = proton.init_scope "name0" : i32
51+
```
52+
}];
53+
54+
let arguments = (
55+
ins StrAttr: $scopeName
56+
);
57+
let results = (outs I32 : $scopeId);
58+
59+
let assemblyFormat = "$scopeName attr-dict `:` type($scopeId) ";
60+
// hasVerifier = 1; verify that (1) each scope is used twice, and (2) the use ops are record like ops
3961
}
4062

4163
def PT_RecordOp : PT_Op<"record", [
42-
MemoryEffects<[MemRead<DefaultResource>]>,
43-
MemoryEffects<[MemWrite<DefaultResource>]>
44-
]> {
45-
let summary = "Record a GPU metric event";
64+
MemoryEffects<[MemRead<DefaultResource>, MemWrite<DefaultResource>]>
65+
]> {
66+
let summary = "Record an event";
4667

4768
let description = [{
48-
The operator records GPU event of a particular metric.
49-
Essentially a marker with a region id.
69+
This operation records events of a particular metric.
5070

5171
Example:
5272

5373
```mlir
54-
proton.record() {isStart = true, regionId = 4 : i32}
74+
proton.record start %scope0
5575
...
56-
proton.record() {isStart = false, regionId = 4 : i32}
76+
proton.record end %scope0
5777
```
5878
}];
5979
let arguments = (
60-
ins BoolAttr: $isStart,
61-
ConfinedAttr<I32Attr, [IntNonNegative]>:$regionId
80+
ins UnitAttr: $isStart,
81+
I32: $scopeId
6282
);
6383

64-
let assemblyFormat = " `(` operands `)` attr-dict";
84+
let assemblyFormat = "(`start` $isStart^):(`end`)? $scopeId attr-dict";
6585
}
6686

6787
def PT_CircularRecordOp : PT_Op<"circular_record", [
68-
MemoryEffects<[MemRead<DefaultResource>]>,
69-
MemoryEffects<[MemWrite<DefaultResource>]>
70-
]> {
88+
MemoryEffects<[MemRead<DefaultResource>, MemWrite<DefaultResource>]>
89+
]> {
7190
let summary = "Record a GPU metric event into a circular buffer";
7291

7392
let description = [{
@@ -76,16 +95,19 @@ def PT_CircularRecordOp : PT_Op<"circular_record", [
7695
get dropped if the `data` is full.
7796
}];
7897
let arguments = (
79-
ins TTG_MemDescType:$data,
98+
ins UnitAttr: $isStart,
99+
I32: $scopeId,
100+
TTG_MemDescType:$data,
80101
TT_PtrLike :$indexPtr,
81-
BoolAttr: $isStart,
82-
ConfinedAttr<I32Attr, [IntNonNegative]>:$regionId,
83102
DefaultValuedAttr<MetricAttr, "Metric::CYCLE">:$metric,
84103
DefaultValuedAttr<GranularityAttr, "Granularity::WARPGROUP">:$granularity
85104
);
86105
let hasVerifier = 1;
87106

88-
let assemblyFormat = [{$data `,` $indexPtr attr-dict `:` qualified(type($data)) `,` type($indexPtr)}];
107+
let assemblyFormat = [{
108+
(`start` $isStart^):(`end`)? $scopeId `,` $data `,` $indexPtr attr-dict `:`
109+
qualified(type($data)) `,` type($indexPtr)
110+
}];
89111
}
90112

91113
def PT_FinalizeOp : PT_Op<"finalize", [
Lines changed: 8 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
1-
#ifndef TRITON_CONVERSION_TRITONPROTON_TO_LLVM_PATTERNS_TRITON_PROTON_OP_TO_LLVM_H
2-
#define TRITON_CONVERSION_TRITONPROTON_TO_LLVM_PATTERNS_TRITON_PROTON_OP_TO_LLVM_H
1+
#ifndef TRITON_PROTON_TO_LLVM_PATTERN_TRITON_PROTON_OP_TO_LLVM_H
2+
#define TRITON_PROTON_TO_LLVM_PATTERN_TRITON_PROTON_OP_TO_LLVM_H
33

44
#include "mlir/Conversion/LLVMCommon/TypeConverter.h"
55

@@ -10,7 +10,12 @@ void populateRecordOpToLLVMPattern(LLVMTypeConverter &typeConverter,
1010
RewritePatternSet &patterns,
1111
const TargetInfoBase &targetInfo,
1212
PatternBenefit benefit);
13+
14+
void populateInitScopeOpToLLVMPattern(LLVMTypeConverter &typeConverter,
15+
RewritePatternSet &patterns,
16+
PatternBenefit benefit);
17+
1318
} // namespace proton
1419
} // namespace mlir::triton
1520

16-
#endif
21+
#endif // TRITON_PROTON_TO_LLVM_PATTERN_TRITON_PROTON_OP_TO_LLVM_H

third_party/proton/dialect/lib/TritonProtonToLLVM/CMakeLists.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
add_triton_library(TritonProtonToLLVM
2+
InitScopeOpToLLVM.cpp
23
RecordOpToLLVM.cpp
34
ProtonLoweringPass.cpp
45

0 commit comments

Comments
 (0)