Skip to content
Closed
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -51,6 +51,8 @@ struct PipeliningOption {
/// lambda to generate the predicated version of operations.
bool peelEpilogue = true;

bool guardEpilogue = true;

/// Control whether the transformation checks that the number of iterations is
/// greater or equal to the number of stages and skip the transformation if
/// this is not the case. If the loop is dynamic and this is set to true the
Expand Down
18 changes: 13 additions & 5 deletions lib/Dialect/TritonGPU/Transforms/Pipeliner/PipelineExpander.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -64,9 +64,10 @@ struct LoopPipelinerInternal {
Value lb;
Value step;
bool dynamicLoop;
triton::PipeliningOption::AnnotationlFnType annotateFn = nullptr;
bool peelEpilogue;
bool guardEpilogue;
triton::PipeliningOption::PredicateOpFnType predicateFn = nullptr;
triton::PipeliningOption::AnnotationlFnType annotateFn = nullptr;

// When peeling the kernel we generate several version of each value for
// different stage of the prologue. This map tracks the mapping between
Expand Down Expand Up @@ -156,6 +157,8 @@ bool LoopPipelinerInternal::initializeLoopInfo(
}
peelEpilogue = options.peelEpilogue;
predicateFn = options.predicateFn;
guardEpilogue = options.guardEpilogue;

if ((!peelEpilogue || dynamicLoop) && predicateFn == nullptr) {
LDBG("--no epilogue or predicate set -> BAIL");
return false;
Expand Down Expand Up @@ -773,11 +776,16 @@ LoopPipelinerInternal::emitEpilogue(RewriterBase &rewriter,
unsigned nextVersion = currentVersion + 1;
Value pred = predicates[currentVersion];
Value prevValue = valueMapping[mapVal][currentVersion];
auto selOp = rewriter.create<arith::SelectOp>(loc, pred, pair.value(),
prevValue);
returnValues[ri] = selOp;
Value nextValue = pair.value();

if (guardEpilogue) {
nextValue = rewriter.create<arith::SelectOp>(loc, pred, nextValue,
prevValue);
}

returnValues[ri] = nextValue;
if (nextVersion <= maxStage)
setValueMapping(mapVal, selOp, nextVersion);
setValueMapping(mapVal, nextValue, nextVersion);
}
}
}
Expand Down
5 changes: 5 additions & 0 deletions python/src/passes.h
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,11 @@
m.def(name, [](mlir::PassManager &pm, ty0 val0, ty1 val1, ty2 val2, \
ty3 val3) { pm.addPass(builder(val0, val1, val2, val3)); })

#define ADD_PASS_WRAPPER_5(name, builder, ty0, ty1, ty2, ty3, ty4) \
m.def(name, \
[](mlir::PassManager &pm, ty0 val0, ty1 val1, ty2 val2, ty3 val3, \
ty4 val4) { pm.addPass(builder(val0, val1, val2, val3, val4)); })

#define ADD_PASS_OPTION_WRAPPER_1(name, builder, ty0) \
m.def(name, \
[](mlir::PassManager &pm, ty0 val0) { pm.addPass(builder({val0})); })
Expand Down
4 changes: 3 additions & 1 deletion third_party/amd/backend/compiler.py
Original file line number Diff line number Diff line change
Expand Up @@ -246,12 +246,14 @@ def make_ttgir(mod, metadata, options):
global_prefetch = int(os.getenv("TRITON_HIP_GLOBAL_PREFETCH", "0"))
local_prefetch = int(os.getenv("TRITON_HIP_LOCAL_PREFETCH", "0"))
use_async_copy = int(os.getenv("TRITON_HIP_USE_ASYNC_COPY", "0")) == 1
must_guard_epilogue = 1

# The `local-prefetch` scheduling variant requires turning on buffer ops.
if options.schedule_hint == "local-prefetch":
global_prefetch = local_prefetch = 1

amd.passes.ttgpuir.add_stream_pipeline(pm, options.num_stages, global_prefetch, local_prefetch, use_async_copy)
amd.passes.ttgpuir.add_stream_pipeline(pm, options.num_stages, global_prefetch, local_prefetch, use_async_copy,
must_guard_epilogue)
if use_async_copy:
amd.passes.ttgpuir.add_coalesce_async_copy(pm, options.arch)
passes.common.add_canonicalizer(pm)
Expand Down
7 changes: 3 additions & 4 deletions third_party/amd/include/TritonAMDGPUTransforms/Passes.h
Original file line number Diff line number Diff line change
Expand Up @@ -8,10 +8,9 @@

namespace mlir {

std::unique_ptr<Pass>
createTritonAMDGPUStreamPipelinePass(int numStages = 2, int globalPrefetch = 0,
int localPrefetch = 0,
bool useAsyncCopy = false);
std::unique_ptr<Pass> createTritonAMDGPUStreamPipelinePass(
int numStages = 2, int globalPrefetch = 0, int localPrefetch = 0,
bool useAsyncCopy = false, bool mustGuardEpilogue = true);

std::unique_ptr<Pass>
createTritonAMDGPUAccelerateMatmulPass(std::string archGenName = std::string(),
Expand Down
3 changes: 3 additions & 0 deletions third_party/amd/include/TritonAMDGPUTransforms/Passes.td
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,9 @@ def TritonAMDGPUStreamPipeline : Pass<"tritonamdgpu-stream-pipeline", "mlir::Mod
Option<"useAsyncCopy", "use_async_copy",
"bool", /*default*/"false",
"Use AsyncCopyGlobalToLocal to directly load to shared memory">,
Option<"mustGuardEpilogue", "must_guard_epilogue",
"bool", /*default*/"true",
"Require conditionalized loads in epiloge of pipelined loop">,
];
}

Expand Down
120 changes: 111 additions & 9 deletions third_party/amd/lib/TritonAMDGPUTransforms/StreamPipeline.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -116,9 +116,11 @@ class StreamPipeliner {

public:
StreamPipeliner(scf::ForOp _forOp, int _numStages, int _globalPrefetch,
int _localPrefetch, bool _useAsyncCopy)
int _localPrefetch, bool _useAsyncCopy,
bool _mustGuardEpilogue)
: forOp(_forOp), numStages(_numStages), numBuffers(1),
useAsyncCopy(_useAsyncCopy), schedule(numStages),
useAsyncCopy(_useAsyncCopy), mustGuardEpilogue(_mustGuardEpilogue),
schedule(numStages),
axisInfoAnalysis(forOp->getParentOfType<ModuleOp>()) {
int lastStage = numStages - 1;
stages[SCHED_GLOBAL_LOAD] = 0;
Expand All @@ -128,7 +130,11 @@ class StreamPipeliner {

options.supportDynamicLoops = true;
options.peelEpilogue = true;
options.predicateFn = streamPredication;
options.guardEpilogue = _mustGuardEpilogue;
if (_mustGuardEpilogue)
options.predicateFn = streamPredication;
else
options.predicateFn = tt::predicateOp;
}

LogicalResult pipelineLoop();
Expand All @@ -151,6 +157,10 @@ class StreamPipeliner {
void createStreamCopy(tt::LoadOp loadOp, Value alloc, Value extractIdx);
void createStreamOps();

// Unguard epilogue
bool safeDAG(Value v, int index);
void checkResultResilience();

void scheduleOp(Operation *op, SchedType type, int stage = -1) {
if (stage < 0)
stage = stages[type];
Expand All @@ -170,6 +180,8 @@ class StreamPipeliner {
// Directly store to shared memory with AsyncCopy when pipelining tt.loads
bool useAsyncCopy;

bool mustGuardEpilogue;

// Stage for each SchedType Op
int stages[SCHED_SIZE];
// Cluster for each SchedType Op
Expand Down Expand Up @@ -205,6 +217,80 @@ class StreamPipeliner {

} // namespace

bool StreamPipeliner::safeDAG(Value v, int index) {
if (Operation *defOp = v.getDefiningOp()) {
if (auto loadOp = dyn_cast<tt::LoadOp>(defOp)) {
// Loads in the loop will be guarded
return loadOp->getParentOfType<scf::ForOp>() == forOp;
} else if (auto cvtOp = dyn_cast<ttg::ConvertLayoutOp>(defOp)) {
return safeDAG(cvtOp.getSrc(), index);
} else if (auto dotOp = dyn_cast<tt::DotOpInterface>(defOp)) {
// 1 input must be safe
if (!safeDAG(dotOp.getA(), -1) && !safeDAG(dotOp.getB(), -1))
return false;
auto C = dotOp->getOperand(2);
return safeDAG(C, index);
} else if (auto splatOp = dyn_cast<tt::SplatOp>(defOp)) {
// both inputs must be safe
return safeDAG(splatOp.getOperand(), -1);
} else if (auto bcastOp = dyn_cast<tt::BroadcastOp>(defOp)) {
// both inputs must be safe
return safeDAG(bcastOp.getOperand(), -1);
} else if (auto expandOp = dyn_cast<tt::ExpandDimsOp>(defOp)) {
// both inputs must be safe
return safeDAG(expandOp.getOperand(), -1);
} else if (auto addOp = dyn_cast<arith::AddFOp>(defOp)) {
// both inputs must be safe
return safeDAG(addOp.getLhs(), -1) && safeDAG(addOp.getRhs(), -1);
} else if (auto subOp = dyn_cast<arith::SubFOp>(defOp)) {
// both inputs must be safe
return safeDAG(subOp.getLhs(), -1) && safeDAG(subOp.getRhs(), -1);
} else if (auto mulOp = dyn_cast<arith::MulFOp>(defOp)) {
// either input must be safe
return safeDAG(mulOp.getLhs(), -1) || safeDAG(mulOp.getRhs(), -1);
} else if (auto truncOp = dyn_cast<arith::TruncFOp>(defOp)) {
// either input must be safe
return safeDAG(truncOp.getOperand(), -1);
} else if (auto constOp = dyn_cast<arith::ConstantOp>(defOp)) {
// check for constant zero
if (auto attr = dyn_cast<FloatAttr>(constOp.getValue()))
return attr.getValue().isZero();
} else if (auto exp2Op = dyn_cast<math::Exp2Op>(defOp)) {
// either input must be safe
return safeDAG(exp2Op.getOperand(), -1);
} else if (auto selectOp = dyn_cast<arith::SelectOp>(defOp)) {
// both inputs must be safe
return safeDAG(selectOp.getTrueValue(), -1) &&
safeDAG(selectOp.getFalseValue(), -1);
} else if (auto transposeOp = dyn_cast<tt::TransposeOpInterface>(defOp)) {
// input must be safe
return safeDAG(transposeOp.getSrc(), -1);
} else {
// Unknown op default to false
LDBG("Unknown op for unguard epilogue in stream pipeliner");
return false;
}
} else {
// check block arg
auto arg = cast<BlockArgument>(v);
return arg.getArgNumber() == index;
}
return false;
}

void StreamPipeliner::checkResultResilience() {
auto yieldVals = forOp.getYieldedValuesMutable().value();
for (auto [index, res] : llvm::enumerate(forOp.getResults())) {
if (!res.use_empty()) {
// Check init value == 0
// Backtrack yield value
Value yieldVal = yieldVals[index].get();
if (!safeDAG(yieldVal, index + 1)) // + induction
mustGuardEpilogue = true;
}
}
}

// Init Schedule Config based on settings and loop characteristics.
// Create clusters in order of ops in loop. This can interleave ops
// from different stages in the same cluster to achieve better backend
Expand All @@ -213,6 +299,10 @@ class StreamPipeliner {
// can cause invalid schedules to be produced.
LogicalResult StreamPipeliner::initSchedule(int maxIndirectionLevel) {

// Check to see if we can unconditionalize epilogue
if (!mustGuardEpilogue)
checkResultResilience();

bool pairedGlobalLoadLocalStore = stages[SCHED_LOCAL_STORE] == 0;
stages[SCHED_LOCAL_STORE] += maxIndirectionLevel;

Expand Down Expand Up @@ -787,11 +877,16 @@ void StreamPipeliner::scheduleRemainingToLastStage() {
// Assign the rest of the ops to the last stage.
// Take care of the ordering of the ops - uses cannot be scheduled to the
// cluster before the definition.
int count = 0;
auto cluster = clusters[SCHED_COMPUTE];
DenseMap<Operation *, tt::CoarseSchedule::Cluster> opToCluster;
for (auto &op : forOp.getBody()->without_terminator()) {
if (schedule.count(&op) == 0)
if (schedule.count(&op) == 0) {
opToCluster[&op] = cluster;
}
}
if (count != 0) {
mustGuardEpilogue = true;
}
SmallVector<Operation *> queue;
for (auto [op, stage, cluster] : schedule.getOpsInOrder(forOp)) {
Expand Down Expand Up @@ -1011,13 +1106,16 @@ void labelLoadOpsForTritonDot(scf::ForOp forOp) {
struct PipelinePass : public TritonAMDGPUStreamPipelineBase<PipelinePass> {
PipelinePass() = default;
PipelinePass(int32_t _numStages, int32_t _globalPrefetch,
int32_t _localPrefetch, bool _useAsyncCopy) {
int32_t _localPrefetch, bool _useAsyncCopy,
bool _mustGuardEpilogue) {
this->numStages = _numStages;

this->globalPrefetch = _globalPrefetch;
this->localPrefetch = _localPrefetch;

this->useAsyncCopy = _useAsyncCopy;

this->mustGuardEpilogue = _mustGuardEpilogue;
}

void runOnOperation() override {
Expand Down Expand Up @@ -1047,15 +1145,19 @@ struct PipelinePass : public TritonAMDGPUStreamPipelineBase<PipelinePass> {
if (!checkPrecondition(forOp))
continue;
StreamPipeliner sp(forOp, tt::getNumStagesOrDefault(forOp, numStages),
globalPrefetch, localPrefetch, useAsyncCopy);
globalPrefetch, localPrefetch, useAsyncCopy,
mustGuardEpilogue);
(void)sp.pipelineLoop();
}
}
};
} // namespace

std::unique_ptr<Pass> mlir::createTritonAMDGPUStreamPipelinePass(
int numStages, int globalPrefetch, int localPrefetch, bool useAsyncCopy) {
std::unique_ptr<Pass>
mlir::createTritonAMDGPUStreamPipelinePass(int numStages, int globalPrefetch,
int localPrefetch, bool useAsyncCopy,
bool mustGuardEpilogue) {
return std::make_unique<PipelinePass>(numStages, globalPrefetch,
localPrefetch, useAsyncCopy);
localPrefetch, useAsyncCopy,
mustGuardEpilogue);
}
4 changes: 2 additions & 2 deletions third_party/amd/python/triton_amd.cc
Original file line number Diff line number Diff line change
Expand Up @@ -78,9 +78,9 @@ void init_triton_amd_passes_ttgpuir(py::module &&m) {
mlir::createTritonAMDGPUFoldTrueCmpIPass);
ADD_PASS_WRAPPER_1("add_block_pingpong",
mlir::createTritonAMDGPUBlockPingpongPass, int32_t);
ADD_PASS_WRAPPER_4("add_stream_pipeline",
ADD_PASS_WRAPPER_5("add_stream_pipeline",
mlir::createTritonAMDGPUStreamPipelinePass, int, int, int,
bool);
bool, bool);
ADD_PASS_WRAPPER_1("add_coalesce_async_copy",
mlir::createTritonAMDGPUCoalesceAsyncCopyPass,
std::string);
Expand Down
Loading