@@ -261,34 +261,23 @@ void StreamPipeliner::createStreamCopy(tt::LoadOp loadOp, Value alloc,
261261
262262 auto sharedEncodingAttr =
263263 cast<ttg::SharedEncodingAttr>(allocTy.getEncoding ());
264- llvm::outs () << " Shared alloc: \n " ;
265- alloc.print (llvm::outs ());
266- llvm::outs () << " \n " ;
264+ auto srcTy = dyn_cast<triton::gpu::TensorOrMemDesc>(src.getType ());
267265
268- bool emitAsyncCopy = false ;
266+ bool useAsyncCopy = false ;
269267
270- auto srcTy = dyn_cast<triton::gpu::TensorOrMemDesc>(src.getType ());
271- // We can use AsyncCopy if we do not swizzle into smem
272- // TODO (alex) ensure it's 2D
268+ // Note that we can only use AsyncCopy when have coalesced LDS writes (e.g. no
269+ // swizzeling).
273270 if (triton::tools::getBoolEnv (" AMDGCN_USE_DIRECT_TO_LDS" ) &&
274271 sharedEncodingAttr.getPerPhase () == 1 &&
275272 sharedEncodingAttr.getMaxPhase () == 1 &&
273+ sharedEncodingAttr.getOrder ().size () == 2 &&
276274 llvm::equal (sharedEncodingAttr.getOrder (),
277275 ttg::getOrder (srcTy.getEncoding ()))) {
278- emitAsyncCopy = true ;
276+ useAsyncCopy = true ;
279277 }
280- llvm::outs () << " Emit async: " << emitAsyncCopy << " \n " ;
281278
282279 SmallVector<Value> copyOffsets (allocTy.getRank (), zero);
283280
284- Operation *newLoadOp{};
285- if (!emitAsyncCopy) {
286- newLoadOp = builder.clone (*loadOp);
287- auto [stage, cluster] = schedule[loadOp];
288- schedule.erase (loadOp);
289- schedule.insert (newLoadOp, stage, cluster);
290- }
291-
292281 // Extract part.
293282 SmallVector<Value> loadOffsets (allocTy.getRank (), zero);
294283 loadOffsets[0 ] = extractIdx;
@@ -300,14 +289,20 @@ void StreamPipeliner::createStreamCopy(tt::LoadOp loadOp, Value alloc,
300289 auto viewLoad =
301290 builder.create <ttg::MemDescSubviewOp>(loc, subviewTy, alloc, loadOffsets);
302291
292+ Operation *newLoadOp{};
303293 Operation *wait{};
304- if (emitAsyncCopy) {
294+
295+ if (!useAsyncCopy) {
296+ newLoadOp = builder.clone (*loadOp);
297+ auto [stage, cluster] = schedule[loadOp];
298+ schedule.erase (loadOp);
299+ schedule.insert (newLoadOp, stage, cluster);
300+ } else {
305301 auto srcTy = dyn_cast<triton::gpu::TensorOrMemDesc>(src.getType ());
306- if (!srcTy) {
307- llvm::outs () << " INVALID SRC!\n " ;
308- }
302+ assert (srcTy);
303+
309304 // We need to ensure we read coalesced into LDS so we adjust the blocked to
310- // read coalesced for now
305+ // read coalesced
311306
312307 auto shape = subviewTy.getShape ();
313308 auto order = sharedEncodingAttr.getOrder ();
@@ -325,19 +320,14 @@ void StreamPipeliner::createStreamCopy(tt::LoadOp loadOp, Value alloc,
325320
326321 auto srcEncoding = srcTy.getEncoding ();
327322 auto newLayout = ttg::BlockedEncodingAttr::get (
328- loadOp->getContext (),
329- sizePerThread, // {1, 1}, // triton::gpu::getSizePerThread(srcEncoding),
330- threadsPerWarp, // {2, 32}, //
331- // triton::gpu::getThreadsPerWarp(srcEncoding),
323+ loadOp->getContext (), sizePerThread, threadsPerWarp,
332324 triton::gpu::getWarpsPerCTA (srcEncoding),
333325 triton::gpu::getOrder (srcEncoding),
334326 triton::gpu::getCTALayout (srcEncoding));
335- llvm::outs () << " New src encoding: " ;
336327 newLayout.printStripped (llvm::outs ());
337328 llvm::outs () << " \n " ;
338329 RankedTensorType newArgType = RankedTensorType::get (
339330 srcTy.getShape (), srcTy.getElementType (), newLayout);
340- llvm::outs () << " Source encoding: " ;
341331 srcTy.getEncoding ().print (llvm::outs ());
342332 llvm::outs () << " \n " ;
343333 auto cvtSrc =
@@ -377,7 +367,7 @@ void StreamPipeliner::createStreamCopy(tt::LoadOp loadOp, Value alloc,
377367
378368 // Prefetch load ahead of the dot stage if is used by the dot.
379369 Operation *storeOp{};
380- if (emitAsyncCopy ) {
370+ if (useAsyncCopy ) {
381371 // FIXME: it should be scheduled as a local_load to hide latency but that
382372 // currently breaks the scheduling as we require one more lds buffer to make
383373 // that work
@@ -391,7 +381,7 @@ void StreamPipeliner::createStreamCopy(tt::LoadOp loadOp, Value alloc,
391381
392382 // Create local load
393383 Operation *sharedLoad{};
394- if (emitAsyncCopy ) {
384+ if (useAsyncCopy ) {
395385 // scheduleOp(wait, SCHED_LOCAL_LOAD);
396386 sharedLoad = builder.create <ttg::LocalLoadOp>(loc, loadOp.getType (),
397387 viewLoad, wait->getResult (0 ));
@@ -409,7 +399,7 @@ void StreamPipeliner::createStreamCopy(tt::LoadOp loadOp, Value alloc,
409399 // instruction scheduling hints to correctly count the emitted `ds_write`
410400 // instructions for each GEMM tile.
411401 if (auto attr = loadOp->getAttr (triton::amdgpu::OpIdxAttr::getMnemonic ())) {
412- if (emitAsyncCopy ) {
402+ if (useAsyncCopy ) {
413403 newLoadOp->setAttr (triton::amdgpu::OpIdxAttr::getMnemonic (), attr);
414404 } else {
415405 storeOp->setAttr (triton::amdgpu::OpIdxAttr::getMnemonic (), attr);
0 commit comments