Skip to content

Commit b0ebcfc

Browse files
authored
[AMD] Adjust local_store and global_load ordering (#5254)
This commit adjusts local store and global load ordering to let local store be ahead of global load when they are not in the same stage. It should help GEMM kernel performance.
1 parent 134b3eb commit b0ebcfc

File tree

2 files changed

+21
-21
lines changed

2 files changed

+21
-21
lines changed

test/TritonGPU/amd/amd-reorder-instructions.mlir

Lines changed: 18 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -270,30 +270,30 @@ module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32, "ttg.thr
270270

271271
// CHECK-LABEL: tt.func @matmul_loop_mb
272272
// CHECK: %{{.*}}:8 = scf.for %[[ARG5:.*]] = %{{.*}} to %{{.*}} step %{{.*}} iter_args(%[[ARG6:.*]] = %{{.*}}, %[[ARG7:.*]] = %{{.*}}, %[[ARG8:.*]] = %{{.*}}, %[[ARG9:.*]] = %{{.*}}, %[[ARG10:.*]] = %{{.*}}, %[[ARG11:.*]] = %{{.*}}, %[[ARG12:.*]] = %{{.*}}, %[[ARG13:.*]] = %{{.*}})
273-
// Stage 0
274-
// CHECK: %[[ADDPTR_28:.*]] = tt.addptr %[[ARG6]], %{{.*}}
275-
// CHECK: %[[MULI_29:.*]] = arith.muli %{{.*}}, %{{.*}}
276-
// CHECK: %[[SUBI_30:.*]] = arith.subi %{{.*}}, %[[MULI_29]]
277-
// CHECK: %[[CMPI_31:.*]] = arith.cmpi slt, %[[ARG5]], %[[SUBI_30]]
278-
// CHECK: %[[SPLAT_32:.*]] = tt.splat %[[CMPI_31]]
279-
// CHECK: %[[LOAD_33:.*]] = tt.load %[[ADDPTR_28]], %[[SPLAT_32]]
280-
// CHECK: %[[ADDPTR_34:.*]] = tt.addptr %[[ARG7]], %{{.*}}
281-
// CHECK: %[[SPLAT_35:.*]] = tt.splat %[[CMPI_31]]
282-
// CHECK: %[[LOAD_36:.*]] = tt.load %[[ADDPTR_34]], %[[SPLAT_35]]
283273
// Stage 1
284-
// CHECK: %[[ADDI_37:.*]] = arith.addi %[[ARG9]], %{{.*}}
285-
// CHECK: %[[CMPI_38:.*]] = arith.cmpi slt, %[[ADDI_37]], %{{.*}}
286-
// CHECK: %[[SELECT_39:.*]] = arith.select %[[CMPI_38]], %[[ADDI_37]], %{{.*}}
287-
// CHECK: %[[MEMDESC_SUBVIEW_40:.*]] = ttg.memdesc_subview %{{.*}}[%[[SELECT_39]], %{{.*}}, %{{.*}}]
288-
// CHECK: ttg.local_store %[[ARG12]], %[[MEMDESC_SUBVIEW_40]]
289-
// CHECK: %[[MEMDESC_SUBVIEW_41:.*]] = ttg.memdesc_subview %{{.*}}[%[[SELECT_39]], %{{.*}}, %{{.*}}]
290-
// CHECK: ttg.local_store %[[ARG13]], %[[MEMDESC_SUBVIEW_41]]
274+
// CHECK: %[[ADDI_28:.*]] = arith.addi %[[ARG9]], %{{.*}}
275+
// CHECK: %[[CMPI_29:.*]] = arith.cmpi slt, %[[ADDI_28]], %{{.*}}
276+
// CHECK: %[[SELECT_30:.*]] = arith.select %[[CMPI_29]], %[[ADDI_28]], %{{.*}}
277+
// CHECK: %[[MEMDESC_SUBVIEW_31:.*]] = ttg.memdesc_subview %{{.*}}[%[[SELECT_30]], %{{.*}}, %{{.*}}]
278+
// CHECK: ttg.local_store %[[ARG12]], %[[MEMDESC_SUBVIEW_31]]
279+
// CHECK: %[[MEMDESC_SUBVIEW_32:.*]] = ttg.memdesc_subview %{{.*}}[%[[SELECT_30]], %{{.*}}, %{{.*}}]
280+
// CHECK: ttg.local_store %[[ARG13]], %[[MEMDESC_SUBVIEW_32]]
281+
// Stage 1
282+
// CHECK: %[[ADDPTR_33:.*]] = tt.addptr %[[ARG6]], %{{.*}}
283+
// CHECK: %[[MULI_34:.*]] = arith.muli %{{.*}}, %{{.*}}
284+
// CHECK: %[[SUBI_35:.*]] = arith.subi %{{.*}}, %[[MULI_34]]
285+
// CHECK: %[[CMPI_36:.*]] = arith.cmpi slt, %[[ARG5]], %[[SUBI_35]]
286+
// CHECK: %[[SPLAT_37:.*]] = tt.splat %[[CMPI_36]]
287+
// CHECK: %[[LOAD_38:.*]] = tt.load %[[ADDPTR_33]], %[[SPLAT_37]]
288+
// CHECK: %[[ADDPTR_39:.*]] = tt.addptr %[[ARG7]], %{{.*}}
289+
// CHECK: %[[SPLAT_40:.*]] = tt.splat %[[CMPI_36]]
290+
// CHECK: %[[LOAD_41:.*]] = tt.load %[[ADDPTR_39]], %[[SPLAT_40]]
291291
// Stage 2
292292
// CHECK: %[[LOCAL_LOAD_42:.*]] = ttg.local_load %[[ARG10]]
293293
// CHECK: %[[LOCAL_LOAD_43:.*]] = ttg.local_load %[[ARG11]]
294294
// CHECK: %[[MULF_44:.*]] = arith.mulf %[[LOCAL_LOAD_43]], %{{.*}}
295295
// CHECK: %[[DOT_45:.*]] = tt.dot %[[LOCAL_LOAD_42]], %[[MULF_44]], %[[ARG8]]
296-
// CHECK: scf.yield %[[ADDPTR_28]], %[[ADDPTR_34]], %[[DOT_45]], %[[SELECT_39]], %[[MEMDESC_SUBVIEW_40]], %[[MEMDESC_SUBVIEW_41]], %[[LOAD_33]], %[[LOAD_36]]
296+
// CHECK: scf.yield %[[ADDPTR_33]], %[[ADDPTR_39]], %[[DOT_45]], %[[SELECT_30]], %[[MEMDESC_SUBVIEW_31]], %[[MEMDESC_SUBVIEW_32]], %[[LOAD_38]], %[[LOAD_41]]
297297
// CHECK: }
298298

299299
tt.func @matmul_loop_mb(%arg0: index, %arg1: index, %arg2: index, %arg3: !tt.ptr<f16> {tt.divisibility = 16 : i32}, %arg4: !tt.ptr<f16> {tt.divisibility = 16 : i32}) -> tensor<128x128xf32, #mma> {

third_party/amd/lib/TritonAMDGPUTransforms/ReorderInstructions.cpp

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -216,12 +216,12 @@ static void moveUpTranspose(triton::FuncOp funcOp) {
216216
// Schedule global load and local store ops for better GEMM performance.
217217
static void scheduleGlobalLoadLocalStore(triton::FuncOp funcOp) {
218218
SmallVector<Operation *> moveOps;
219-
// Move global loads early to prefetch. This may increase register pressure
220-
// but it enables issuing global loads early.
221-
funcOp.walk([&](triton::LoadOp op) { moveOps.push_back(op); });
222219
// Move local_stores early if dependence distance greater than one iteration.
223220
// Best perf on GEMM when these precede global loads.
224221
funcOp.walk([&](ttg::LocalStoreOp op) { moveOps.push_back(op); });
222+
// Move global loads early to prefetch. This may increase register pressure
223+
// but it enables issuing global loads early.
224+
funcOp.walk([&](triton::LoadOp op) { moveOps.push_back(op); });
225225

226226
for (auto op : llvm::reverse(moveOps)) {
227227
// Gather use-def chain in block.

0 commit comments

Comments
 (0)