Skip to content
Merged
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 5 additions & 0 deletions include/triton/Dialect/Triton/IR/TritonDialect.td
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,11 @@ def Triton_Dialect : Dialect {
void registerTypes();
}];

let discardableAttrs = (ins
"::mlir::IntegerAttr":$num_stages,
"::mlir::IntegerAttr":$latency
);

let hasConstantMaterializer = 1;
let useDefaultTypePrinterParser = 1;
let usePropertiesForAttributes = 1;
Expand Down
12 changes: 9 additions & 3 deletions lib/Dialect/TritonGPU/Transforms/Pipeliner/AssignLatencies.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -37,18 +37,24 @@ bool preCondition(scf::ForOp forOp) {
}

bool hasLatenciesAssigned(scf::ForOp forOp) {
auto helper = forOp.getContext()
->getLoadedDialect<TritonDialect>()
->getLatencyAttrHelper();
for (auto &op : forOp.getBody()->without_terminator()) {
if (op.hasAttr("tt_latency"))
if (helper.getAttr(&op))
return true;
}
return false;
}

void assignUserProvidedLatencies(scf::ForOp forOp,
DenseMap<Operation *, int> &opLatency) {
auto helper = forOp.getContext()
->getLoadedDialect<TritonDialect>()
->getLatencyAttrHelper();
for (auto &op : forOp.getBody()->without_terminator()) {
if (auto latencyAttr = op.getAttr("tt_latency")) {
opLatency[&op] = mlir::cast<IntegerAttr>(latencyAttr).getInt();
if (auto latencyAttr = helper.getAttr(&op)) {
opLatency[&op] = latencyAttr.getInt();
}
}
}
Expand Down
28 changes: 17 additions & 11 deletions lib/Dialect/TritonGPU/Transforms/Pipeliner/PipeliningUtility.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -248,19 +248,24 @@ int mlir::triton::getCopyVecBytes(RankedTensorType registerTy,

void mlir::triton::serializeLatencies(ModuleOp module,
DenseMap<Operation *, int> &opLatency) {
auto helper = module.getContext()
->getLoadedDialect<TritonDialect>()
->getLatencyAttrHelper();
auto builder = Builder(module);
for (auto &[op, latency] : opLatency) {
op->setAttr(
kLatencyAttrName,
IntegerAttr::get(IntegerType::get(module.getContext(), 32), latency));
helper.setAttr(op, builder.getI32IntegerAttr(latency));
}
}

DenseMap<Operation *, int> mlir::triton::deserializeLatencies(Operation *op) {
auto helper = op->getContext()
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

suggestion: can you add a shorthand for this in Triton/Dialect.h?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes! added.

Copy link
Contributor Author

@sjw36 sjw36 Apr 15, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

posted these upstream here

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

well.. I guess Medhi and River are not big fans..

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

They tend to be picky :P

->getLoadedDialect<TritonDialect>()
->getLatencyAttrHelper();
DenseMap<Operation *, int> opLatency;
op->walk([&](Operation *op) {
if (op->hasAttr(kLatencyAttrName)) {
opLatency[op] = op->getAttrOfType<IntegerAttr>(kLatencyAttrName).getInt();
op->removeAttr(kLatencyAttrName);
if (auto attr = helper.getAttr(op)) {
opLatency[op] = attr.getInt();
helper.removeAttr(op);
}
});
return opLatency;
Expand Down Expand Up @@ -408,9 +413,10 @@ int mlir::triton::getNumStagesOrDefault(scf::ForOp forOp,
int defaultNumStages) {
// Use the attribute attached to the loop if it exists otherwise use the
// global control.
if (!forOp->hasAttr(mlir::triton::kNumStagesAttrName))
return defaultNumStages;
return mlir::cast<IntegerAttr>(
forOp->getAttr(mlir::triton::kNumStagesAttrName))
.getInt();
auto helper = forOp.getContext()
->getLoadedDialect<TritonDialect>()
->getNumStagesAttrHelper();
if (auto attr = helper.getAttr(forOp))
return attr.getInt();
return defaultNumStages;
}
4 changes: 2 additions & 2 deletions test/TritonGPU/loop-pipeline-async-latencies.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -101,9 +101,9 @@ tt.func public @matmul_kernel_tma_persistent(%arg0: !tt.ptr<i8, 0> {tt.nv_tma_de
// CHECK: [[RHS_MBAR:%.*]] = ttg.memdesc_subview [[RHS_BARS]][[[RHS_BUF_IDX]]]
// CHECK-NEXT: ttng.wait_barrier [[RHS_MBAR]], [[RHS_PHASE]]

%4 = tt.descriptor_load %1[%c0_i32, %arg6] {tt_latency = 1 : i32} : !tt.tensordesc<tensor<128x64xf16, #shared>> -> tensor<128x64xf16, #blocked>
%4 = tt.descriptor_load %1[%c0_i32, %arg6] {tt.latency = 1 : i32} : !tt.tensordesc<tensor<128x64xf16, #shared>> -> tensor<128x64xf16, #blocked>
%5 = ttg.local_alloc %4 : (tensor<128x64xf16, #blocked>) -> !ttg.memdesc<128x64xf16, #shared, #smem>
%6 = tt.descriptor_load %2[%c0_i32, %arg6] {tt_latency = 3 : i32} : !tt.tensordesc<tensor<256x64xf16, #shared>> -> tensor<256x64xf16, #blocked>
%6 = tt.descriptor_load %2[%c0_i32, %arg6] {tt.latency = 3 : i32} : !tt.tensordesc<tensor<256x64xf16, #shared>> -> tensor<256x64xf16, #blocked>
%7 = ttg.local_alloc %6 : (tensor<256x64xf16, #blocked>) -> !ttg.memdesc<256x64xf16, #shared, #smem>
%8 = ttg.memdesc_trans %7 {order = array<i32: 1, 0>} : !ttg.memdesc<256x64xf16, #shared, #smem> -> !ttg.memdesc<64x256xf16, #shared1, #smem>
%9 = ttng.warp_group_dot %5, %8, %arg7 {inputPrecision = 0 : i32} : !ttg.memdesc<128x64xf16, #shared, #smem> * !ttg.memdesc<64x256xf16, #shared1, #smem> -> tensor<128x256xf32, #mma>
Expand Down
6 changes: 3 additions & 3 deletions test/TritonGPU/loop-schedule.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -168,7 +168,7 @@ tt.func @prologue_backward_slice(%ub: i32, %cond: i1) {
// CHECK: op.with_region
"op.with_region"() ({
"use"(%1) : (i32) -> ()
}) {tt_latency = 2 : i32} : () -> ()
}) {tt.latency = 2 : i32} : () -> ()
// CHECK: loop.cluster = 1 : i32, loop.stage = 0 : i32

} {tt.num_stages = 3 : i32}
Expand All @@ -186,7 +186,7 @@ tt.func @epilogue_forward_slice(%ub: i32, %cond: i1) {
// CHECK: scf.for
scf.for %i = %c0_i32 to %ub step %c1_i32 : i32 {
// CHECK: "latency.op"() {loop.cluster = 3 : i32, loop.stage = 0 : i32
%0 = "latency.op"() {tt_latency = 2 : i32} : () -> i32
%0 = "latency.op"() {tt.latency = 2 : i32} : () -> i32
// CHECK: scf.if
%1 = scf.if %cond -> i32 {
scf.yield %0 : i32
Expand Down Expand Up @@ -219,7 +219,7 @@ tt.func @prologue_latency(%ub: i32, %cond: i1) {
scf.yield %0 : i32
} else {
scf.yield %c0_i32 : i32
} {tt_latency = 2 : i32}
} {tt.latency = 2 : i32}
// CHECK: loop.cluster = 0 : i32, loop.stage = 0 : i32

} {tt.num_stages = 3 : i32}
Expand Down