Skip to content

Commit a6e7434

Browse files
authored
[SWP] Dedup the code that checks if LoadOp can be converted to cpasync (#8529)
During SWP, we are checking if a given `LoadOp` should be lowered to `AsyncCopyGlobalToLocalOp` twice - first in `AssignLatency`, and `LowerLoops` next. The two checks duplicate non-trivial conditions like `copyVecBytes >= 4` or `op.getResultTypes()[0].getIntOrFloatBitWidth() >= 32`. I moved the `isPipeliningBeneficial` function from `AssignLatency` into utilities so that it can also be used by `LowerLoops`. This will also be used by WS to determine if `LoadOp` should be lowered to cpasync and assigned to the load partition.
1 parent d703656 commit a6e7434

File tree

4 files changed

+56
-78
lines changed

4 files changed

+56
-78
lines changed

include/triton/Dialect/TritonGPU/Transforms/PipeliningUtility.h

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -184,6 +184,13 @@ getLastUseOfPipelinedOp(ArrayRef<Operation *> ops, scf::ForOp forOp,
184184

185185
// Clean up attributes passing over schedules across stages in pipelining
186186
void removePipeliningAttributes(ModuleOp moduleOp);
187+
188+
// For LoadOp, DescriptorLoad, and DescriptorGather ops, determine if
189+
// they should be pipelined.
190+
bool isPipeliningBeneficial(Operation *op,
191+
triton::ModuleAxisInfoAnalysis &axisInfoAnalysis,
192+
bool filterSmall = true);
193+
187194
} // namespace triton
188195
} // namespace mlir
189196

lib/Dialect/TritonGPU/Transforms/Pipeliner/AssignLatencies.cpp

Lines changed: 1 addition & 60 deletions
Original file line numberDiff line numberDiff line change
@@ -88,64 +88,6 @@ class AssignLoadLatencies {
8888
scf::ForOp forOp;
8989
int numStages;
9090
DenseMap<Operation *, int> &opLatency;
91-
92-
public:
93-
static bool canHaveSharedEncoding(tt::LoadOp op) {
94-
// If used by an user with DotOp encoding, all the uses must be compatible.
95-
bool incompatible = false;
96-
getSharedEncIfAllUsersAreDotEnc(op.getResult(), incompatible);
97-
return !incompatible;
98-
}
99-
100-
static bool
101-
isPipeliningBeneficial(Operation *op, Operation *finalUser,
102-
tt::ModuleAxisInfoAnalysis &axisInfoAnalysis,
103-
bool filterSmall) {
104-
if (auto loadOp = dyn_cast<tt::LoadOp>(op)) {
105-
if (filterSmall && !canBeConvertedToAsyncLoad(loadOp, axisInfoAnalysis)) {
106-
LDBG("Load " << *loadOp << " is too small for pipelining");
107-
return false;
108-
}
109-
}
110-
if (isa<tt::DescriptorLoadOp, tt::DescriptorGatherOp>(op))
111-
return true;
112-
if (!canHaveSharedEncoding(cast<tt::LoadOp>(op))) {
113-
LDBG("Load " << *op << " cannot have shared encoding");
114-
return false;
115-
}
116-
117-
ttg::SharedEncodingTrait localAllocEnc;
118-
if (llvm::any_of(op->getUsers(), [&](Operation *user) {
119-
return isa<ttg::LocalAllocOp>(user);
120-
})) {
121-
for (auto user : op->getUsers()) {
122-
auto localAlloc = dyn_cast<ttg::LocalAllocOp>(user);
123-
if (!localAlloc)
124-
continue;
125-
auto enc = mlir::cast<ttg::SharedEncodingTrait>(
126-
localAlloc.getType().getEncoding());
127-
if (!localAllocEnc) {
128-
localAllocEnc = enc;
129-
}
130-
if (enc != localAllocEnc) {
131-
// If the load is used by a LocalAllocOp, all the users need to have
132-
// the same encoding.
133-
return false;
134-
}
135-
}
136-
}
137-
138-
if (localAllocEnc) {
139-
auto registerTy = cast<RankedTensorType>(op->getResultTypes()[0]);
140-
auto vecBytes = getCopyVecBytes(registerTy, localAllocEnc);
141-
if (filterSmall && vecBytes < 4) {
142-
// At least 4 bytes need to be consecutive for cp.async
143-
return false;
144-
}
145-
}
146-
147-
return true;
148-
}
14991
};
15092

15193
class AssignMMALatencies {
@@ -280,8 +222,7 @@ loadOpsToIndirectionLevel(scf::ForOp forOp, bool pipelineWithoutDot,
280222
if (!seen.insert(op).second || excluded.count(op))
281223
return;
282224
if (isa<tt::LoadOp, tt::DescriptorLoadOp, tt::DescriptorGatherOp>(op)) {
283-
if (!AssignLoadLatencies::isPipeliningBeneficial(
284-
op, finalUser, axisInfoAnalysis, filterSmall))
225+
if (!isPipeliningBeneficial(op, axisInfoAnalysis, filterSmall))
285226
return;
286227
if (loadOpToIndLevel.count(op)) {
287228
int level = loadOpToIndLevel[op].first;

lib/Dialect/TritonGPU/Transforms/Pipeliner/LowerLoops.cpp

Lines changed: 9 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -453,26 +453,17 @@ scf::ForOp lowerLoads(scf::ForOp forOp, CoarseSchedule &schedule,
453453
continue;
454454
}
455455
SharedEncodingTrait sharedEncoding;
456-
bool canUseAsyncCp = false;
457-
if (!isa<RankedTensorType>(op.getResultTypes()[0])) {
458-
canUseAsyncCp = op.getResultTypes()[0].getIntOrFloatBitWidth() >= 32;
459-
sharedEncoding = ttg::SwizzledSharedEncodingAttr::get(
460-
forOp.getContext(), 1, 1, 1, {0},
461-
ttg::CTALayoutAttr::get(forOp.getContext(), {1}, {1}, {0}));
462-
if (canUseAsyncCp) {
456+
bool canUseAsyncCp =
457+
triton::isPipeliningBeneficial(&op, axisInfoAnalysis);
458+
if (canUseAsyncCp) {
459+
if (!isa<RankedTensorType>(op.getResultTypes()[0])) {
460+
sharedEncoding = ttg::SwizzledSharedEncodingAttr::get(
461+
forOp.getContext(), 1, 1, 1, {0},
462+
ttg::CTALayoutAttr::get(forOp.getContext(), {1}, {1}, {0}));
463463
scalarLoads.push_back(&op);
464+
} else {
465+
sharedEncoding = getSharedEncoding(&op);
464466
}
465-
} else {
466-
sharedEncoding = getSharedEncoding(&op);
467-
// Do not create async loads for small loads (cp.async requires at least
468-
// 4 bytes)
469-
canUseAsyncCp =
470-
isa<tt::LoadOp>(op) &&
471-
canBeConvertedToAsyncLoad(cast<tt::LoadOp>(op), axisInfoAnalysis);
472-
int copyVecBytes = getCopyVecBytes(
473-
cast<RankedTensorType>(op.getResultTypes()[0]), sharedEncoding);
474-
475-
canUseAsyncCp &= copyVecBytes >= 4;
476467
}
477468
if (canUseAsyncCp || isTMALoad(&op)) {
478469
if (loadRequiresAdditionalBuffer(&op)) {

lib/Dialect/TritonGPU/Transforms/Pipeliner/PipeliningUtility.cpp

Lines changed: 39 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -603,6 +603,10 @@ ttg::SharedEncodingTrait mlir::triton::getSharedEncoding(RankedTensorType ty) {
603603
}
604604

605605
ttg::SharedEncodingTrait mlir::triton::getSharedEncoding(Operation *op) {
606+
if (!isa<RankedTensorType>(op->getResultTypes()[0])) {
607+
return nullptr;
608+
}
609+
606610
// Try to use local alloc encoding if possible.
607611
ttg::SharedEncodingTrait localAllocEnc;
608612
if (llvm::any_of(op->getUsers(), [&](Operation *user) {
@@ -933,3 +937,38 @@ void triton::removePipeliningAttributes(ModuleOp moduleOp) {
933937
op->removeAttr(mlir::triton::kScheduledMaxStageAttrName);
934938
});
935939
}
940+
941+
static bool canHaveSharedEncoding(tt::LoadOp op) {
942+
// If used by an user with DotOp encoding, all the uses must be compatible.
943+
bool incompatible = false;
944+
getSharedEncIfAllUsersAreDotEnc(op.getResult(), incompatible);
945+
return !incompatible;
946+
}
947+
948+
bool triton::isPipeliningBeneficial(
949+
Operation *op, tt::ModuleAxisInfoAnalysis &axisInfoAnalysis,
950+
bool filterSmall) {
951+
if (auto loadOp = dyn_cast<tt::LoadOp>(op)) {
952+
if (filterSmall && !canBeConvertedToAsyncLoad(loadOp, axisInfoAnalysis)) {
953+
LDBG("Load " << *loadOp << " is too small for pipelining");
954+
return false;
955+
}
956+
}
957+
if (isa<tt::DescriptorLoadOp, tt::DescriptorGatherOp>(op))
958+
return true;
959+
if (!canHaveSharedEncoding(cast<tt::LoadOp>(op))) {
960+
LDBG("Load " << *op << " cannot have shared encoding");
961+
return false;
962+
}
963+
964+
if (auto localAllocEnc = getSharedEncoding(op)) {
965+
auto registerTy = cast<RankedTensorType>(op->getResultTypes()[0]);
966+
auto vecBytes = mlir::triton::getCopyVecBytes(registerTy, localAllocEnc);
967+
if (filterSmall && vecBytes < 4) {
968+
// At least 4 bytes need to be consecutive for cp.async
969+
return false;
970+
}
971+
}
972+
973+
return true;
974+
}

0 commit comments

Comments
 (0)