Skip to content

Commit 9d346d0

Browse files
committed
Cleanup
1 parent be84d1e commit 9d346d0

File tree

1 file changed

+27
-31
lines changed

1 file changed

+27
-31
lines changed

third_party/amd/lib/TritonAMDGPUToLLVM/LoadStoreOpToLLVM.cpp

Lines changed: 27 additions & 31 deletions
Original file line numberDiff line numberDiff line change
@@ -394,16 +394,16 @@ struct BufferLoadOpConversion
394394
}
395395
};
396396

397-
struct AsyncCopyToGlobalOpConversion
397+
struct AsyncCopyGlobalToLocalOpConversion
398398
: public ConvertOpToLLVMPattern<triton::gpu::AsyncCopyGlobalToLocalOp>,
399399
public LoadStoreConversionBase {
400400
using ConvertOpToLLVMPattern<
401401
triton::gpu::AsyncCopyGlobalToLocalOp>::ConvertOpToLLVMPattern;
402402

403-
AsyncCopyToGlobalOpConversion(LLVMTypeConverter &converter,
404-
const AMD::TargetInfo &targetInfo,
405-
ModuleAxisInfoAnalysis &axisAnalysisPass,
406-
PatternBenefit benefit)
403+
AsyncCopyGlobalToLocalOpConversion(LLVMTypeConverter &converter,
404+
const AMD::TargetInfo &targetInfo,
405+
ModuleAxisInfoAnalysis &axisAnalysisPass,
406+
PatternBenefit benefit)
407407
: ConvertOpToLLVMPattern<triton::gpu::AsyncCopyGlobalToLocalOp>(converter,
408408
benefit),
409409
LoadStoreConversionBase(targetInfo, axisAnalysisPass) {}
@@ -437,21 +437,17 @@ struct AsyncCopyToGlobalOpConversion
437437
MLIRContext *ctx = rewriter.getContext();
438438
auto loc = op.getLoc();
439439

440-
Value mask = op.getMask();
441-
Value other = op.getOther();
442-
443440
auto srcTy = op.getSrc().getType();
444441
auto srcEncoding = srcTy.getEncoding();
445442
assert((isa<BlockedEncodingAttr, SliceEncodingAttr>(srcEncoding) &&
446443
"Unexpected srcEncoding in AsyncCopyGlobalToLocalOpConversion"));
447-
448-
auto dstTy = op.getResult().getType();
449-
auto resElemTy = getTypeConverter()->convertType(dstTy.getElementType());
450-
451444
auto srcShape = srcTy.getShape();
452445
assert(srcShape.size() <= 2 && "Async copy only supports 1d and 2d "
453446
"tensors: Unexpected rank of %src");
454447

448+
auto dstTy = op.getResult().getType();
449+
auto resElemTy = getTypeConverter()->convertType(dstTy.getElementType());
450+
455451
Value llSrc = adaptor.getSrc();
456452

457453
auto srcElems = unpackLLElements(loc, llSrc, rewriter);
@@ -460,21 +456,9 @@ struct AsyncCopyToGlobalOpConversion
460456
auto smemObj = mlir::LLVM::getSharedMemoryObjectFromStruct(
461457
loc, llDst, resElemTy, rewriter);
462458

463-
Value llMask = adaptor.getMask();
464-
SmallVector<Value> maskElems;
465-
if (llMask) {
466-
maskElems = unpackLLElements(loc, llMask, rewriter);
467-
assert(srcElems.size() == maskElems.size());
468-
}
469-
470-
Value llOther = adaptor.getOther();
471-
SmallVector<Value> otherElems;
472-
if (llOther) {
473-
otherElems = unpackLLElements(loc, llOther, rewriter);
474-
assert(srcElems.size() == otherElems.size());
475-
}
476-
477459
unsigned maxVec = getContiguity(op.getSrc(), axisAnalysisPass);
460+
461+
Value mask = op.getMask();
478462
if (mask) {
479463
maxVec = std::min(maxVec, getMaskAlignment(mask));
480464
}
@@ -521,13 +505,25 @@ struct AsyncCopyToGlobalOpConversion
521505

522506
int vecBytes = vecBits / 8;
523507
assert(llvm::isPowerOf2_32(vecBytes));
524-
525-
std::string intrinsic = "llvm.amdgcn.global.load.lds";
526508
Value vecBytesVal = i32_val(vecBytes);
527509

528510
Value cacheModifiers = i32_val(
529511
getCtrlBitsForCacheModifierOnTarget(op.getCache(), false, targetInfo));
530512

513+
Value llMask = adaptor.getMask();
514+
SmallVector<Value> maskElems;
515+
if (llMask) {
516+
maskElems = unpackLLElements(loc, llMask, rewriter);
517+
assert(srcElems.size() == maskElems.size());
518+
}
519+
520+
Value other = op.getOther();
521+
SmallVector<Value> otherElems;
522+
if (other) {
523+
otherElems = unpackLLElements(loc, adaptor.getOther(), rewriter);
524+
assert(srcElems.size() == otherElems.size());
525+
}
526+
531527
for (int i = 0; i < shmemAddrs.size(); i++) {
532528
auto srcIdx = i * maxVec;
533529
auto srcPtr = srcElems[srcIdx];
@@ -1632,7 +1628,7 @@ struct AsyncWaitConversion : public ConvertOpToLLVMPattern<AsyncWaitOp> {
16321628
ConversionPatternRewriter &rewriter) const override {
16331629

16341630
auto loc = op->getLoc();
1635-
rewriter.create<ROCDL::WaitcntOp>(loc, 0);
1631+
rewriter.create<ROCDL::WaitcntOp>(loc, op.getNum());
16361632
rewriter.replaceOp(op, i32_val(0));
16371633
return success();
16381634
}
@@ -1651,7 +1647,7 @@ struct AsyncCommitGroupConversion
16511647
LogicalResult
16521648
matchAndRewrite(AsyncCommitGroupOp op, OpAdaptor adaptor,
16531649
ConversionPatternRewriter &rewriter) const override {
1654-
// We do not have that concept so simply drop it
1650+
// Drop the result token
16551651
auto loc = op->getLoc();
16561652
rewriter.replaceOp(op, i32_val(0));
16571653
return success();
@@ -1670,7 +1666,7 @@ void populateLoadStoreOpToLLVMPatterns(LLVMTypeConverter &typeConverter,
16701666
patterns.add<AtomicCASOpConversion, AtomicRMWOpConversion, LoadOpConversion,
16711667
StoreOpConversion, BufferLoadOpConversion,
16721668
BufferStoreOpConversion, BufferAtomicRMWOpConversion,
1673-
AsyncCopyToGlobalOpConversion, AsyncCommitGroupConversion,
1669+
AsyncCopyGlobalToLocalOpConversion, AsyncCommitGroupConversion,
16741670
AsyncWaitConversion, AsyncCommitGroupConversion>(
16751671
typeConverter, targetInfo, axisInfoAnalysis, benefit);
16761672
}

0 commit comments

Comments
 (0)