Skip to content

Commit ca2b17a

Browse files
committed
Revert stream pipeline related changes
1 parent e0917de commit ca2b17a

File tree

2 files changed

+13
-113
lines changed

2 files changed

+13
-113
lines changed

include/triton/Tools/Sys/GetEnv.hpp

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,6 @@ inline const std::set<std::string> CACHE_INVALIDATING_ENV_VARS = {
1414
// clang-format off
1515
"AMDGCN_ENABLE_DUMP",
1616
"AMDGCN_USE_BUFFER_OPS",
17-
"AMDGCN_USE_DIRECT_TO_LDS",
1817
"DISABLE_FAST_REDUCTION",
1918
"DISABLE_LLVM_OPT",
2019
"DISABLE_MMA_V3",

third_party/amd/lib/TritonAMDGPUTransforms/StreamPipeline.cpp

Lines changed: 13 additions & 112 deletions
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,6 @@
1010
#include "triton/Dialect/TritonGPU/Transforms/PipeliningUtility.h"
1111
#include "triton/Dialect/TritonGPU/Transforms/Schedule.h"
1212
#include "triton/Dialect/TritonGPU/Transforms/Utility.h"
13-
#include "triton/Tools/Sys/GetEnv.hpp"
1413
#include "llvm/Support/Debug.h"
1514

1615
//===----------------------------------------------------------------------===//
@@ -258,25 +257,12 @@ void StreamPipeliner::createStreamCopy(tt::LoadOp loadOp, Value alloc,
258257
Value other = loadOp.getOther();
259258

260259
ttg::MemDescType allocTy = cast<ttg::MemDescType>(alloc.getType());
261-
262-
auto sharedEncodingAttr =
263-
cast<ttg::SharedEncodingAttr>(allocTy.getEncoding());
264-
auto srcTy = dyn_cast<triton::gpu::TensorOrMemDesc>(src.getType());
265-
266-
bool useAsyncCopy = false;
267-
268-
// Note that we can only use AsyncCopy when have coalesced LDS writes (e.g. no
269-
// swizzeling).
270-
if (triton::tools::getBoolEnv("AMDGCN_USE_DIRECT_TO_LDS") &&
271-
sharedEncodingAttr.getPerPhase() == 1 &&
272-
sharedEncodingAttr.getMaxPhase() == 1 &&
273-
sharedEncodingAttr.getOrder().size() == 2 &&
274-
llvm::equal(sharedEncodingAttr.getOrder(),
275-
ttg::getOrder(srcTy.getEncoding()))) {
276-
useAsyncCopy = true;
277-
}
278-
279260
SmallVector<Value> copyOffsets(allocTy.getRank(), zero);
261+
Operation *copy = builder.clone(*loadOp);
262+
263+
auto [stage, cluster] = schedule[loadOp];
264+
schedule.erase(loadOp);
265+
schedule.insert(copy, stage, cluster);
280266

281267
// Extract part.
282268
SmallVector<Value> loadOffsets(allocTy.getRank(), zero);
@@ -288,72 +274,6 @@ void StreamPipeliner::createStreamCopy(tt::LoadOp loadOp, Value alloc,
288274
allocTy.getEncoding(), sharedMemorySpace, /*mutableMemory=*/true);
289275
auto viewLoad =
290276
builder.create<ttg::MemDescSubviewOp>(loc, subviewTy, alloc, loadOffsets);
291-
292-
Operation *newLoadOp{};
293-
Operation *wait{};
294-
295-
if (!useAsyncCopy) {
296-
newLoadOp = builder.clone(*loadOp);
297-
auto [stage, cluster] = schedule[loadOp];
298-
schedule.erase(loadOp);
299-
schedule.insert(newLoadOp, stage, cluster);
300-
} else {
301-
auto srcTy = dyn_cast<triton::gpu::TensorOrMemDesc>(src.getType());
302-
assert(srcTy);
303-
304-
// We need to ensure we read coalesced into LDS so we adjust the blocked to
305-
// read coalesced
306-
307-
auto shape = subviewTy.getShape();
308-
auto order = sharedEncodingAttr.getOrder();
309-
// Aim to use wider loads
310-
llvm::SmallVector<unsigned, 2> sizePerThread{1, 1};
311-
sizePerThread[order[0]] =
312-
32 / allocTy.getElementType().getIntOrFloatBitWidth();
313-
llvm::SmallVector<unsigned, 2> threadsPerWarp{1, 1};
314-
assert((shape[order[0]] % sizePerThread[0]) == 0);
315-
unsigned warpSize = 64;
316-
threadsPerWarp[order[0]] =
317-
std::min<unsigned>(warpSize, shape[order[0]] / sizePerThread[order[0]]);
318-
threadsPerWarp[order[1]] =
319-
std::max<unsigned>(1, warpSize / threadsPerWarp[order[0]]);
320-
321-
auto srcEncoding = srcTy.getEncoding();
322-
auto newLayout = ttg::BlockedEncodingAttr::get(
323-
loadOp->getContext(), sizePerThread, threadsPerWarp,
324-
triton::gpu::getWarpsPerCTA(srcEncoding),
325-
triton::gpu::getOrder(srcEncoding),
326-
triton::gpu::getCTALayout(srcEncoding));
327-
newLayout.printStripped(llvm::outs());
328-
llvm::outs() << "\n";
329-
RankedTensorType newArgType = RankedTensorType::get(
330-
srcTy.getShape(), srcTy.getElementType(), newLayout);
331-
srcTy.getEncoding().print(llvm::outs());
332-
llvm::outs() << "\n";
333-
auto cvtSrc =
334-
builder.create<ttg::ConvertLayoutOp>(loadOp.getLoc(), newArgType, src);
335-
336-
auto mask = loadOp.getMask();
337-
if (mask) {
338-
auto maskTy = dyn_cast<triton::gpu::TensorOrMemDesc>(mask.getType());
339-
RankedTensorType newMaskTy = RankedTensorType::get(
340-
maskTy.getShape(), maskTy.getElementType(), newLayout);
341-
auto cvtMask = builder.create<ttg::ConvertLayoutOp>(
342-
loadOp->getLoc(), newMaskTy, loadOp.getMask());
343-
}
344-
345-
newLoadOp = builder.create<ttg::AsyncCopyGlobalToLocalOp>(
346-
loadOp.getLoc(), cvtSrc.getResult(), viewLoad, mask, other,
347-
loadOp.getCache(), loadOp.getEvict(), loadOp.getIsVolatile());
348-
349-
wait = builder.create<ttg::AsyncWaitOp>(loc, newLoadOp->getResult(0), 0);
350-
351-
auto [stage, cluster] = schedule[loadOp];
352-
schedule.erase(loadOp);
353-
schedule.insert(cvtSrc, stage, cluster);
354-
schedule.insert(newLoadOp, stage, cluster);
355-
}
356-
357277
// Clean up old local caches.
358278
SmallVector<ttg::LocalAllocOp> allocsToErase;
359279
for (Operation *user : loadOp->getUsers()) {
@@ -366,30 +286,15 @@ void StreamPipeliner::createStreamCopy(tt::LoadOp loadOp, Value alloc,
366286
alloc.erase();
367287

368288
// Prefetch load ahead of the dot stage if is used by the dot.
369-
Operation *storeOp{};
370-
if (useAsyncCopy) {
371-
// FIXME: it should be scheduled as a local_load to hide latency but that
372-
// currently breaks the scheduling as we require one more lds buffer to make
373-
// that work
374-
scheduleOp(newLoadOp, SCHED_LOCAL_STORE);
375-
} else {
376-
storeOp = builder.create<ttg::LocalStoreOp>(loc, newLoadOp->getResult(0),
377-
viewLoad);
378-
scheduleOp(viewLoad, SCHED_LOCAL_STORE);
379-
scheduleOp(storeOp, SCHED_LOCAL_STORE);
380-
}
289+
auto storeOp =
290+
builder.create<ttg::LocalStoreOp>(loc, copy->getResult(0), viewLoad);
291+
scheduleOp(viewLoad, SCHED_LOCAL_STORE);
292+
scheduleOp(storeOp, SCHED_LOCAL_STORE);
381293

382294
// Create local load
383-
Operation *sharedLoad{};
384-
if (useAsyncCopy) {
385-
// scheduleOp(wait, SCHED_LOCAL_LOAD);
386-
sharedLoad = builder.create<ttg::LocalLoadOp>(loc, loadOp.getType(),
387-
viewLoad, wait->getResult(0));
388-
} else {
389-
sharedLoad =
390-
builder.create<ttg::LocalLoadOp>(loc, loadOp.getType(), viewLoad);
391-
}
392-
Value result = sharedLoad->getResult(0);
295+
auto sharedLoad =
296+
builder.create<ttg::LocalLoadOp>(loc, loadOp.getType(), viewLoad);
297+
Value result = sharedLoad.getResult();
393298
if (prefetch)
394299
scheduleOp(sharedLoad, SCHED_LOCAL_LOAD);
395300

@@ -399,11 +304,7 @@ void StreamPipeliner::createStreamCopy(tt::LoadOp loadOp, Value alloc,
399304
// instruction scheduling hints to correctly count the emitted `ds_write`
400305
// instructions for each GEMM tile.
401306
if (auto attr = loadOp->getAttr(triton::amdgpu::OpIdxAttr::getMnemonic())) {
402-
if (useAsyncCopy) {
403-
newLoadOp->setAttr(triton::amdgpu::OpIdxAttr::getMnemonic(), attr);
404-
} else {
405-
storeOp->setAttr(triton::amdgpu::OpIdxAttr::getMnemonic(), attr);
406-
}
307+
storeOp->setAttr(triton::amdgpu::OpIdxAttr::getMnemonic(), attr);
407308
}
408309

409310
loadOp->replaceAllUsesWith(ValueRange{result});

0 commit comments

Comments
 (0)