Skip to content

Commit cea35da

Browse files
plognjenOgnjen Plavsicoplavsic
authored
[AMD] Add basics to allow bypass LDS for dot RHS (#5350)
The AMDBypassLDSForDotOperandPass implements a strategy to bypass using the Local Data Share (LDS) for one of the operands in an MFMA dot operation. Under certain conditions, the dot layout of one of the operands allows direct loading from HBM to VGPRs in the MFMA dot layout, without losing of vectorization of global loads or increasing the number of global loads due to shared data between threads. --------- Co-authored-by: Ognjen Plavsic <[email protected]> Co-authored-by: Ognjen Plavsic <[email protected]>
1 parent 734d9f2 commit cea35da

File tree

14 files changed

+473
-57
lines changed

14 files changed

+473
-57
lines changed

bin/RegisterTritonDialects.h

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -60,6 +60,7 @@ inline void registerTritonDialects(mlir::DialectRegistry &registry) {
6060
// TritonAMDGPUTransforms passes
6161
mlir::registerTritonAMDGPUAccelerateMatmul();
6262
mlir::registerTritonAMDGPUOptimizeEpilogue();
63+
mlir::registerTritonAMDGPUBypassLDSForDotOperand();
6364
mlir::registerTritonAMDGPUReorderInstructions();
6465
mlir::registerTritonAMDGPUBlockPingpong();
6566
mlir::registerTritonAMDGPUStreamPipeline();

include/triton/Dialect/TritonGPU/Transforms/Utility.h

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -205,6 +205,8 @@ enum class MMALoadType {
205205
};
206206
MMALoadType getMMALoadType(Operation *loadOp);
207207

208+
// Convert \param op operands and results to layout \param encoding.
209+
void convertOpEncoding(Attribute encoding, Operation *op);
208210
} // namespace mlir
209211

210212
#endif // TRITON_DIALECT_TRITONGPU_TRANSFORMS_UTILITY_H_

include/triton/Tools/Sys/GetEnv.hpp

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -31,6 +31,7 @@ inline const std::set<std::string> CACHE_INVALIDATING_ENV_VARS = {
3131
"TRITON_ENABLE_LLVM_DEBUG",
3232
"TRITON_HIP_STREAM_PREFETCH",
3333
"TRITON_HIP_USE_BLOCK_PINGPONG",
34+
"TRITON_HIP_BYPASS_LDS_FOR_DOT",
3435
"TRITON_LLVM_DEBUG_ONLY",
3536
"TRITON_ENABLE_ASAN",
3637
"TRITON_OVERRIDE_ARCH",

lib/Dialect/TritonGPU/Transforms/Coalesce.cpp

Lines changed: 1 addition & 50 deletions
Original file line numberDiff line numberDiff line change
@@ -104,55 +104,6 @@ struct CoalescePass : public impl::TritonGPUCoalesceBase<CoalescePass> {
104104
threadsPerWarp, CTALayout);
105105
}
106106

107-
static Type getNewType(Type type, Attribute encoding) {
108-
RankedTensorType tensorType = cast<RankedTensorType>(type);
109-
return RankedTensorType::get(tensorType.getShape(),
110-
tensorType.getElementType(), encoding);
111-
}
112-
113-
void coalesceOp(Attribute encoding, Operation *op) {
114-
OpBuilder builder(op);
115-
// Convert operands
116-
// For load/store with tensor pointers, we don't have to change the
117-
// operands' type, we do this by changing the outputs' type of
118-
// `make_tensor_ptr`
119-
SmallVector<Value, 4> newArgs;
120-
for (auto operand : op->getOperands()) {
121-
auto tensorType = dyn_cast<RankedTensorType>(operand.getType());
122-
if (tensorType &&
123-
!isa<triton::gpu::SharedEncodingAttr>(tensorType.getEncoding())) {
124-
Type newType = getNewType(tensorType, encoding);
125-
newArgs.push_back(builder.create<triton::gpu::ConvertLayoutOp>(
126-
op->getLoc(), newType, operand));
127-
} else {
128-
newArgs.push_back(operand);
129-
}
130-
}
131-
132-
// Convert output types
133-
SmallVector<Type, 4> newTypes;
134-
for (auto t : op->getResultTypes()) {
135-
bool isAsync = isa<triton::gpu::AsyncCopyGlobalToLocalOp>(op);
136-
newTypes.push_back(isAsync ? t : getNewType(t, encoding));
137-
}
138-
139-
// Construct new op with the new encoding
140-
Operation *newOp =
141-
builder.create(op->getLoc(), op->getName().getIdentifier(), newArgs,
142-
newTypes, op->getAttrs());
143-
144-
// Cast the results back to the original layout
145-
for (size_t i = 0; i < op->getNumResults(); i++) {
146-
Value newResult = newOp->getResult(i);
147-
if (newTypes[i] != op->getResultTypes()[i]) {
148-
newResult = builder.create<triton::gpu::ConvertLayoutOp>(
149-
op->getLoc(), op->getResult(i).getType(), newResult);
150-
}
151-
op->getResult(i).replaceAllUsesWith(newResult);
152-
}
153-
op->erase();
154-
}
155-
156107
void runOnOperation() override {
157108
// Run axis info analysis
158109
ModuleOp moduleOp = getOperation();
@@ -187,7 +138,7 @@ struct CoalescePass : public impl::TritonGPUCoalesceBase<CoalescePass> {
187138
// 4. Convert the output of this new memory op back to L1
188139
// 5. Replace all the uses of the original memory op by the new one
189140
for (auto &kv : layoutMap) {
190-
coalesceOp(kv.second, kv.first);
141+
convertOpEncoding(kv.second, kv.first);
191142
}
192143
}
193144
};

lib/Dialect/TritonGPU/Transforms/RemoveLayoutConversions.cpp

Lines changed: 42 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1022,6 +1022,43 @@ void LayoutRematerialization::hoistConvertOnTopOfExtOrBroadcast() {
10221022
}
10231023
}
10241024

1025+
bool shouldPropagateConversion(ConvertLayoutOp convertOp) {
1026+
RankedTensorType targetType = convertOp.getType();
1027+
auto dotEnc = dyn_cast<DotOperandEncodingAttr>(targetType.getEncoding());
1028+
// If the target encoding is not DotOperandEncodingAttr, allow propagation.
1029+
if (!dotEnc) {
1030+
return true;
1031+
}
1032+
// Skip conversions to DotOperandEncodingAttr when the operand index is 0.
1033+
// This heuristic is applied to prevent moving the blocked->dot conversion of
1034+
// the Q tensor (a loop invariant in Flash Attention) outside the loop. Doing
1035+
// so can increase register pressure and cause spilling in some cases.
1036+
if (dotEnc.getOpIdx() == 0) {
1037+
return false;
1038+
}
1039+
// Skip conversions to DotOperandEncodingAttr when the operand index is 1 if
1040+
// it's not intentionally placed above a load as we have to be a bit more
1041+
// careful with the heuristics for both correctness and performance.
1042+
// TODO: Fix this logic to avoid propagating conversions backward unless
1043+
// it reduces the total number of conversions.
1044+
assert(dotEnc.getOpIdx() == 1);
1045+
SetVector<Operation *> slice;
1046+
BackwardSliceOptions opt;
1047+
opt.omitBlockArguments = true;
1048+
opt.filter = [&](Operation *op) {
1049+
return op->getParentRegion() == convertOp->getParentRegion();
1050+
};
1051+
getBackwardSlice(convertOp.getOperation(), &slice, opt);
1052+
1053+
for (Operation *currOp : slice) {
1054+
if (isa<LoadOp>(currOp)) {
1055+
return false;
1056+
}
1057+
}
1058+
// Allow propagation if no LoadOp is found.
1059+
return true;
1060+
}
1061+
10251062
void LayoutRematerialization::hoistConvertIntoConditionals() {
10261063
// Go through each ConvertLayoutOp.
10271064
SmallVector<ConvertLayoutOp> convertOps;
@@ -1040,11 +1077,11 @@ void LayoutRematerialization::hoistConvertIntoConditionals() {
10401077

10411078
void LayoutRematerialization::backwardRematerialization(
10421079
ConvertLayoutOp convertOp) {
1043-
// we don't handle conversions to DotOperandEncodingAttr
1044-
// this is a heuristic to accommodate fused attention
10451080
RankedTensorType targetType = convertOp.getType();
1046-
if (isa<DotOperandEncodingAttr>(targetType.getEncoding()))
1081+
if (!shouldPropagateConversion(convertOp)) {
10471082
return;
1083+
}
1084+
10481085
Value oldV = convertOp.getSrc();
10491086
LDBG("check backward remat with source " << oldV << " encoding "
10501087
<< targetType.getEncoding());
@@ -1083,11 +1120,10 @@ void LayoutRematerialization::backwardRematerialization(
10831120
// of the convert.
10841121
void LayoutRematerialization::hoistConvertOnTopOfExtOrBroadcast(
10851122
ConvertLayoutOp convertOp) {
1086-
// we don't handle conversions to DotOperandEncodingAttr
1087-
// this is a heuristics to accommodate fused attention
10881123
RankedTensorType targetType = convertOp.getType();
1089-
if (isa<DotOperandEncodingAttr>(targetType.getEncoding()))
1124+
if (!shouldPropagateConversion(convertOp)) {
10901125
return;
1126+
}
10911127

10921128
auto isExtOrBroadcastOp = [](Operation *op) {
10931129
if (isa<arith::ExtSIOp, arith::ExtUIOp, arith::ExtFOp, BroadcastOp,

lib/Dialect/TritonGPU/Transforms/Utility.cpp

Lines changed: 48 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1057,6 +1057,54 @@ MMALoadType getMMALoadType(Operation *loadOp) {
10571057
}
10581058
}
10591059

1060+
static Type getNewType(Type type, Attribute encoding) {
1061+
RankedTensorType tensorType = cast<RankedTensorType>(type);
1062+
return RankedTensorType::get(tensorType.getShape(),
1063+
tensorType.getElementType(), encoding);
1064+
}
1065+
1066+
void convertOpEncoding(Attribute encoding, Operation *op) {
1067+
OpBuilder builder(op);
1068+
// Convert operands
1069+
// For load/store with tensor pointers, we don't have to change the
1070+
// operands' type, we do this by changing the outputs' type of
1071+
// `make_tensor_ptr`
1072+
SmallVector<Value, 4> newArgs;
1073+
for (auto operand : op->getOperands()) {
1074+
auto tensorType = dyn_cast<RankedTensorType>(operand.getType());
1075+
if (tensorType &&
1076+
!isa<triton::gpu::SharedEncodingAttr>(tensorType.getEncoding())) {
1077+
Type newType = getNewType(tensorType, encoding);
1078+
newArgs.push_back(builder.create<triton::gpu::ConvertLayoutOp>(
1079+
op->getLoc(), newType, operand));
1080+
} else {
1081+
newArgs.push_back(operand);
1082+
}
1083+
}
1084+
1085+
// Convert output types
1086+
SmallVector<Type, 4> newTypes;
1087+
for (auto t : op->getResultTypes()) {
1088+
bool isAsync = isa<triton::gpu::AsyncCopyGlobalToLocalOp>(op);
1089+
newTypes.push_back(isAsync ? t : getNewType(t, encoding));
1090+
}
1091+
1092+
// Construct new op with the new encoding
1093+
Operation *newOp = builder.create(op->getLoc(), op->getName().getIdentifier(),
1094+
newArgs, newTypes, op->getAttrs());
1095+
1096+
// Cast the results back to the original layout
1097+
for (size_t i = 0; i < op->getNumResults(); i++) {
1098+
Value newResult = newOp->getResult(i);
1099+
if (newTypes[i] != op->getResultTypes()[i]) {
1100+
newResult = builder.create<triton::gpu::ConvertLayoutOp>(
1101+
op->getLoc(), op->getResult(i).getType(), newResult);
1102+
}
1103+
op->getResult(i).replaceAllUsesWith(newResult);
1104+
}
1105+
op->erase();
1106+
}
1107+
10601108
namespace {
10611109

10621110
/// Detect dead arguments in scf.for op by assuming all the values are dead and

0 commit comments

Comments
 (0)