Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
37 changes: 37 additions & 0 deletions test/TritonGPU/amd/amd-reorder-instructions.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -499,3 +499,40 @@ module attributes {"triton_gpu.num-warps" = 4 : i32, "triton_gpu.threads-per-war
tt.return
}
}


// -----

#mfma = #triton_gpu.amd_mfma<{versionMajor = 2, versionMinor = 0, warpsPerCTA = [8, 1], instrShape = [32, 32], isTransposed = true}>
module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 8 : i32, triton_gpu.target = "hip:gfx90a", "triton_gpu.threads-per-warp" = 64 : i32} {
// CHECK-LABEL: dont_hoist_scf_ops
// Make sure we don't hoist scf ops above its dependencies.
tt.func public @dont_hoist_scf_ops(%init: tensor<256x128xf32, #mfma>,
%base: tensor<256x128x!tt.ptr<f16>, #triton_gpu.dot_op<{opIdx = 0, parent = #mfma, kWidth = 4}>>,
%p1: tensor<128x128x!tt.ptr<f16>, #triton_gpu.dot_op<{opIdx = 1, parent = #mfma, kWidth = 4}>>, %i1: i1) -> (tensor<256x128xf32, #mfma>) {
%c0_i32 = arith.constant 0 : i32
%c1_i32 = arith.constant 1 : i32
%c4_i32 = arith.constant 4 : i32
%cst = arith.constant 1.44269502 : f32
%c128_i32 = arith.constant 128 : i32
// CHECK: scf.for
%54 = scf.for %arg21 = %c0_i32 to %c4_i32 step %c1_i32 iter_args(%arg = %init) -> (tensor<256x128xf32, #mfma>) : i32 {
// CHECK: arith.addi
%f = arith.addi %arg21, %c128_i32 : i32
// CHECK: scf.if
// CHECK: tt.load
%p0 = scf.if %i1 -> tensor<256x128x!tt.ptr<f16>, #triton_gpu.dot_op<{opIdx = 0, parent = #mfma, kWidth = 4}>>{
%t = tt.splat %f : i32 -> tensor<256x128xi32>
%padd = tt.addptr %base, %t : tensor<256x128x!tt.ptr<f16>, #triton_gpu.dot_op<{opIdx = 0, parent = #mfma, kWidth = 4}>>, tensor<256x128xi32>
scf.yield %padd : tensor<256x128x!tt.ptr<f16>, #triton_gpu.dot_op<{opIdx = 0, parent = #mfma, kWidth = 4}>>
} else {
scf.yield %base : tensor<256x128x!tt.ptr<f16>, #triton_gpu.dot_op<{opIdx = 0, parent = #mfma, kWidth = 4}>>
}
%l = tt.load %p0 : tensor<256x128x!tt.ptr<f16>, #triton_gpu.dot_op<{opIdx = 0, parent = #mfma, kWidth = 4}>>
%r = tt.load %p1 : tensor<128x128x!tt.ptr<f16>, #triton_gpu.dot_op<{opIdx = 1, parent = #mfma, kWidth = 4}>>
%acc = tt.dot %l, %r, %arg : tensor<256x128xf16, #triton_gpu.dot_op<{opIdx = 0, parent = #mfma, kWidth = 4}>> * tensor<128x128xf16, #triton_gpu.dot_op<{opIdx = 1, parent = #mfma, kWidth = 4}>> -> tensor<256x128xf32, #mfma>
scf.yield %acc : tensor<256x128xf32, #mfma>
}
tt.return %54 : tensor<256x128xf32, #mfma>
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -227,6 +227,7 @@ static void scheduleGlobalLoadLocalStore(triton::FuncOp funcOp) {
// Gather use-def chain in block.
Block *block = op->getBlock();
bool leadsToLoad = false;
bool dontReorder = false;
SetVector<Operation *> backwardSet;

BackwardSliceOptions options;
Expand All @@ -236,6 +237,13 @@ static void scheduleGlobalLoadLocalStore(triton::FuncOp funcOp) {
Block *defBlock = defOp->getBlock();
if (!block->findAncestorOpInBlock(*defOp))
return false;
// Don't hoist control flow as we don't track backtraces of ops within
// their regions.
if (isa<scf::IfOp, scf::ForOp, scf::WhileOp>(defOp)) {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This could be defOp->getNumRegions() != 0. I think this could happen for tt.reduce?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

but reduces are fine to hoist right?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Well, it looks like tl.reduce can take a function with arbitrary code. This could include external constants? which could get out-of-order

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actually I can't see a way to get that to fail. Just being overly cautious.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I thought it had to be isolated from above but I guess I don't see anything blocking that

dontReorder = true;
return false;
}

// Check for a `load` dependent path.
leadsToLoad |= isa<triton::LoadOp>(defOp);
// Only move ops residing in the same block.
Expand All @@ -244,6 +252,9 @@ static void scheduleGlobalLoadLocalStore(triton::FuncOp funcOp) {
mlir::getBackwardSlice(op, &backwardSet, options);
backwardSet.insert(op);

// If we found ops in the slice we don't want to hoist.
if (dontReorder)
continue;
// Don't move a local_store if its source is a load from
// the same iteration.
if (isa<ttg::LocalStoreOp>(op) && leadsToLoad)
Expand Down
Loading