diff --git a/include/triton/Dialect/Triton/IR/TritonDialect.td b/include/triton/Dialect/Triton/IR/TritonDialect.td index a91b7951af00..a79ab29e8001 100644 --- a/include/triton/Dialect/Triton/IR/TritonDialect.td +++ b/include/triton/Dialect/Triton/IR/TritonDialect.td @@ -34,8 +34,20 @@ def Triton_Dialect : Dialect { let extraClassDeclaration = [{ void registerTypes(); + + static TritonDialect *getLoaded(MLIRContext *ctx) { + return ctx->getLoadedDialect(); + } + static TritonDialect *getLoaded(Operation *op) { + return getLoaded(op->getContext()); + } }]; + let discardableAttrs = (ins + "::mlir::IntegerAttr":$num_stages, + "::mlir::IntegerAttr":$latency + ); + let hasConstantMaterializer = 1; let useDefaultTypePrinterParser = 1; let usePropertiesForAttributes = 1; diff --git a/include/triton/Dialect/TritonGPU/Transforms/PipeliningUtility.h b/include/triton/Dialect/TritonGPU/Transforms/PipeliningUtility.h index 60ff5e005a38..29b8d026da58 100644 --- a/include/triton/Dialect/TritonGPU/Transforms/PipeliningUtility.h +++ b/include/triton/Dialect/TritonGPU/Transforms/PipeliningUtility.h @@ -18,7 +18,6 @@ static const char *kWarpSpecializeAttrName = "tt.warp_specialize"; static const char *kLoopStageAttrName = "loop.stage"; static const char *kLoopClusterAttrName = "loop.cluster"; static const char *kScheduledMaxStageAttrName = "tt.scheduled_max_stage"; -static const char *kLatencyAttrName = "tt.latency"; bool loopHasDistGreaterThanOne(scf::ForOp forOp); bool isOuterLoop(scf::ForOp forOp); diff --git a/lib/Dialect/TritonGPU/Transforms/Pipeliner/AssignLatencies.cpp b/lib/Dialect/TritonGPU/Transforms/Pipeliner/AssignLatencies.cpp index 3da99c245c62..ce00ecbdfa9d 100644 --- a/lib/Dialect/TritonGPU/Transforms/Pipeliner/AssignLatencies.cpp +++ b/lib/Dialect/TritonGPU/Transforms/Pipeliner/AssignLatencies.cpp @@ -37,8 +37,9 @@ bool preCondition(scf::ForOp forOp) { } bool hasLatenciesAssigned(scf::ForOp forOp) { + auto helper = TritonDialect::getLoaded(forOp)->getLatencyAttrHelper(); for (auto &op : forOp.getBody()->without_terminator()) { - if (op.hasAttr("tt_latency")) + if (helper.getAttr(&op)) return true; } return false; @@ -46,9 +47,10 @@ bool hasLatenciesAssigned(scf::ForOp forOp) { void assignUserProvidedLatencies(scf::ForOp forOp, DenseMap &opLatency) { + auto helper = TritonDialect::getLoaded(forOp)->getLatencyAttrHelper(); for (auto &op : forOp.getBody()->without_terminator()) { - if (auto latencyAttr = op.getAttr("tt_latency")) { - opLatency[&op] = mlir::cast(latencyAttr).getInt(); + if (auto latencyAttr = helper.getAttr(&op)) { + opLatency[&op] = latencyAttr.getInt(); } } } diff --git a/lib/Dialect/TritonGPU/Transforms/Pipeliner/PipeliningUtility.cpp b/lib/Dialect/TritonGPU/Transforms/Pipeliner/PipeliningUtility.cpp index 148ed58e4ba3..911b1104c16c 100644 --- a/lib/Dialect/TritonGPU/Transforms/Pipeliner/PipeliningUtility.cpp +++ b/lib/Dialect/TritonGPU/Transforms/Pipeliner/PipeliningUtility.cpp @@ -248,19 +248,20 @@ int mlir::triton::getCopyVecBytes(RankedTensorType registerTy, void mlir::triton::serializeLatencies(ModuleOp module, DenseMap &opLatency) { + auto helper = TritonDialect::getLoaded(module)->getLatencyAttrHelper(); + auto builder = Builder(module); for (auto &[op, latency] : opLatency) { - op->setAttr( - kLatencyAttrName, - IntegerAttr::get(IntegerType::get(module.getContext(), 32), latency)); + helper.setAttr(op, builder.getI32IntegerAttr(latency)); } } DenseMap mlir::triton::deserializeLatencies(Operation *op) { + auto helper = TritonDialect::getLoaded(op)->getLatencyAttrHelper(); DenseMap opLatency; op->walk([&](Operation *op) { - if (op->hasAttr(kLatencyAttrName)) { - opLatency[op] = op->getAttrOfType(kLatencyAttrName).getInt(); - op->removeAttr(kLatencyAttrName); + if (auto attr = helper.getAttr(op)) { + opLatency[op] = attr.getInt(); + helper.removeAttr(op); } }); return opLatency; @@ -408,9 +409,8 @@ int mlir::triton::getNumStagesOrDefault(scf::ForOp forOp, int defaultNumStages) { // Use the attribute attached to the loop if it exists otherwise use the // global control. - if (!forOp->hasAttr(mlir::triton::kNumStagesAttrName)) - return defaultNumStages; - return mlir::cast( - forOp->getAttr(mlir::triton::kNumStagesAttrName)) - .getInt(); + auto helper = TritonDialect::getLoaded(forOp)->getNumStagesAttrHelper(); + if (auto attr = helper.getAttr(forOp)) + return attr.getInt(); + return defaultNumStages; } diff --git a/lib/Dialect/TritonGPU/Transforms/Pipeliner/TestPipelineScheduleLoop.cpp b/lib/Dialect/TritonGPU/Transforms/Pipeliner/TestPipelineScheduleLoop.cpp index a95688aa0e2f..4397595fc772 100644 --- a/lib/Dialect/TritonGPU/Transforms/Pipeliner/TestPipelineScheduleLoop.cpp +++ b/lib/Dialect/TritonGPU/Transforms/Pipeliner/TestPipelineScheduleLoop.cpp @@ -15,8 +15,6 @@ namespace gpu { #define GEN_PASS_DEF_TRITONGPUTESTPIPELINESCHEDULELOOP #include "triton/Dialect/TritonGPU/Transforms/Passes.h.inc" -static const char *kLatencyAttrName = "tt.latency"; - struct TestPipelineScheduleLoop : public impl::TritonGPUTestPipelineScheduleLoopBase< TestPipelineScheduleLoop> { diff --git a/test/TritonGPU/loop-pipeline-async-latencies.mlir b/test/TritonGPU/loop-pipeline-async-latencies.mlir index 06e4d053bdd2..56fcc4c75bdc 100644 --- a/test/TritonGPU/loop-pipeline-async-latencies.mlir +++ b/test/TritonGPU/loop-pipeline-async-latencies.mlir @@ -101,9 +101,9 @@ tt.func public @matmul_kernel_tma_persistent(%arg0: !tt.ptr {tt.nv_tma_de // CHECK: [[RHS_MBAR:%.*]] = ttg.memdesc_subview [[RHS_BARS]][[[RHS_BUF_IDX]]] // CHECK-NEXT: ttng.wait_barrier [[RHS_MBAR]], [[RHS_PHASE]] - %4 = tt.descriptor_load %1[%c0_i32, %arg6] {tt_latency = 1 : i32} : !tt.tensordesc> -> tensor<128x64xf16, #blocked> + %4 = tt.descriptor_load %1[%c0_i32, %arg6] {tt.latency = 1 : i32} : !tt.tensordesc> -> tensor<128x64xf16, #blocked> %5 = ttg.local_alloc %4 : (tensor<128x64xf16, #blocked>) -> !ttg.memdesc<128x64xf16, #shared, #smem> - %6 = tt.descriptor_load %2[%c0_i32, %arg6] {tt_latency = 3 : i32} : !tt.tensordesc> -> tensor<256x64xf16, #blocked> + %6 = tt.descriptor_load %2[%c0_i32, %arg6] {tt.latency = 3 : i32} : !tt.tensordesc> -> tensor<256x64xf16, #blocked> %7 = ttg.local_alloc %6 : (tensor<256x64xf16, #blocked>) -> !ttg.memdesc<256x64xf16, #shared, #smem> %8 = ttg.memdesc_trans %7 {order = array} : !ttg.memdesc<256x64xf16, #shared, #smem> -> !ttg.memdesc<64x256xf16, #shared1, #smem> %9 = ttng.warp_group_dot %5, %8, %arg7 {inputPrecision = 0 : i32} : !ttg.memdesc<128x64xf16, #shared, #smem> * !ttg.memdesc<64x256xf16, #shared1, #smem> -> tensor<128x256xf32, #mma> diff --git a/test/TritonGPU/loop-schedule.mlir b/test/TritonGPU/loop-schedule.mlir index 406b9250836c..282451570f0c 100644 --- a/test/TritonGPU/loop-schedule.mlir +++ b/test/TritonGPU/loop-schedule.mlir @@ -168,7 +168,7 @@ tt.func @prologue_backward_slice(%ub: i32, %cond: i1) { // CHECK: op.with_region "op.with_region"() ({ "use"(%1) : (i32) -> () - }) {tt_latency = 2 : i32} : () -> () + }) {tt.latency = 2 : i32} : () -> () // CHECK: loop.cluster = 1 : i32, loop.stage = 0 : i32 } {tt.num_stages = 3 : i32} @@ -186,7 +186,7 @@ tt.func @epilogue_forward_slice(%ub: i32, %cond: i1) { // CHECK: scf.for scf.for %i = %c0_i32 to %ub step %c1_i32 : i32 { // CHECK: "latency.op"() {loop.cluster = 3 : i32, loop.stage = 0 : i32 - %0 = "latency.op"() {tt_latency = 2 : i32} : () -> i32 + %0 = "latency.op"() {tt.latency = 2 : i32} : () -> i32 // CHECK: scf.if %1 = scf.if %cond -> i32 { scf.yield %0 : i32 @@ -219,7 +219,7 @@ tt.func @prologue_latency(%ub: i32, %cond: i1) { scf.yield %0 : i32 } else { scf.yield %c0_i32 : i32 - } {tt_latency = 2 : i32} + } {tt.latency = 2 : i32} // CHECK: loop.cluster = 0 : i32, loop.stage = 0 : i32 } {tt.num_stages = 3 : i32}