diff --git a/test/TritonGPU/amd/amd-reorder-instructions.mlir b/test/TritonGPU/amd/amd-reorder-instructions.mlir index 5dfd0f2a5f4c..1b8dc10d938b 100644 --- a/test/TritonGPU/amd/amd-reorder-instructions.mlir +++ b/test/TritonGPU/amd/amd-reorder-instructions.mlir @@ -499,3 +499,40 @@ module attributes {"triton_gpu.num-warps" = 4 : i32, "triton_gpu.threads-per-war tt.return } } + + +// ----- + +#mfma = #triton_gpu.amd_mfma<{versionMajor = 2, versionMinor = 0, warpsPerCTA = [8, 1], instrShape = [32, 32], isTransposed = true}> +module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 8 : i32, triton_gpu.target = "hip:gfx90a", "triton_gpu.threads-per-warp" = 64 : i32} { + // CHECK-LABEL: dont_hoist_scf_ops + // Make sure we don't hoist scf ops above its dependencies. + tt.func public @dont_hoist_scf_ops(%init: tensor<256x128xf32, #mfma>, + %base: tensor<256x128x!tt.ptr, #triton_gpu.dot_op<{opIdx = 0, parent = #mfma, kWidth = 4}>>, + %p1: tensor<128x128x!tt.ptr, #triton_gpu.dot_op<{opIdx = 1, parent = #mfma, kWidth = 4}>>, %i1: i1) -> (tensor<256x128xf32, #mfma>) { + %c0_i32 = arith.constant 0 : i32 + %c1_i32 = arith.constant 1 : i32 + %c4_i32 = arith.constant 4 : i32 + %cst = arith.constant 1.44269502 : f32 + %c128_i32 = arith.constant 128 : i32 + // CHECK: scf.for + %54 = scf.for %arg21 = %c0_i32 to %c4_i32 step %c1_i32 iter_args(%arg = %init) -> (tensor<256x128xf32, #mfma>) : i32 { + // CHECK: arith.addi + %f = arith.addi %arg21, %c128_i32 : i32 + // CHECK: scf.if + // CHECK: tt.load + %p0 = scf.if %i1 -> tensor<256x128x!tt.ptr, #triton_gpu.dot_op<{opIdx = 0, parent = #mfma, kWidth = 4}>>{ + %t = tt.splat %f : i32 -> tensor<256x128xi32> + %padd = tt.addptr %base, %t : tensor<256x128x!tt.ptr, #triton_gpu.dot_op<{opIdx = 0, parent = #mfma, kWidth = 4}>>, tensor<256x128xi32> + scf.yield %padd : tensor<256x128x!tt.ptr, #triton_gpu.dot_op<{opIdx = 0, parent = #mfma, kWidth = 4}>> + } else { + scf.yield %base : tensor<256x128x!tt.ptr, #triton_gpu.dot_op<{opIdx = 0, parent = #mfma, kWidth = 4}>> + } + %l = tt.load %p0 : tensor<256x128x!tt.ptr, #triton_gpu.dot_op<{opIdx = 0, parent = #mfma, kWidth = 4}>> + %r = tt.load %p1 : tensor<128x128x!tt.ptr, #triton_gpu.dot_op<{opIdx = 1, parent = #mfma, kWidth = 4}>> + %acc = tt.dot %l, %r, %arg : tensor<256x128xf16, #triton_gpu.dot_op<{opIdx = 0, parent = #mfma, kWidth = 4}>> * tensor<128x128xf16, #triton_gpu.dot_op<{opIdx = 1, parent = #mfma, kWidth = 4}>> -> tensor<256x128xf32, #mfma> + scf.yield %acc : tensor<256x128xf32, #mfma> + } + tt.return %54 : tensor<256x128xf32, #mfma> + } +} diff --git a/third_party/amd/lib/TritonAMDGPUTransforms/ReorderInstructions.cpp b/third_party/amd/lib/TritonAMDGPUTransforms/ReorderInstructions.cpp index f55ab7855440..0837f16dcf7c 100644 --- a/third_party/amd/lib/TritonAMDGPUTransforms/ReorderInstructions.cpp +++ b/third_party/amd/lib/TritonAMDGPUTransforms/ReorderInstructions.cpp @@ -227,6 +227,7 @@ static void scheduleGlobalLoadLocalStore(triton::FuncOp funcOp) { // Gather use-def chain in block. Block *block = op->getBlock(); bool leadsToLoad = false; + bool dontReorder = false; SetVector backwardSet; BackwardSliceOptions options; @@ -236,6 +237,13 @@ static void scheduleGlobalLoadLocalStore(triton::FuncOp funcOp) { Block *defBlock = defOp->getBlock(); if (!block->findAncestorOpInBlock(*defOp)) return false; + // Don't hoist control flow as we don't track backtraces of ops within + // their regions. + if (isa(defOp)) { + dontReorder = true; + return false; + } + // Check for a `load` dependent path. leadsToLoad |= isa(defOp); // Only move ops residing in the same block. @@ -244,6 +252,9 @@ static void scheduleGlobalLoadLocalStore(triton::FuncOp funcOp) { mlir::getBackwardSlice(op, &backwardSet, options); backwardSet.insert(op); + // If we found ops in the slice we don't want to hoist. + if (dontReorder) + continue; // Don't move a local_store if its source is a load from // the same iteration. if (isa(op) && leadsToLoad)