Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -51,6 +51,8 @@ struct PipeliningOption {
/// lambda to generate the predicated version of operations.
bool peelEpilogue = true;

bool guardEpilogue = true;

/// Control whether the transformation checks that the number of iterations is
/// greater or equal to the number of stages and skip the transformation if
/// this is not the case. If the loop is dynamic and this is set to true the
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -64,8 +64,9 @@ struct LoopPipelinerInternal {
Value lb;
Value step;
bool dynamicLoop;
triton::PipeliningOption::AnnotationlFnType annotateFn = nullptr;
bool peelEpilogue;
bool guardEpilogue;
triton::PipeliningOption::AnnotationlFnType annotateFn = nullptr;
triton::PipeliningOption::PredicateOpFnType predicateFn = nullptr;

// When peeling the kernel we generate several version of each value for
Expand Down Expand Up @@ -156,6 +157,7 @@ bool LoopPipelinerInternal::initializeLoopInfo(
}
peelEpilogue = options.peelEpilogue;
predicateFn = options.predicateFn;
guardEpilogue = options.guardEpilogue;
if ((!peelEpilogue || dynamicLoop) && predicateFn == nullptr) {
LDBG("--no epilogue or predicate set -> BAIL");
return false;
Expand Down
5 changes: 5 additions & 0 deletions python/src/passes.h
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,11 @@
m.def(name, [](mlir::PassManager &pm, ty0 val0, ty1 val1, ty2 val2, \
ty3 val3) { pm.addPass(builder(val0, val1, val2, val3)); })

#define ADD_PASS_WRAPPER_5(name, builder, ty0, ty1, ty2, ty3, ty4) \
m.def(name, \
[](mlir::PassManager &pm, ty0 val0, ty1 val1, ty2 val2, ty3 val3, \
ty4 val4) { pm.addPass(builder(val0, val1, val2, val3, val4)); })

#define ADD_PASS_OPTION_WRAPPER_1(name, builder, ty0) \
m.def(name, \
[](mlir::PassManager &pm, ty0 val0) { pm.addPass(builder({val0})); })
Expand Down
4 changes: 3 additions & 1 deletion third_party/amd/backend/compiler.py
Original file line number Diff line number Diff line change
Expand Up @@ -246,12 +246,14 @@ def make_ttgir(mod, metadata, options):
global_prefetch = int(os.getenv("TRITON_HIP_GLOBAL_PREFETCH", "0"))
local_prefetch = int(os.getenv("TRITON_HIP_LOCAL_PREFETCH", "0"))
use_async_copy = int(os.getenv("TRITON_HIP_USE_ASYNC_COPY", "0")) == 1
must_guard_epilogue = int(os.getenv("TRITON_HIP_GUARD_EPILOGUE", "1"))

# The `local-prefetch` scheduling variant requires turning on buffer ops.
if options.schedule_hint == "local-prefetch":
global_prefetch = local_prefetch = 1

amd.passes.ttgpuir.add_stream_pipeline(pm, options.num_stages, global_prefetch, local_prefetch, use_async_copy)
amd.passes.ttgpuir.add_stream_pipeline(pm, options.num_stages, global_prefetch, local_prefetch, use_async_copy,
must_guard_epilogue)
if use_async_copy:
amd.passes.ttgpuir.add_coalesce_async_copy(pm, options.arch)
passes.common.add_canonicalizer(pm)
Expand Down
7 changes: 3 additions & 4 deletions third_party/amd/include/TritonAMDGPUTransforms/Passes.h
Original file line number Diff line number Diff line change
Expand Up @@ -8,10 +8,9 @@

namespace mlir {

std::unique_ptr<Pass>
createTritonAMDGPUStreamPipelinePass(int numStages = 2, int globalPrefetch = 0,
int localPrefetch = 0,
bool useAsyncCopy = false);
std::unique_ptr<Pass> createTritonAMDGPUStreamPipelinePass(
int numStages = 2, int globalPrefetch = 0, int localPrefetch = 0,
bool useAsyncCopy = false, bool mustGuardEpilogue = true);

std::unique_ptr<Pass>
createTritonAMDGPUAccelerateMatmulPass(std::string archGenName = std::string(),
Expand Down
3 changes: 3 additions & 0 deletions third_party/amd/include/TritonAMDGPUTransforms/Passes.td
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,9 @@ def TritonAMDGPUStreamPipeline : Pass<"tritonamdgpu-stream-pipeline", "mlir::Mod
Option<"useAsyncCopy", "use_async_copy",
"bool", /*default*/"false",
"Use AsyncCopyGlobalToLocal to directly load to shared memory">,
Option<"mustGuardEpilogue", "must_guard_epilogue",
"bool", /*default*/"true",
"Require conditionalized loads in epiloge of pipelined loop">,
];
}

Expand Down
159 changes: 132 additions & 27 deletions third_party/amd/lib/TritonAMDGPUTransforms/StreamPipeline.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -116,19 +116,17 @@ class StreamPipeliner {

public:
StreamPipeliner(scf::ForOp _forOp, int _numStages, int _globalPrefetch,
int _localPrefetch, bool _useAsyncCopy)
int _localPrefetch, bool _useAsyncCopy,
bool _mustGuardEpilogue)
: forOp(_forOp), numStages(_numStages), numBuffers(1),
useAsyncCopy(_useAsyncCopy), schedule(numStages),
useAsyncCopy(_useAsyncCopy), mustGuardEpilogue(_mustGuardEpilogue),
schedule(numStages),
axisInfoAnalysis(forOp->getParentOfType<ModuleOp>()) {
int lastStage = numStages - 1;
stages[SCHED_GLOBAL_LOAD] = 0;
stages[SCHED_LOCAL_STORE] = _globalPrefetch;
stages[SCHED_LOCAL_LOAD] = lastStage - _localPrefetch;
stages[SCHED_COMPUTE] = lastStage;

options.supportDynamicLoops = true;
options.peelEpilogue = true;
options.predicateFn = streamPredication;
}

LogicalResult pipelineLoop();
Expand All @@ -151,6 +149,10 @@ class StreamPipeliner {
void createStreamCopy(tt::LoadOp loadOp, Value alloc, Value extractIdx);
void createStreamOps();

// Unguard epilogue
bool safeDAG(Value v, int index);
void checkResultResilience();

void scheduleOp(Operation *op, SchedType type, int stage = -1) {
if (stage < 0)
stage = stages[type];
Expand All @@ -170,6 +172,8 @@ class StreamPipeliner {
// Directly store to shared memory with AsyncCopy when pipelining tt.loads
bool useAsyncCopy;

bool mustGuardEpilogue;

// Stage for each SchedType Op
int stages[SCHED_SIZE];
// Cluster for each SchedType Op
Expand Down Expand Up @@ -198,13 +202,85 @@ class StreamPipeliner {

// Capture list of new shared memory buffers.
SmallVector<Value> sharedMemAllocs;

// Pipelining options for the PipelineExpander
tt::PipeliningOption options;
};

} // namespace

bool StreamPipeliner::safeDAG(Value v, int index) {
if (Operation *defOp = v.getDefiningOp()) {
if (auto loadOp = dyn_cast<tt::LoadOp>(defOp)) {
// Loads in the loop will be guarded
return loadOp->getParentOfType<scf::ForOp>() == forOp;
} else if (auto cvtOp = dyn_cast<ttg::ConvertLayoutOp>(defOp)) {
return safeDAG(cvtOp.getSrc(), index);
} else if (auto dotOp = dyn_cast<tt::DotOpInterface>(defOp)) {
// 1 input must be safe
if (!safeDAG(dotOp.getA(), -1) && !safeDAG(dotOp.getB(), -1))
return false;
auto C = dotOp->getOperand(2);
return safeDAG(C, index);
} else if (auto splatOp = dyn_cast<tt::SplatOp>(defOp)) {
// both inputs must be safe
return safeDAG(splatOp.getOperand(), -1);
} else if (auto bcastOp = dyn_cast<tt::BroadcastOp>(defOp)) {
// both inputs must be safe
return safeDAG(bcastOp.getOperand(), -1);
} else if (auto expandOp = dyn_cast<tt::ExpandDimsOp>(defOp)) {
// both inputs must be safe
return safeDAG(expandOp.getOperand(), -1);
} else if (auto addOp = dyn_cast<arith::AddFOp>(defOp)) {
// both inputs must be safe
return safeDAG(addOp.getLhs(), -1) && safeDAG(addOp.getRhs(), -1);
} else if (auto subOp = dyn_cast<arith::SubFOp>(defOp)) {
// both inputs must be safe
return safeDAG(subOp.getLhs(), -1) && safeDAG(subOp.getRhs(), -1);
} else if (auto mulOp = dyn_cast<arith::MulFOp>(defOp)) {
// either input must be safe
return safeDAG(mulOp.getLhs(), -1) || safeDAG(mulOp.getRhs(), -1);
} else if (auto truncOp = dyn_cast<arith::TruncFOp>(defOp)) {
// either input must be safe
return safeDAG(truncOp.getOperand(), -1);
} else if (auto constOp = dyn_cast<arith::ConstantOp>(defOp)) {
// check for constant zero
if (auto attr = dyn_cast<FloatAttr>(constOp.getValue()))
return attr.getValue().isZero();
} else if (auto exp2Op = dyn_cast<math::Exp2Op>(defOp)) {
// either input must be safe
return safeDAG(exp2Op.getOperand(), -1);
} else if (auto selectOp = dyn_cast<arith::SelectOp>(defOp)) {
// both inputs must be safe
return safeDAG(selectOp.getTrueValue(), -1) &&
safeDAG(selectOp.getFalseValue(), -1);
} else if (auto transposeOp = dyn_cast<tt::TransposeOpInterface>(defOp)) {
// input must be safe
return safeDAG(transposeOp.getSrc(), -1);
} else {
// Unknown op default to false
LDBG("Unknown op for unguard epilogue in stream pipeliner");
return false;
}
} else {
// check block arg
auto arg = cast<BlockArgument>(v);
return arg.getArgNumber() == index;
}
return false;
}
// TODO(crobeck): is this valid if we have loop-carried
// results needed for the next epilogue stage?
void StreamPipeliner::checkResultResilience() {
auto yieldVals = forOp.getYieldedValuesMutable().value();
for (auto [index, res] : llvm::enumerate(forOp.getResults())) {
if (!res.use_empty()) {
// Check init value == 0
// Backtrack yield value
Value yieldVal = yieldVals[index].get();
if (!safeDAG(yieldVal, index + 1)) // + induction
mustGuardEpilogue = true;
}
}
}

// Init Schedule Config based on settings and loop characteristics.
// Create clusters in order of ops in loop. This can interleave ops
// from different stages in the same cluster to achieve better backend
Expand All @@ -213,6 +289,10 @@ class StreamPipeliner {
// can cause invalid schedules to be produced.
LogicalResult StreamPipeliner::initSchedule(int maxIndirectionLevel) {

// Check to see if we can unconditionalize epilogue
if (!mustGuardEpilogue)
checkResultResilience();

bool pairedGlobalLoadLocalStore = stages[SCHED_LOCAL_STORE] == 0;
stages[SCHED_LOCAL_STORE] += maxIndirectionLevel;

Expand Down Expand Up @@ -787,11 +867,16 @@ void StreamPipeliner::scheduleRemainingToLastStage() {
// Assign the rest of the ops to the last stage.
// Take care of the ordering of the ops - uses cannot be scheduled to the
// cluster before the definition.
int count = 0;
auto cluster = clusters[SCHED_COMPUTE];
DenseMap<Operation *, tt::CoarseSchedule::Cluster> opToCluster;
for (auto &op : forOp.getBody()->without_terminator()) {
if (schedule.count(&op) == 0)
if (schedule.count(&op) == 0) {
opToCluster[&op] = cluster;
}
}
if (count != 0) {
mustGuardEpilogue = true;
}
SmallVector<Operation *> queue;
for (auto [op, stage, cluster] : schedule.getOpsInOrder(forOp)) {
Expand Down Expand Up @@ -921,18 +1006,6 @@ LogicalResult StreamPipeliner::preprocessLoopAndBuildSchedule() {
schedule.dump();
});

// Create the final schedule for the kernel loop. This will dictate the
// stages and order of operations to the pipeline expander.
std::vector<std::pair<Operation *, unsigned>> coarseSchedule =
schedule.createFinalSchedule(forOp);

// Fill out the pipeline options.
options.getScheduleFn =
[coarseSchedule](scf::ForOp,
std::vector<std::pair<Operation *, unsigned>> &s) {
s = std::move(coarseSchedule);
};

OpBuilder builder(forOp);
builder.setInsertionPointAfter(forOp);
// Explicitly deallocate created allocations.
Expand All @@ -946,6 +1019,31 @@ LogicalResult StreamPipeliner::pipelineLoop() {
if (failed(preprocessLoopAndBuildSchedule()))
return failure();
LDBG("Loop before sending to expander:\n" << *forOp);
// Create the final schedule for the kernel loop. This will dictate the
// stages and order of operations to the pipeline expander.
std::vector<std::pair<Operation *, unsigned>> coarseSchedule =
schedule.createFinalSchedule(forOp);

// Pipelining options for the PipelineExpander
tt::PipeliningOption options;
options.supportDynamicLoops = true;
options.peelEpilogue = true;
options.guardEpilogue = mustGuardEpilogue;

// predicateFn sets the control flow for the software pipeliner
// streamPredication sets the control flow to conditionally execute the
// epilogue dot but if we can provably unconditionalize the epilogue then we
// can just use default tt::predicateOp
if (mustGuardEpilogue)
options.predicateFn = streamPredication;
else
options.predicateFn = tt::predicateOp;

options.getScheduleFn =
[&coarseSchedule](scf::ForOp,
std::vector<std::pair<Operation *, unsigned>> &s) {
s = std::move(coarseSchedule);
};

IRRewriter rewriter(forOp->getContext());
rewriter.setInsertionPoint(forOp);
Expand Down Expand Up @@ -1011,13 +1109,16 @@ void labelLoadOpsForTritonDot(scf::ForOp forOp) {
struct PipelinePass : public TritonAMDGPUStreamPipelineBase<PipelinePass> {
PipelinePass() = default;
PipelinePass(int32_t _numStages, int32_t _globalPrefetch,
int32_t _localPrefetch, bool _useAsyncCopy) {
int32_t _localPrefetch, bool _useAsyncCopy,
bool _mustGuardEpilogue) {
this->numStages = _numStages;

this->globalPrefetch = _globalPrefetch;
this->localPrefetch = _localPrefetch;

this->useAsyncCopy = _useAsyncCopy;

this->mustGuardEpilogue = _mustGuardEpilogue;
}

void runOnOperation() override {
Expand Down Expand Up @@ -1047,15 +1148,19 @@ struct PipelinePass : public TritonAMDGPUStreamPipelineBase<PipelinePass> {
if (!checkPrecondition(forOp))
continue;
StreamPipeliner sp(forOp, tt::getNumStagesOrDefault(forOp, numStages),
globalPrefetch, localPrefetch, useAsyncCopy);
globalPrefetch, localPrefetch, useAsyncCopy,
mustGuardEpilogue);
(void)sp.pipelineLoop();
}
}
};
} // namespace

std::unique_ptr<Pass> mlir::createTritonAMDGPUStreamPipelinePass(
int numStages, int globalPrefetch, int localPrefetch, bool useAsyncCopy) {
std::unique_ptr<Pass>
mlir::createTritonAMDGPUStreamPipelinePass(int numStages, int globalPrefetch,
int localPrefetch, bool useAsyncCopy,
bool mustGuardEpilogue) {
return std::make_unique<PipelinePass>(numStages, globalPrefetch,
localPrefetch, useAsyncCopy);
localPrefetch, useAsyncCopy,
mustGuardEpilogue);
}
4 changes: 2 additions & 2 deletions third_party/amd/python/triton_amd.cc
Original file line number Diff line number Diff line change
Expand Up @@ -78,9 +78,9 @@ void init_triton_amd_passes_ttgpuir(py::module &&m) {
mlir::createTritonAMDGPUFoldTrueCmpIPass);
ADD_PASS_WRAPPER_1("add_block_pingpong",
mlir::createTritonAMDGPUBlockPingpongPass, int32_t);
ADD_PASS_WRAPPER_4("add_stream_pipeline",
ADD_PASS_WRAPPER_5("add_stream_pipeline",
mlir::createTritonAMDGPUStreamPipelinePass, int, int, int,
bool);
bool, bool);
ADD_PASS_WRAPPER_1("add_coalesce_async_copy",
mlir::createTritonAMDGPUCoalesceAsyncCopyPass,
std::string);
Expand Down
Loading