@@ -399,16 +399,16 @@ struct BufferLoadOpConversion
399399 }
400400};
401401
402- struct AsyncCopyToGlobalOpConversion
402+ struct AsyncCopyGlobalToLocalOpConversion
403403 : public ConvertOpToLLVMPattern<triton::gpu::AsyncCopyGlobalToLocalOp>,
404404 public LoadStoreConversionBase {
405405 using ConvertOpToLLVMPattern<
406406 triton::gpu::AsyncCopyGlobalToLocalOp>::ConvertOpToLLVMPattern;
407407
408- AsyncCopyToGlobalOpConversion (LLVMTypeConverter &converter,
409- const AMD::TargetInfo &targetInfo,
410- ModuleAxisInfoAnalysis &axisAnalysisPass,
411- PatternBenefit benefit)
408+ AsyncCopyGlobalToLocalOpConversion (LLVMTypeConverter &converter,
409+ const AMD::TargetInfo &targetInfo,
410+ ModuleAxisInfoAnalysis &axisAnalysisPass,
411+ PatternBenefit benefit)
412412 : ConvertOpToLLVMPattern<triton::gpu::AsyncCopyGlobalToLocalOp>(converter,
413413 benefit),
414414 LoadStoreConversionBase (targetInfo, axisAnalysisPass) {}
@@ -442,21 +442,17 @@ struct AsyncCopyToGlobalOpConversion
442442 MLIRContext *ctx = rewriter.getContext ();
443443 auto loc = op.getLoc ();
444444
445- Value mask = op.getMask ();
446- Value other = op.getOther ();
447-
448445 auto srcTy = op.getSrc ().getType ();
449446 auto srcEncoding = srcTy.getEncoding ();
450447 assert ((isa<BlockedEncodingAttr, SliceEncodingAttr>(srcEncoding) &&
451448 " Unexpected srcEncoding in AsyncCopyGlobalToLocalOpConversion" ));
452-
453- auto dstTy = op.getResult ().getType ();
454- auto resElemTy = getTypeConverter ()->convertType (dstTy.getElementType ());
455-
456449 auto srcShape = srcTy.getShape ();
457450 assert (srcShape.size () <= 2 && " Async copy only supports 1d and 2d "
458451 " tensors: Unexpected rank of %src" );
459452
453+ auto dstTy = op.getResult ().getType ();
454+ auto resElemTy = getTypeConverter ()->convertType (dstTy.getElementType ());
455+
460456 Value llSrc = adaptor.getSrc ();
461457
462458 auto srcElems = unpackLLElements (loc, llSrc, rewriter);
@@ -465,21 +461,9 @@ struct AsyncCopyToGlobalOpConversion
465461 auto smemObj = mlir::LLVM::getSharedMemoryObjectFromStruct (
466462 loc, llDst, resElemTy, rewriter);
467463
468- Value llMask = adaptor.getMask ();
469- SmallVector<Value> maskElems;
470- if (llMask) {
471- maskElems = unpackLLElements (loc, llMask, rewriter);
472- assert (srcElems.size () == maskElems.size ());
473- }
474-
475- Value llOther = adaptor.getOther ();
476- SmallVector<Value> otherElems;
477- if (llOther) {
478- otherElems = unpackLLElements (loc, llOther, rewriter);
479- assert (srcElems.size () == otherElems.size ());
480- }
481-
482464 unsigned maxVec = getContiguity (op.getSrc (), axisAnalysisPass);
465+
466+ Value mask = op.getMask ();
483467 if (mask) {
484468 maxVec = std::min (maxVec, getMaskAlignment (mask));
485469 }
@@ -526,13 +510,25 @@ struct AsyncCopyToGlobalOpConversion
526510
527511 int vecBytes = vecBits / 8 ;
528512 assert (llvm::isPowerOf2_32 (vecBytes));
529-
530- std::string intrinsic = " llvm.amdgcn.global.load.lds" ;
531513 Value vecBytesVal = i32_val (vecBytes);
532514
533515 Value cacheModifiers = i32_val (
534516 getCtrlBitsForCacheModifierOnTarget (op.getCache (), false , targetInfo));
535517
518+ Value llMask = adaptor.getMask ();
519+ SmallVector<Value> maskElems;
520+ if (llMask) {
521+ maskElems = unpackLLElements (loc, llMask, rewriter);
522+ assert (srcElems.size () == maskElems.size ());
523+ }
524+
525+ Value other = op.getOther ();
526+ SmallVector<Value> otherElems;
527+ if (other) {
528+ otherElems = unpackLLElements (loc, adaptor.getOther (), rewriter);
529+ assert (srcElems.size () == otherElems.size ());
530+ }
531+
536532 for (int i = 0 ; i < shmemAddrs.size (); i++) {
537533 auto srcIdx = i * maxVec;
538534 auto srcPtr = srcElems[srcIdx];
@@ -1652,7 +1648,7 @@ struct AsyncWaitConversion : public ConvertOpToLLVMPattern<AsyncWaitOp> {
16521648 ConversionPatternRewriter &rewriter) const override {
16531649
16541650 auto loc = op->getLoc ();
1655- rewriter.create <ROCDL::WaitcntOp>(loc, 0 );
1651+ rewriter.create <ROCDL::WaitcntOp>(loc, op. getNum () );
16561652 rewriter.replaceOp (op, i32_val (0 ));
16571653 return success ();
16581654 }
@@ -1671,7 +1667,7 @@ struct AsyncCommitGroupConversion
16711667 LogicalResult
16721668 matchAndRewrite (AsyncCommitGroupOp op, OpAdaptor adaptor,
16731669 ConversionPatternRewriter &rewriter) const override {
1674- // We do not have that concept so simply drop it
1670+ // Drop the result token
16751671 auto loc = op->getLoc ();
16761672 rewriter.replaceOp (op, i32_val (0 ));
16771673 return success ();
@@ -1690,7 +1686,7 @@ void populateLoadStoreOpToLLVMPatterns(LLVMTypeConverter &typeConverter,
16901686 patterns.add <AtomicCASOpConversion, AtomicRMWOpConversion, LoadOpConversion,
16911687 StoreOpConversion, BufferLoadOpConversion,
16921688 BufferStoreOpConversion, BufferAtomicRMWOpConversion,
1693- AsyncCopyToGlobalOpConversion , AsyncCommitGroupConversion,
1689+ AsyncCopyGlobalToLocalOpConversion , AsyncCommitGroupConversion,
16941690 AsyncWaitConversion, AsyncCommitGroupConversion>(
16951691 typeConverter, targetInfo, axisInfoAnalysis, benefit);
16961692}
0 commit comments