1010#include " triton/Dialect/TritonGPU/Transforms/PipeliningUtility.h"
1111#include " triton/Dialect/TritonGPU/Transforms/Schedule.h"
1212#include " triton/Dialect/TritonGPU/Transforms/Utility.h"
13- #include " triton/Tools/Sys/GetEnv.hpp"
1413#include " llvm/Support/Debug.h"
1514
1615// ===----------------------------------------------------------------------===//
@@ -258,25 +257,12 @@ void StreamPipeliner::createStreamCopy(tt::LoadOp loadOp, Value alloc,
258257 Value other = loadOp.getOther ();
259258
260259 ttg::MemDescType allocTy = cast<ttg::MemDescType>(alloc.getType ());
261-
262- auto sharedEncodingAttr =
263- cast<ttg::SharedEncodingAttr>(allocTy.getEncoding ());
264- auto srcTy = dyn_cast<triton::gpu::TensorOrMemDesc>(src.getType ());
265-
266- bool useAsyncCopy = false ;
267-
268- // Note that we can only use AsyncCopy when have coalesced LDS writes (e.g. no
269- // swizzeling).
270- if (triton::tools::getBoolEnv (" AMDGCN_USE_DIRECT_TO_LDS" ) &&
271- sharedEncodingAttr.getPerPhase () == 1 &&
272- sharedEncodingAttr.getMaxPhase () == 1 &&
273- sharedEncodingAttr.getOrder ().size () == 2 &&
274- llvm::equal (sharedEncodingAttr.getOrder (),
275- ttg::getOrder (srcTy.getEncoding ()))) {
276- useAsyncCopy = true ;
277- }
278-
279260 SmallVector<Value> copyOffsets (allocTy.getRank (), zero);
261+ Operation *copy = builder.clone (*loadOp);
262+
263+ auto [stage, cluster] = schedule[loadOp];
264+ schedule.erase (loadOp);
265+ schedule.insert (copy, stage, cluster);
280266
281267 // Extract part.
282268 SmallVector<Value> loadOffsets (allocTy.getRank (), zero);
@@ -288,72 +274,6 @@ void StreamPipeliner::createStreamCopy(tt::LoadOp loadOp, Value alloc,
288274 allocTy.getEncoding (), sharedMemorySpace, /* mutableMemory=*/ true );
289275 auto viewLoad =
290276 builder.create <ttg::MemDescSubviewOp>(loc, subviewTy, alloc, loadOffsets);
291-
292- Operation *newLoadOp{};
293- Operation *wait{};
294-
295- if (!useAsyncCopy) {
296- newLoadOp = builder.clone (*loadOp);
297- auto [stage, cluster] = schedule[loadOp];
298- schedule.erase (loadOp);
299- schedule.insert (newLoadOp, stage, cluster);
300- } else {
301- auto srcTy = dyn_cast<triton::gpu::TensorOrMemDesc>(src.getType ());
302- assert (srcTy);
303-
304- // We need to ensure we read coalesced into LDS so we adjust the blocked to
305- // read coalesced
306-
307- auto shape = subviewTy.getShape ();
308- auto order = sharedEncodingAttr.getOrder ();
309- // Aim to use wider loads
310- llvm::SmallVector<unsigned , 2 > sizePerThread{1 , 1 };
311- sizePerThread[order[0 ]] =
312- 32 / allocTy.getElementType ().getIntOrFloatBitWidth ();
313- llvm::SmallVector<unsigned , 2 > threadsPerWarp{1 , 1 };
314- assert ((shape[order[0 ]] % sizePerThread[0 ]) == 0 );
315- unsigned warpSize = 64 ;
316- threadsPerWarp[order[0 ]] =
317- std::min<unsigned >(warpSize, shape[order[0 ]] / sizePerThread[order[0 ]]);
318- threadsPerWarp[order[1 ]] =
319- std::max<unsigned >(1 , warpSize / threadsPerWarp[order[0 ]]);
320-
321- auto srcEncoding = srcTy.getEncoding ();
322- auto newLayout = ttg::BlockedEncodingAttr::get (
323- loadOp->getContext (), sizePerThread, threadsPerWarp,
324- triton::gpu::getWarpsPerCTA (srcEncoding),
325- triton::gpu::getOrder (srcEncoding),
326- triton::gpu::getCTALayout (srcEncoding));
327- newLayout.printStripped (llvm::outs ());
328- llvm::outs () << " \n " ;
329- RankedTensorType newArgType = RankedTensorType::get (
330- srcTy.getShape (), srcTy.getElementType (), newLayout);
331- srcTy.getEncoding ().print (llvm::outs ());
332- llvm::outs () << " \n " ;
333- auto cvtSrc =
334- builder.create <ttg::ConvertLayoutOp>(loadOp.getLoc (), newArgType, src);
335-
336- auto mask = loadOp.getMask ();
337- if (mask) {
338- auto maskTy = dyn_cast<triton::gpu::TensorOrMemDesc>(mask.getType ());
339- RankedTensorType newMaskTy = RankedTensorType::get (
340- maskTy.getShape (), maskTy.getElementType (), newLayout);
341- auto cvtMask = builder.create <ttg::ConvertLayoutOp>(
342- loadOp->getLoc (), newMaskTy, loadOp.getMask ());
343- }
344-
345- newLoadOp = builder.create <ttg::AsyncCopyGlobalToLocalOp>(
346- loadOp.getLoc (), cvtSrc.getResult (), viewLoad, mask, other,
347- loadOp.getCache (), loadOp.getEvict (), loadOp.getIsVolatile ());
348-
349- wait = builder.create <ttg::AsyncWaitOp>(loc, newLoadOp->getResult (0 ), 0 );
350-
351- auto [stage, cluster] = schedule[loadOp];
352- schedule.erase (loadOp);
353- schedule.insert (cvtSrc, stage, cluster);
354- schedule.insert (newLoadOp, stage, cluster);
355- }
356-
357277 // Clean up old local caches.
358278 SmallVector<ttg::LocalAllocOp> allocsToErase;
359279 for (Operation *user : loadOp->getUsers ()) {
@@ -366,30 +286,15 @@ void StreamPipeliner::createStreamCopy(tt::LoadOp loadOp, Value alloc,
366286 alloc.erase ();
367287
368288 // Prefetch load ahead of the dot stage if is used by the dot.
369- Operation *storeOp{};
370- if (useAsyncCopy) {
371- // FIXME: it should be scheduled as a local_load to hide latency but that
372- // currently breaks the scheduling as we require one more lds buffer to make
373- // that work
374- scheduleOp (newLoadOp, SCHED_LOCAL_STORE);
375- } else {
376- storeOp = builder.create <ttg::LocalStoreOp>(loc, newLoadOp->getResult (0 ),
377- viewLoad);
378- scheduleOp (viewLoad, SCHED_LOCAL_STORE);
379- scheduleOp (storeOp, SCHED_LOCAL_STORE);
380- }
289+ auto storeOp =
290+ builder.create <ttg::LocalStoreOp>(loc, copy->getResult (0 ), viewLoad);
291+ scheduleOp (viewLoad, SCHED_LOCAL_STORE);
292+ scheduleOp (storeOp, SCHED_LOCAL_STORE);
381293
382294 // Create local load
383- Operation *sharedLoad{};
384- if (useAsyncCopy) {
385- // scheduleOp(wait, SCHED_LOCAL_LOAD);
386- sharedLoad = builder.create <ttg::LocalLoadOp>(loc, loadOp.getType (),
387- viewLoad, wait->getResult (0 ));
388- } else {
389- sharedLoad =
390- builder.create <ttg::LocalLoadOp>(loc, loadOp.getType (), viewLoad);
391- }
392- Value result = sharedLoad->getResult (0 );
295+ auto sharedLoad =
296+ builder.create <ttg::LocalLoadOp>(loc, loadOp.getType (), viewLoad);
297+ Value result = sharedLoad.getResult ();
393298 if (prefetch)
394299 scheduleOp (sharedLoad, SCHED_LOCAL_LOAD);
395300
@@ -399,11 +304,7 @@ void StreamPipeliner::createStreamCopy(tt::LoadOp loadOp, Value alloc,
399304 // instruction scheduling hints to correctly count the emitted `ds_write`
400305 // instructions for each GEMM tile.
401306 if (auto attr = loadOp->getAttr (triton::amdgpu::OpIdxAttr::getMnemonic ())) {
402- if (useAsyncCopy) {
403- newLoadOp->setAttr (triton::amdgpu::OpIdxAttr::getMnemonic (), attr);
404- } else {
405- storeOp->setAttr (triton::amdgpu::OpIdxAttr::getMnemonic (), attr);
406- }
307+ storeOp->setAttr (triton::amdgpu::OpIdxAttr::getMnemonic (), attr);
407308 }
408309
409310 loadOp->replaceAllUsesWith (ValueRange{result});
0 commit comments