From e9f34115a5166c42be0a25a53f10d0ea55b37aa3 Mon Sep 17 00:00:00 2001 From: Thomas Raoux Date: Wed, 20 Nov 2024 15:35:08 -0800 Subject: [PATCH] [AMD] Prevent wrong reordering of scf operations The pass was reordering scf.if operations without checking the extra dependencies coming from the region. For now just prevent this case although this part of the code might still be fragile. --- .../amd/amd-reorder-instructions.mlir | 37 +++++++++++++++++++ .../ReorderInstructions.cpp | 11 ++++++ 2 files changed, 48 insertions(+) diff --git a/test/TritonGPU/amd/amd-reorder-instructions.mlir b/test/TritonGPU/amd/amd-reorder-instructions.mlir index 5dfd0f2a5f4c..1b8dc10d938b 100644 --- a/test/TritonGPU/amd/amd-reorder-instructions.mlir +++ b/test/TritonGPU/amd/amd-reorder-instructions.mlir @@ -499,3 +499,40 @@ module attributes {"triton_gpu.num-warps" = 4 : i32, "triton_gpu.threads-per-war tt.return } } + + +// ----- + +#mfma = #triton_gpu.amd_mfma<{versionMajor = 2, versionMinor = 0, warpsPerCTA = [8, 1], instrShape = [32, 32], isTransposed = true}> +module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 8 : i32, triton_gpu.target = "hip:gfx90a", "triton_gpu.threads-per-warp" = 64 : i32} { + // CHECK-LABEL: dont_hoist_scf_ops + // Make sure we don't hoist scf ops above its dependencies. + tt.func public @dont_hoist_scf_ops(%init: tensor<256x128xf32, #mfma>, + %base: tensor<256x128x!tt.ptr, #triton_gpu.dot_op<{opIdx = 0, parent = #mfma, kWidth = 4}>>, + %p1: tensor<128x128x!tt.ptr, #triton_gpu.dot_op<{opIdx = 1, parent = #mfma, kWidth = 4}>>, %i1: i1) -> (tensor<256x128xf32, #mfma>) { + %c0_i32 = arith.constant 0 : i32 + %c1_i32 = arith.constant 1 : i32 + %c4_i32 = arith.constant 4 : i32 + %cst = arith.constant 1.44269502 : f32 + %c128_i32 = arith.constant 128 : i32 + // CHECK: scf.for + %54 = scf.for %arg21 = %c0_i32 to %c4_i32 step %c1_i32 iter_args(%arg = %init) -> (tensor<256x128xf32, #mfma>) : i32 { + // CHECK: arith.addi + %f = arith.addi %arg21, %c128_i32 : i32 + // CHECK: scf.if + // CHECK: tt.load + %p0 = scf.if %i1 -> tensor<256x128x!tt.ptr, #triton_gpu.dot_op<{opIdx = 0, parent = #mfma, kWidth = 4}>>{ + %t = tt.splat %f : i32 -> tensor<256x128xi32> + %padd = tt.addptr %base, %t : tensor<256x128x!tt.ptr, #triton_gpu.dot_op<{opIdx = 0, parent = #mfma, kWidth = 4}>>, tensor<256x128xi32> + scf.yield %padd : tensor<256x128x!tt.ptr, #triton_gpu.dot_op<{opIdx = 0, parent = #mfma, kWidth = 4}>> + } else { + scf.yield %base : tensor<256x128x!tt.ptr, #triton_gpu.dot_op<{opIdx = 0, parent = #mfma, kWidth = 4}>> + } + %l = tt.load %p0 : tensor<256x128x!tt.ptr, #triton_gpu.dot_op<{opIdx = 0, parent = #mfma, kWidth = 4}>> + %r = tt.load %p1 : tensor<128x128x!tt.ptr, #triton_gpu.dot_op<{opIdx = 1, parent = #mfma, kWidth = 4}>> + %acc = tt.dot %l, %r, %arg : tensor<256x128xf16, #triton_gpu.dot_op<{opIdx = 0, parent = #mfma, kWidth = 4}>> * tensor<128x128xf16, #triton_gpu.dot_op<{opIdx = 1, parent = #mfma, kWidth = 4}>> -> tensor<256x128xf32, #mfma> + scf.yield %acc : tensor<256x128xf32, #mfma> + } + tt.return %54 : tensor<256x128xf32, #mfma> + } +} diff --git a/third_party/amd/lib/TritonAMDGPUTransforms/ReorderInstructions.cpp b/third_party/amd/lib/TritonAMDGPUTransforms/ReorderInstructions.cpp index f55ab7855440..0837f16dcf7c 100644 --- a/third_party/amd/lib/TritonAMDGPUTransforms/ReorderInstructions.cpp +++ b/third_party/amd/lib/TritonAMDGPUTransforms/ReorderInstructions.cpp @@ -227,6 +227,7 @@ static void scheduleGlobalLoadLocalStore(triton::FuncOp funcOp) { // Gather use-def chain in block. Block *block = op->getBlock(); bool leadsToLoad = false; + bool dontReorder = false; SetVector backwardSet; BackwardSliceOptions options; @@ -236,6 +237,13 @@ static void scheduleGlobalLoadLocalStore(triton::FuncOp funcOp) { Block *defBlock = defOp->getBlock(); if (!block->findAncestorOpInBlock(*defOp)) return false; + // Don't hoist control flow as we don't track backtraces of ops within + // their regions. + if (isa(defOp)) { + dontReorder = true; + return false; + } + // Check for a `load` dependent path. leadsToLoad |= isa(defOp); // Only move ops residing in the same block. @@ -244,6 +252,9 @@ static void scheduleGlobalLoadLocalStore(triton::FuncOp funcOp) { mlir::getBackwardSlice(op, &backwardSet, options); backwardSet.insert(op); + // If we found ops in the slice we don't want to hoist. + if (dontReorder) + continue; // Don't move a local_store if its source is a load from // the same iteration. if (isa(op) && leadsToLoad)