From 17dbc36e6a4f06dbdb525ac05b75881ad71c6a58 Mon Sep 17 00:00:00 2001 From: Corbin Robeck Date: Wed, 9 Apr 2025 01:57:14 +0000 Subject: [PATCH 1/2] optimize conditionals for compute logic in epilogue that is wholly dependent on pipelined loads --- .../TritonGPU/Transforms/PipelineExpander.h | 2 + .../Transforms/Pipeliner/PipelineExpander.cpp | 18 ++- python/src/passes.h | 5 + third_party/amd/backend/compiler.py | 4 +- .../include/TritonAMDGPUTransforms/Passes.h | 7 +- .../include/TritonAMDGPUTransforms/Passes.td | 3 + .../TritonAMDGPUTransforms/StreamPipeline.cpp | 120 ++++++++++++++++-- third_party/amd/python/triton_amd.cc | 4 +- 8 files changed, 142 insertions(+), 21 deletions(-) diff --git a/include/triton/Dialect/TritonGPU/Transforms/PipelineExpander.h b/include/triton/Dialect/TritonGPU/Transforms/PipelineExpander.h index 943cba117656..3bb92017cc68 100644 --- a/include/triton/Dialect/TritonGPU/Transforms/PipelineExpander.h +++ b/include/triton/Dialect/TritonGPU/Transforms/PipelineExpander.h @@ -51,6 +51,8 @@ struct PipeliningOption { /// lambda to generate the predicated version of operations. bool peelEpilogue = true; + bool guardEpilogue = true; + /// Control whether the transformation checks that the number of iterations is /// greater or equal to the number of stages and skip the transformation if /// this is not the case. If the loop is dynamic and this is set to true the diff --git a/lib/Dialect/TritonGPU/Transforms/Pipeliner/PipelineExpander.cpp b/lib/Dialect/TritonGPU/Transforms/Pipeliner/PipelineExpander.cpp index 0f4a835ed050..ae60839ffccd 100644 --- a/lib/Dialect/TritonGPU/Transforms/Pipeliner/PipelineExpander.cpp +++ b/lib/Dialect/TritonGPU/Transforms/Pipeliner/PipelineExpander.cpp @@ -64,9 +64,10 @@ struct LoopPipelinerInternal { Value lb; Value step; bool dynamicLoop; - triton::PipeliningOption::AnnotationlFnType annotateFn = nullptr; bool peelEpilogue; + bool guardEpilogue; triton::PipeliningOption::PredicateOpFnType predicateFn = nullptr; + triton::PipeliningOption::AnnotationlFnType annotateFn = nullptr; // When peeling the kernel we generate several version of each value for // different stage of the prologue. This map tracks the mapping between @@ -156,6 +157,8 @@ bool LoopPipelinerInternal::initializeLoopInfo( } peelEpilogue = options.peelEpilogue; predicateFn = options.predicateFn; + guardEpilogue = options.guardEpilogue; + if ((!peelEpilogue || dynamicLoop) && predicateFn == nullptr) { LDBG("--no epilogue or predicate set -> BAIL"); return false; @@ -773,11 +776,16 @@ LoopPipelinerInternal::emitEpilogue(RewriterBase &rewriter, unsigned nextVersion = currentVersion + 1; Value pred = predicates[currentVersion]; Value prevValue = valueMapping[mapVal][currentVersion]; - auto selOp = rewriter.create(loc, pred, pair.value(), - prevValue); - returnValues[ri] = selOp; + Value nextValue = pair.value(); + + if (guardEpilogue) { + nextValue = rewriter.create(loc, pred, nextValue, + prevValue); + } + + returnValues[ri] = nextValue; if (nextVersion <= maxStage) - setValueMapping(mapVal, selOp, nextVersion); + setValueMapping(mapVal, nextValue, nextVersion); } } } diff --git a/python/src/passes.h b/python/src/passes.h index 629fe362d8b2..f9c7a8394fdc 100644 --- a/python/src/passes.h +++ b/python/src/passes.h @@ -19,6 +19,11 @@ m.def(name, [](mlir::PassManager &pm, ty0 val0, ty1 val1, ty2 val2, \ ty3 val3) { pm.addPass(builder(val0, val1, val2, val3)); }) +#define ADD_PASS_WRAPPER_5(name, builder, ty0, ty1, ty2, ty3, ty4) \ + m.def(name, \ + [](mlir::PassManager &pm, ty0 val0, ty1 val1, ty2 val2, ty3 val3, \ + ty4 val4) { pm.addPass(builder(val0, val1, val2, val3, val4)); }) + #define ADD_PASS_OPTION_WRAPPER_1(name, builder, ty0) \ m.def(name, \ [](mlir::PassManager &pm, ty0 val0) { pm.addPass(builder({val0})); }) diff --git a/third_party/amd/backend/compiler.py b/third_party/amd/backend/compiler.py index b4f18e35b701..2b5b451cb821 100644 --- a/third_party/amd/backend/compiler.py +++ b/third_party/amd/backend/compiler.py @@ -246,12 +246,14 @@ def make_ttgir(mod, metadata, options): global_prefetch = int(os.getenv("TRITON_HIP_GLOBAL_PREFETCH", "0")) local_prefetch = int(os.getenv("TRITON_HIP_LOCAL_PREFETCH", "0")) use_async_copy = int(os.getenv("TRITON_HIP_USE_ASYNC_COPY", "0")) == 1 + must_guard_epilogue = 1 # The `local-prefetch` scheduling variant requires turning on buffer ops. if options.schedule_hint == "local-prefetch": global_prefetch = local_prefetch = 1 - amd.passes.ttgpuir.add_stream_pipeline(pm, options.num_stages, global_prefetch, local_prefetch, use_async_copy) + amd.passes.ttgpuir.add_stream_pipeline(pm, options.num_stages, global_prefetch, local_prefetch, use_async_copy, + must_guard_epilogue) if use_async_copy: amd.passes.ttgpuir.add_coalesce_async_copy(pm, options.arch) passes.common.add_canonicalizer(pm) diff --git a/third_party/amd/include/TritonAMDGPUTransforms/Passes.h b/third_party/amd/include/TritonAMDGPUTransforms/Passes.h index e01ff34b5764..1740dd0069dc 100644 --- a/third_party/amd/include/TritonAMDGPUTransforms/Passes.h +++ b/third_party/amd/include/TritonAMDGPUTransforms/Passes.h @@ -8,10 +8,9 @@ namespace mlir { -std::unique_ptr -createTritonAMDGPUStreamPipelinePass(int numStages = 2, int globalPrefetch = 0, - int localPrefetch = 0, - bool useAsyncCopy = false); +std::unique_ptr createTritonAMDGPUStreamPipelinePass( + int numStages = 2, int globalPrefetch = 0, int localPrefetch = 0, + bool useAsyncCopy = false, bool mustGuardEpilogue = true); std::unique_ptr createTritonAMDGPUAccelerateMatmulPass(std::string archGenName = std::string(), diff --git a/third_party/amd/include/TritonAMDGPUTransforms/Passes.td b/third_party/amd/include/TritonAMDGPUTransforms/Passes.td index dee0d4bd7fe5..7f7498c748b7 100644 --- a/third_party/amd/include/TritonAMDGPUTransforms/Passes.td +++ b/third_party/amd/include/TritonAMDGPUTransforms/Passes.td @@ -28,6 +28,9 @@ def TritonAMDGPUStreamPipeline : Pass<"tritonamdgpu-stream-pipeline", "mlir::Mod Option<"useAsyncCopy", "use_async_copy", "bool", /*default*/"false", "Use AsyncCopyGlobalToLocal to directly load to shared memory">, + Option<"mustGuardEpilogue", "must_guard_epilogue", + "bool", /*default*/"true", + "Require conditionalized loads in epiloge of pipelined loop">, ]; } diff --git a/third_party/amd/lib/TritonAMDGPUTransforms/StreamPipeline.cpp b/third_party/amd/lib/TritonAMDGPUTransforms/StreamPipeline.cpp index 5b2a7a18a7f5..664282258be0 100644 --- a/third_party/amd/lib/TritonAMDGPUTransforms/StreamPipeline.cpp +++ b/third_party/amd/lib/TritonAMDGPUTransforms/StreamPipeline.cpp @@ -116,9 +116,11 @@ class StreamPipeliner { public: StreamPipeliner(scf::ForOp _forOp, int _numStages, int _globalPrefetch, - int _localPrefetch, bool _useAsyncCopy) + int _localPrefetch, bool _useAsyncCopy, + bool _mustGuardEpilogue) : forOp(_forOp), numStages(_numStages), numBuffers(1), - useAsyncCopy(_useAsyncCopy), schedule(numStages), + useAsyncCopy(_useAsyncCopy), mustGuardEpilogue(_mustGuardEpilogue), + schedule(numStages), axisInfoAnalysis(forOp->getParentOfType()) { int lastStage = numStages - 1; stages[SCHED_GLOBAL_LOAD] = 0; @@ -128,7 +130,11 @@ class StreamPipeliner { options.supportDynamicLoops = true; options.peelEpilogue = true; - options.predicateFn = streamPredication; + options.guardEpilogue = _mustGuardEpilogue; + if (_mustGuardEpilogue) + options.predicateFn = streamPredication; + else + options.predicateFn = tt::predicateOp; } LogicalResult pipelineLoop(); @@ -151,6 +157,10 @@ class StreamPipeliner { void createStreamCopy(tt::LoadOp loadOp, Value alloc, Value extractIdx); void createStreamOps(); + // Unguard epilogue + bool safeDAG(Value v, int index); + void checkResultResilience(); + void scheduleOp(Operation *op, SchedType type, int stage = -1) { if (stage < 0) stage = stages[type]; @@ -170,6 +180,8 @@ class StreamPipeliner { // Directly store to shared memory with AsyncCopy when pipelining tt.loads bool useAsyncCopy; + bool mustGuardEpilogue; + // Stage for each SchedType Op int stages[SCHED_SIZE]; // Cluster for each SchedType Op @@ -205,6 +217,80 @@ class StreamPipeliner { } // namespace +bool StreamPipeliner::safeDAG(Value v, int index) { + if (Operation *defOp = v.getDefiningOp()) { + if (auto loadOp = dyn_cast(defOp)) { + // Loads in the loop will be guarded + return loadOp->getParentOfType() == forOp; + } else if (auto cvtOp = dyn_cast(defOp)) { + return safeDAG(cvtOp.getSrc(), index); + } else if (auto dotOp = dyn_cast(defOp)) { + // 1 input must be safe + if (!safeDAG(dotOp.getA(), -1) && !safeDAG(dotOp.getB(), -1)) + return false; + auto C = dotOp->getOperand(2); + return safeDAG(C, index); + } else if (auto splatOp = dyn_cast(defOp)) { + // both inputs must be safe + return safeDAG(splatOp.getOperand(), -1); + } else if (auto bcastOp = dyn_cast(defOp)) { + // both inputs must be safe + return safeDAG(bcastOp.getOperand(), -1); + } else if (auto expandOp = dyn_cast(defOp)) { + // both inputs must be safe + return safeDAG(expandOp.getOperand(), -1); + } else if (auto addOp = dyn_cast(defOp)) { + // both inputs must be safe + return safeDAG(addOp.getLhs(), -1) && safeDAG(addOp.getRhs(), -1); + } else if (auto subOp = dyn_cast(defOp)) { + // both inputs must be safe + return safeDAG(subOp.getLhs(), -1) && safeDAG(subOp.getRhs(), -1); + } else if (auto mulOp = dyn_cast(defOp)) { + // either input must be safe + return safeDAG(mulOp.getLhs(), -1) || safeDAG(mulOp.getRhs(), -1); + } else if (auto truncOp = dyn_cast(defOp)) { + // either input must be safe + return safeDAG(truncOp.getOperand(), -1); + } else if (auto constOp = dyn_cast(defOp)) { + // check for constant zero + if (auto attr = dyn_cast(constOp.getValue())) + return attr.getValue().isZero(); + } else if (auto exp2Op = dyn_cast(defOp)) { + // either input must be safe + return safeDAG(exp2Op.getOperand(), -1); + } else if (auto selectOp = dyn_cast(defOp)) { + // both inputs must be safe + return safeDAG(selectOp.getTrueValue(), -1) && + safeDAG(selectOp.getFalseValue(), -1); + } else if (auto transposeOp = dyn_cast(defOp)) { + // input must be safe + return safeDAG(transposeOp.getSrc(), -1); + } else { + // Unknown op default to false + LDBG("Unknown op for unguard epilogue in stream pipeliner"); + return false; + } + } else { + // check block arg + auto arg = cast(v); + return arg.getArgNumber() == index; + } + return false; +} + +void StreamPipeliner::checkResultResilience() { + auto yieldVals = forOp.getYieldedValuesMutable().value(); + for (auto [index, res] : llvm::enumerate(forOp.getResults())) { + if (!res.use_empty()) { + // Check init value == 0 + // Backtrack yield value + Value yieldVal = yieldVals[index].get(); + if (!safeDAG(yieldVal, index + 1)) // + induction + mustGuardEpilogue = true; + } + } +} + // Init Schedule Config based on settings and loop characteristics. // Create clusters in order of ops in loop. This can interleave ops // from different stages in the same cluster to achieve better backend @@ -213,6 +299,10 @@ class StreamPipeliner { // can cause invalid schedules to be produced. LogicalResult StreamPipeliner::initSchedule(int maxIndirectionLevel) { + // Check to see if we can unconditionalize epilogue + if (!mustGuardEpilogue) + checkResultResilience(); + bool pairedGlobalLoadLocalStore = stages[SCHED_LOCAL_STORE] == 0; stages[SCHED_LOCAL_STORE] += maxIndirectionLevel; @@ -787,11 +877,16 @@ void StreamPipeliner::scheduleRemainingToLastStage() { // Assign the rest of the ops to the last stage. // Take care of the ordering of the ops - uses cannot be scheduled to the // cluster before the definition. + int count = 0; auto cluster = clusters[SCHED_COMPUTE]; DenseMap opToCluster; for (auto &op : forOp.getBody()->without_terminator()) { - if (schedule.count(&op) == 0) + if (schedule.count(&op) == 0) { opToCluster[&op] = cluster; + } + } + if (count != 0) { + mustGuardEpilogue = true; } SmallVector queue; for (auto [op, stage, cluster] : schedule.getOpsInOrder(forOp)) { @@ -1011,13 +1106,16 @@ void labelLoadOpsForTritonDot(scf::ForOp forOp) { struct PipelinePass : public TritonAMDGPUStreamPipelineBase { PipelinePass() = default; PipelinePass(int32_t _numStages, int32_t _globalPrefetch, - int32_t _localPrefetch, bool _useAsyncCopy) { + int32_t _localPrefetch, bool _useAsyncCopy, + bool _mustGuardEpilogue) { this->numStages = _numStages; this->globalPrefetch = _globalPrefetch; this->localPrefetch = _localPrefetch; this->useAsyncCopy = _useAsyncCopy; + + this->mustGuardEpilogue = _mustGuardEpilogue; } void runOnOperation() override { @@ -1047,15 +1145,19 @@ struct PipelinePass : public TritonAMDGPUStreamPipelineBase { if (!checkPrecondition(forOp)) continue; StreamPipeliner sp(forOp, tt::getNumStagesOrDefault(forOp, numStages), - globalPrefetch, localPrefetch, useAsyncCopy); + globalPrefetch, localPrefetch, useAsyncCopy, + mustGuardEpilogue); (void)sp.pipelineLoop(); } } }; } // namespace -std::unique_ptr mlir::createTritonAMDGPUStreamPipelinePass( - int numStages, int globalPrefetch, int localPrefetch, bool useAsyncCopy) { +std::unique_ptr +mlir::createTritonAMDGPUStreamPipelinePass(int numStages, int globalPrefetch, + int localPrefetch, bool useAsyncCopy, + bool mustGuardEpilogue) { return std::make_unique(numStages, globalPrefetch, - localPrefetch, useAsyncCopy); + localPrefetch, useAsyncCopy, + mustGuardEpilogue); } diff --git a/third_party/amd/python/triton_amd.cc b/third_party/amd/python/triton_amd.cc index 00397941bc7c..e20def349450 100644 --- a/third_party/amd/python/triton_amd.cc +++ b/third_party/amd/python/triton_amd.cc @@ -78,9 +78,9 @@ void init_triton_amd_passes_ttgpuir(py::module &&m) { mlir::createTritonAMDGPUFoldTrueCmpIPass); ADD_PASS_WRAPPER_1("add_block_pingpong", mlir::createTritonAMDGPUBlockPingpongPass, int32_t); - ADD_PASS_WRAPPER_4("add_stream_pipeline", + ADD_PASS_WRAPPER_5("add_stream_pipeline", mlir::createTritonAMDGPUStreamPipelinePass, int, int, int, - bool); + bool, bool); ADD_PASS_WRAPPER_1("add_coalesce_async_copy", mlir::createTritonAMDGPUCoalesceAsyncCopyPass, std::string); From 22ad1d340c03057c2dab0eb3c5331c1d928811c2 Mon Sep 17 00:00:00 2001 From: Corbin Robeck Date: Wed, 9 Apr 2025 13:19:35 +0000 Subject: [PATCH 2/2] update and clean up --- .../Transforms/Pipeliner/PipelineExpander.cpp | 16 ++---- third_party/amd/backend/compiler.py | 2 +- .../TritonAMDGPUTransforms/StreamPipeline.cpp | 51 ++++++++++--------- 3 files changed, 33 insertions(+), 36 deletions(-) diff --git a/lib/Dialect/TritonGPU/Transforms/Pipeliner/PipelineExpander.cpp b/lib/Dialect/TritonGPU/Transforms/Pipeliner/PipelineExpander.cpp index ae60839ffccd..f99373607785 100644 --- a/lib/Dialect/TritonGPU/Transforms/Pipeliner/PipelineExpander.cpp +++ b/lib/Dialect/TritonGPU/Transforms/Pipeliner/PipelineExpander.cpp @@ -66,8 +66,8 @@ struct LoopPipelinerInternal { bool dynamicLoop; bool peelEpilogue; bool guardEpilogue; - triton::PipeliningOption::PredicateOpFnType predicateFn = nullptr; triton::PipeliningOption::AnnotationlFnType annotateFn = nullptr; + triton::PipeliningOption::PredicateOpFnType predicateFn = nullptr; // When peeling the kernel we generate several version of each value for // different stage of the prologue. This map tracks the mapping between @@ -158,7 +158,6 @@ bool LoopPipelinerInternal::initializeLoopInfo( peelEpilogue = options.peelEpilogue; predicateFn = options.predicateFn; guardEpilogue = options.guardEpilogue; - if ((!peelEpilogue || dynamicLoop) && predicateFn == nullptr) { LDBG("--no epilogue or predicate set -> BAIL"); return false; @@ -776,16 +775,11 @@ LoopPipelinerInternal::emitEpilogue(RewriterBase &rewriter, unsigned nextVersion = currentVersion + 1; Value pred = predicates[currentVersion]; Value prevValue = valueMapping[mapVal][currentVersion]; - Value nextValue = pair.value(); - - if (guardEpilogue) { - nextValue = rewriter.create(loc, pred, nextValue, - prevValue); - } - - returnValues[ri] = nextValue; + auto selOp = rewriter.create(loc, pred, pair.value(), + prevValue); + returnValues[ri] = selOp; if (nextVersion <= maxStage) - setValueMapping(mapVal, nextValue, nextVersion); + setValueMapping(mapVal, selOp, nextVersion); } } } diff --git a/third_party/amd/backend/compiler.py b/third_party/amd/backend/compiler.py index 2b5b451cb821..72ca0ab39c63 100644 --- a/third_party/amd/backend/compiler.py +++ b/third_party/amd/backend/compiler.py @@ -246,7 +246,7 @@ def make_ttgir(mod, metadata, options): global_prefetch = int(os.getenv("TRITON_HIP_GLOBAL_PREFETCH", "0")) local_prefetch = int(os.getenv("TRITON_HIP_LOCAL_PREFETCH", "0")) use_async_copy = int(os.getenv("TRITON_HIP_USE_ASYNC_COPY", "0")) == 1 - must_guard_epilogue = 1 + must_guard_epilogue = int(os.getenv("TRITON_HIP_GUARD_EPILOGUE", "1")) # The `local-prefetch` scheduling variant requires turning on buffer ops. if options.schedule_hint == "local-prefetch": diff --git a/third_party/amd/lib/TritonAMDGPUTransforms/StreamPipeline.cpp b/third_party/amd/lib/TritonAMDGPUTransforms/StreamPipeline.cpp index 664282258be0..0d561f6a5248 100644 --- a/third_party/amd/lib/TritonAMDGPUTransforms/StreamPipeline.cpp +++ b/third_party/amd/lib/TritonAMDGPUTransforms/StreamPipeline.cpp @@ -127,14 +127,6 @@ class StreamPipeliner { stages[SCHED_LOCAL_STORE] = _globalPrefetch; stages[SCHED_LOCAL_LOAD] = lastStage - _localPrefetch; stages[SCHED_COMPUTE] = lastStage; - - options.supportDynamicLoops = true; - options.peelEpilogue = true; - options.guardEpilogue = _mustGuardEpilogue; - if (_mustGuardEpilogue) - options.predicateFn = streamPredication; - else - options.predicateFn = tt::predicateOp; } LogicalResult pipelineLoop(); @@ -210,9 +202,6 @@ class StreamPipeliner { // Capture list of new shared memory buffers. SmallVector sharedMemAllocs; - - // Pipelining options for the PipelineExpander - tt::PipeliningOption options; }; } // namespace @@ -277,7 +266,8 @@ bool StreamPipeliner::safeDAG(Value v, int index) { } return false; } - +// TODO(crobeck): is this valid if we have loop-carried +// results needed for the next epilogue stage? void StreamPipeliner::checkResultResilience() { auto yieldVals = forOp.getYieldedValuesMutable().value(); for (auto [index, res] : llvm::enumerate(forOp.getResults())) { @@ -1016,18 +1006,6 @@ LogicalResult StreamPipeliner::preprocessLoopAndBuildSchedule() { schedule.dump(); }); - // Create the final schedule for the kernel loop. This will dictate the - // stages and order of operations to the pipeline expander. - std::vector> coarseSchedule = - schedule.createFinalSchedule(forOp); - - // Fill out the pipeline options. - options.getScheduleFn = - [coarseSchedule](scf::ForOp, - std::vector> &s) { - s = std::move(coarseSchedule); - }; - OpBuilder builder(forOp); builder.setInsertionPointAfter(forOp); // Explicitly deallocate created allocations. @@ -1041,6 +1019,31 @@ LogicalResult StreamPipeliner::pipelineLoop() { if (failed(preprocessLoopAndBuildSchedule())) return failure(); LDBG("Loop before sending to expander:\n" << *forOp); + // Create the final schedule for the kernel loop. This will dictate the + // stages and order of operations to the pipeline expander. + std::vector> coarseSchedule = + schedule.createFinalSchedule(forOp); + + // Pipelining options for the PipelineExpander + tt::PipeliningOption options; + options.supportDynamicLoops = true; + options.peelEpilogue = true; + options.guardEpilogue = mustGuardEpilogue; + + // predicateFn sets the control flow for the software pipeliner + // streamPredication sets the control flow to conditionally execute the + // epilogue dot but if we can provably unconditionalize the epilogue then we + // can just use default tt::predicateOp + if (mustGuardEpilogue) + options.predicateFn = streamPredication; + else + options.predicateFn = tt::predicateOp; + + options.getScheduleFn = + [&coarseSchedule](scf::ForOp, + std::vector> &s) { + s = std::move(coarseSchedule); + }; IRRewriter rewriter(forOp->getContext()); rewriter.setInsertionPoint(forOp);