Skip to content

Commit fe8aacb

Browse files
committed
Cleanup StreamPipeliner changes
1 parent ee16b79 commit fe8aacb

File tree

1 file changed

+21
-31
lines changed

1 file changed

+21
-31
lines changed

third_party/amd/lib/TritonAMDGPUTransforms/StreamPipeline.cpp

Lines changed: 21 additions & 31 deletions
Original file line numberDiff line numberDiff line change
@@ -261,34 +261,23 @@ void StreamPipeliner::createStreamCopy(tt::LoadOp loadOp, Value alloc,
261261

262262
auto sharedEncodingAttr =
263263
cast<ttg::SharedEncodingAttr>(allocTy.getEncoding());
264-
llvm::outs() << "Shared alloc: \n";
265-
alloc.print(llvm::outs());
266-
llvm::outs() << "\n";
264+
auto srcTy = dyn_cast<triton::gpu::TensorOrMemDesc>(src.getType());
267265

268-
bool emitAsyncCopy = false;
266+
bool useAsyncCopy = false;
269267

270-
auto srcTy = dyn_cast<triton::gpu::TensorOrMemDesc>(src.getType());
271-
// We can use AsyncCopy if we do not swizzle into smem
272-
// TODO (alex) ensure it's 2D
268+
// Note that we can only use AsyncCopy when have coalesced LDS writes (e.g. no
269+
// swizzeling).
273270
if (triton::tools::getBoolEnv("AMDGCN_USE_DIRECT_TO_LDS") &&
274271
sharedEncodingAttr.getPerPhase() == 1 &&
275272
sharedEncodingAttr.getMaxPhase() == 1 &&
273+
sharedEncodingAttr.getOrder().size() == 2 &&
276274
llvm::equal(sharedEncodingAttr.getOrder(),
277275
ttg::getOrder(srcTy.getEncoding()))) {
278-
emitAsyncCopy = true;
276+
useAsyncCopy = true;
279277
}
280-
llvm::outs() << "Emit async: " << emitAsyncCopy << "\n";
281278

282279
SmallVector<Value> copyOffsets(allocTy.getRank(), zero);
283280

284-
Operation *newLoadOp{};
285-
if (!emitAsyncCopy) {
286-
newLoadOp = builder.clone(*loadOp);
287-
auto [stage, cluster] = schedule[loadOp];
288-
schedule.erase(loadOp);
289-
schedule.insert(newLoadOp, stage, cluster);
290-
}
291-
292281
// Extract part.
293282
SmallVector<Value> loadOffsets(allocTy.getRank(), zero);
294283
loadOffsets[0] = extractIdx;
@@ -300,14 +289,20 @@ void StreamPipeliner::createStreamCopy(tt::LoadOp loadOp, Value alloc,
300289
auto viewLoad =
301290
builder.create<ttg::MemDescSubviewOp>(loc, subviewTy, alloc, loadOffsets);
302291

292+
Operation *newLoadOp{};
303293
Operation *wait{};
304-
if (emitAsyncCopy) {
294+
295+
if (!useAsyncCopy) {
296+
newLoadOp = builder.clone(*loadOp);
297+
auto [stage, cluster] = schedule[loadOp];
298+
schedule.erase(loadOp);
299+
schedule.insert(newLoadOp, stage, cluster);
300+
} else {
305301
auto srcTy = dyn_cast<triton::gpu::TensorOrMemDesc>(src.getType());
306-
if (!srcTy) {
307-
llvm::outs() << "INVALID SRC!\n";
308-
}
302+
assert(srcTy);
303+
309304
// We need to ensure we read coalesced into LDS so we adjust the blocked to
310-
// read coalesced for now
305+
// read coalesced
311306

312307
auto shape = subviewTy.getShape();
313308
auto order = sharedEncodingAttr.getOrder();
@@ -325,19 +320,14 @@ void StreamPipeliner::createStreamCopy(tt::LoadOp loadOp, Value alloc,
325320

326321
auto srcEncoding = srcTy.getEncoding();
327322
auto newLayout = ttg::BlockedEncodingAttr::get(
328-
loadOp->getContext(),
329-
sizePerThread, //{1, 1}, // triton::gpu::getSizePerThread(srcEncoding),
330-
threadsPerWarp, //{2, 32}, //
331-
// triton::gpu::getThreadsPerWarp(srcEncoding),
323+
loadOp->getContext(), sizePerThread, threadsPerWarp,
332324
triton::gpu::getWarpsPerCTA(srcEncoding),
333325
triton::gpu::getOrder(srcEncoding),
334326
triton::gpu::getCTALayout(srcEncoding));
335-
llvm::outs() << "New src encoding: ";
336327
newLayout.printStripped(llvm::outs());
337328
llvm::outs() << "\n";
338329
RankedTensorType newArgType = RankedTensorType::get(
339330
srcTy.getShape(), srcTy.getElementType(), newLayout);
340-
llvm::outs() << "Source encoding: ";
341331
srcTy.getEncoding().print(llvm::outs());
342332
llvm::outs() << "\n";
343333
auto cvtSrc =
@@ -377,7 +367,7 @@ void StreamPipeliner::createStreamCopy(tt::LoadOp loadOp, Value alloc,
377367

378368
// Prefetch load ahead of the dot stage if is used by the dot.
379369
Operation *storeOp{};
380-
if (emitAsyncCopy) {
370+
if (useAsyncCopy) {
381371
// FIXME: it should be scheduled as a local_load to hide latency but that
382372
// currently breaks the scheduling as we require one more lds buffer to make
383373
// that work
@@ -391,7 +381,7 @@ void StreamPipeliner::createStreamCopy(tt::LoadOp loadOp, Value alloc,
391381

392382
// Create local load
393383
Operation *sharedLoad{};
394-
if (emitAsyncCopy) {
384+
if (useAsyncCopy) {
395385
// scheduleOp(wait, SCHED_LOCAL_LOAD);
396386
sharedLoad = builder.create<ttg::LocalLoadOp>(loc, loadOp.getType(),
397387
viewLoad, wait->getResult(0));
@@ -409,7 +399,7 @@ void StreamPipeliner::createStreamCopy(tt::LoadOp loadOp, Value alloc,
409399
// instruction scheduling hints to correctly count the emitted `ds_write`
410400
// instructions for each GEMM tile.
411401
if (auto attr = loadOp->getAttr(triton::amdgpu::OpIdxAttr::getMnemonic())) {
412-
if (emitAsyncCopy) {
402+
if (useAsyncCopy) {
413403
newLoadOp->setAttr(triton::amdgpu::OpIdxAttr::getMnemonic(), attr);
414404
} else {
415405
storeOp->setAttr(triton::amdgpu::OpIdxAttr::getMnemonic(), attr);

0 commit comments

Comments
 (0)