@@ -394,16 +394,16 @@ struct BufferLoadOpConversion
394394 }
395395};
396396
397- struct AsyncCopyToGlobalOpConversion
397+ struct AsyncCopyGlobalToLocalOpConversion
398398 : public ConvertOpToLLVMPattern<triton::gpu::AsyncCopyGlobalToLocalOp>,
399399 public LoadStoreConversionBase {
400400 using ConvertOpToLLVMPattern<
401401 triton::gpu::AsyncCopyGlobalToLocalOp>::ConvertOpToLLVMPattern;
402402
403- AsyncCopyToGlobalOpConversion (LLVMTypeConverter &converter,
404- const AMD::TargetInfo &targetInfo,
405- ModuleAxisInfoAnalysis &axisAnalysisPass,
406- PatternBenefit benefit)
403+ AsyncCopyGlobalToLocalOpConversion (LLVMTypeConverter &converter,
404+ const AMD::TargetInfo &targetInfo,
405+ ModuleAxisInfoAnalysis &axisAnalysisPass,
406+ PatternBenefit benefit)
407407 : ConvertOpToLLVMPattern<triton::gpu::AsyncCopyGlobalToLocalOp>(converter,
408408 benefit),
409409 LoadStoreConversionBase (targetInfo, axisAnalysisPass) {}
@@ -437,21 +437,17 @@ struct AsyncCopyToGlobalOpConversion
437437 MLIRContext *ctx = rewriter.getContext ();
438438 auto loc = op.getLoc ();
439439
440- Value mask = op.getMask ();
441- Value other = op.getOther ();
442-
443440 auto srcTy = op.getSrc ().getType ();
444441 auto srcEncoding = srcTy.getEncoding ();
445442 assert ((isa<BlockedEncodingAttr, SliceEncodingAttr>(srcEncoding) &&
446443 " Unexpected srcEncoding in AsyncCopyGlobalToLocalOpConversion" ));
447-
448- auto dstTy = op.getResult ().getType ();
449- auto resElemTy = getTypeConverter ()->convertType (dstTy.getElementType ());
450-
451444 auto srcShape = srcTy.getShape ();
452445 assert (srcShape.size () <= 2 && " Async copy only supports 1d and 2d "
453446 " tensors: Unexpected rank of %src" );
454447
448+ auto dstTy = op.getResult ().getType ();
449+ auto resElemTy = getTypeConverter ()->convertType (dstTy.getElementType ());
450+
455451 Value llSrc = adaptor.getSrc ();
456452
457453 auto srcElems = unpackLLElements (loc, llSrc, rewriter);
@@ -460,21 +456,9 @@ struct AsyncCopyToGlobalOpConversion
460456 auto smemObj = mlir::LLVM::getSharedMemoryObjectFromStruct (
461457 loc, llDst, resElemTy, rewriter);
462458
463- Value llMask = adaptor.getMask ();
464- SmallVector<Value> maskElems;
465- if (llMask) {
466- maskElems = unpackLLElements (loc, llMask, rewriter);
467- assert (srcElems.size () == maskElems.size ());
468- }
469-
470- Value llOther = adaptor.getOther ();
471- SmallVector<Value> otherElems;
472- if (llOther) {
473- otherElems = unpackLLElements (loc, llOther, rewriter);
474- assert (srcElems.size () == otherElems.size ());
475- }
476-
477459 unsigned maxVec = getContiguity (op.getSrc (), axisAnalysisPass);
460+
461+ Value mask = op.getMask ();
478462 if (mask) {
479463 maxVec = std::min (maxVec, getMaskAlignment (mask));
480464 }
@@ -521,13 +505,25 @@ struct AsyncCopyToGlobalOpConversion
521505
522506 int vecBytes = vecBits / 8 ;
523507 assert (llvm::isPowerOf2_32 (vecBytes));
524-
525- std::string intrinsic = " llvm.amdgcn.global.load.lds" ;
526508 Value vecBytesVal = i32_val (vecBytes);
527509
528510 Value cacheModifiers = i32_val (
529511 getCtrlBitsForCacheModifierOnTarget (op.getCache (), false , targetInfo));
530512
513+ Value llMask = adaptor.getMask ();
514+ SmallVector<Value> maskElems;
515+ if (llMask) {
516+ maskElems = unpackLLElements (loc, llMask, rewriter);
517+ assert (srcElems.size () == maskElems.size ());
518+ }
519+
520+ Value other = op.getOther ();
521+ SmallVector<Value> otherElems;
522+ if (other) {
523+ otherElems = unpackLLElements (loc, adaptor.getOther (), rewriter);
524+ assert (srcElems.size () == otherElems.size ());
525+ }
526+
531527 for (int i = 0 ; i < shmemAddrs.size (); i++) {
532528 auto srcIdx = i * maxVec;
533529 auto srcPtr = srcElems[srcIdx];
@@ -1632,7 +1628,7 @@ struct AsyncWaitConversion : public ConvertOpToLLVMPattern<AsyncWaitOp> {
16321628 ConversionPatternRewriter &rewriter) const override {
16331629
16341630 auto loc = op->getLoc ();
1635- rewriter.create <ROCDL::WaitcntOp>(loc, 0 );
1631+ rewriter.create <ROCDL::WaitcntOp>(loc, op. getNum () );
16361632 rewriter.replaceOp (op, i32_val (0 ));
16371633 return success ();
16381634 }
@@ -1651,7 +1647,7 @@ struct AsyncCommitGroupConversion
16511647 LogicalResult
16521648 matchAndRewrite (AsyncCommitGroupOp op, OpAdaptor adaptor,
16531649 ConversionPatternRewriter &rewriter) const override {
1654- // We do not have that concept so simply drop it
1650+ // Drop the result token
16551651 auto loc = op->getLoc ();
16561652 rewriter.replaceOp (op, i32_val (0 ));
16571653 return success ();
@@ -1670,7 +1666,7 @@ void populateLoadStoreOpToLLVMPatterns(LLVMTypeConverter &typeConverter,
16701666 patterns.add <AtomicCASOpConversion, AtomicRMWOpConversion, LoadOpConversion,
16711667 StoreOpConversion, BufferLoadOpConversion,
16721668 BufferStoreOpConversion, BufferAtomicRMWOpConversion,
1673- AsyncCopyToGlobalOpConversion , AsyncCommitGroupConversion,
1669+ AsyncCopyGlobalToLocalOpConversion , AsyncCommitGroupConversion,
16741670 AsyncWaitConversion, AsyncCommitGroupConversion>(
16751671 typeConverter, targetInfo, axisInfoAnalysis, benefit);
16761672}
0 commit comments