Skip to content

Commit c8d05e1

Browse files
committed
Cleanup
1 parent 435fb9e commit c8d05e1

File tree

1 file changed

+27
-31
lines changed

1 file changed

+27
-31
lines changed

third_party/amd/lib/TritonAMDGPUToLLVM/LoadStoreOpToLLVM.cpp

Lines changed: 27 additions & 31 deletions
Original file line numberDiff line numberDiff line change
@@ -399,16 +399,16 @@ struct BufferLoadOpConversion
399399
}
400400
};
401401

402-
struct AsyncCopyToGlobalOpConversion
402+
struct AsyncCopyGlobalToLocalOpConversion
403403
: public ConvertOpToLLVMPattern<triton::gpu::AsyncCopyGlobalToLocalOp>,
404404
public LoadStoreConversionBase {
405405
using ConvertOpToLLVMPattern<
406406
triton::gpu::AsyncCopyGlobalToLocalOp>::ConvertOpToLLVMPattern;
407407

408-
AsyncCopyToGlobalOpConversion(LLVMTypeConverter &converter,
409-
const AMD::TargetInfo &targetInfo,
410-
ModuleAxisInfoAnalysis &axisAnalysisPass,
411-
PatternBenefit benefit)
408+
AsyncCopyGlobalToLocalOpConversion(LLVMTypeConverter &converter,
409+
const AMD::TargetInfo &targetInfo,
410+
ModuleAxisInfoAnalysis &axisAnalysisPass,
411+
PatternBenefit benefit)
412412
: ConvertOpToLLVMPattern<triton::gpu::AsyncCopyGlobalToLocalOp>(converter,
413413
benefit),
414414
LoadStoreConversionBase(targetInfo, axisAnalysisPass) {}
@@ -442,21 +442,17 @@ struct AsyncCopyToGlobalOpConversion
442442
MLIRContext *ctx = rewriter.getContext();
443443
auto loc = op.getLoc();
444444

445-
Value mask = op.getMask();
446-
Value other = op.getOther();
447-
448445
auto srcTy = op.getSrc().getType();
449446
auto srcEncoding = srcTy.getEncoding();
450447
assert((isa<BlockedEncodingAttr, SliceEncodingAttr>(srcEncoding) &&
451448
"Unexpected srcEncoding in AsyncCopyGlobalToLocalOpConversion"));
452-
453-
auto dstTy = op.getResult().getType();
454-
auto resElemTy = getTypeConverter()->convertType(dstTy.getElementType());
455-
456449
auto srcShape = srcTy.getShape();
457450
assert(srcShape.size() <= 2 && "Async copy only supports 1d and 2d "
458451
"tensors: Unexpected rank of %src");
459452

453+
auto dstTy = op.getResult().getType();
454+
auto resElemTy = getTypeConverter()->convertType(dstTy.getElementType());
455+
460456
Value llSrc = adaptor.getSrc();
461457

462458
auto srcElems = unpackLLElements(loc, llSrc, rewriter);
@@ -465,21 +461,9 @@ struct AsyncCopyToGlobalOpConversion
465461
auto smemObj = mlir::LLVM::getSharedMemoryObjectFromStruct(
466462
loc, llDst, resElemTy, rewriter);
467463

468-
Value llMask = adaptor.getMask();
469-
SmallVector<Value> maskElems;
470-
if (llMask) {
471-
maskElems = unpackLLElements(loc, llMask, rewriter);
472-
assert(srcElems.size() == maskElems.size());
473-
}
474-
475-
Value llOther = adaptor.getOther();
476-
SmallVector<Value> otherElems;
477-
if (llOther) {
478-
otherElems = unpackLLElements(loc, llOther, rewriter);
479-
assert(srcElems.size() == otherElems.size());
480-
}
481-
482464
unsigned maxVec = getContiguity(op.getSrc(), axisAnalysisPass);
465+
466+
Value mask = op.getMask();
483467
if (mask) {
484468
maxVec = std::min(maxVec, getMaskAlignment(mask));
485469
}
@@ -526,13 +510,25 @@ struct AsyncCopyToGlobalOpConversion
526510

527511
int vecBytes = vecBits / 8;
528512
assert(llvm::isPowerOf2_32(vecBytes));
529-
530-
std::string intrinsic = "llvm.amdgcn.global.load.lds";
531513
Value vecBytesVal = i32_val(vecBytes);
532514

533515
Value cacheModifiers = i32_val(
534516
getCtrlBitsForCacheModifierOnTarget(op.getCache(), false, targetInfo));
535517

518+
Value llMask = adaptor.getMask();
519+
SmallVector<Value> maskElems;
520+
if (llMask) {
521+
maskElems = unpackLLElements(loc, llMask, rewriter);
522+
assert(srcElems.size() == maskElems.size());
523+
}
524+
525+
Value other = op.getOther();
526+
SmallVector<Value> otherElems;
527+
if (other) {
528+
otherElems = unpackLLElements(loc, adaptor.getOther(), rewriter);
529+
assert(srcElems.size() == otherElems.size());
530+
}
531+
536532
for (int i = 0; i < shmemAddrs.size(); i++) {
537533
auto srcIdx = i * maxVec;
538534
auto srcPtr = srcElems[srcIdx];
@@ -1652,7 +1648,7 @@ struct AsyncWaitConversion : public ConvertOpToLLVMPattern<AsyncWaitOp> {
16521648
ConversionPatternRewriter &rewriter) const override {
16531649

16541650
auto loc = op->getLoc();
1655-
rewriter.create<ROCDL::WaitcntOp>(loc, 0);
1651+
rewriter.create<ROCDL::WaitcntOp>(loc, op.getNum());
16561652
rewriter.replaceOp(op, i32_val(0));
16571653
return success();
16581654
}
@@ -1671,7 +1667,7 @@ struct AsyncCommitGroupConversion
16711667
LogicalResult
16721668
matchAndRewrite(AsyncCommitGroupOp op, OpAdaptor adaptor,
16731669
ConversionPatternRewriter &rewriter) const override {
1674-
// We do not have that concept so simply drop it
1670+
// Drop the result token
16751671
auto loc = op->getLoc();
16761672
rewriter.replaceOp(op, i32_val(0));
16771673
return success();
@@ -1690,7 +1686,7 @@ void populateLoadStoreOpToLLVMPatterns(LLVMTypeConverter &typeConverter,
16901686
patterns.add<AtomicCASOpConversion, AtomicRMWOpConversion, LoadOpConversion,
16911687
StoreOpConversion, BufferLoadOpConversion,
16921688
BufferStoreOpConversion, BufferAtomicRMWOpConversion,
1693-
AsyncCopyToGlobalOpConversion, AsyncCommitGroupConversion,
1689+
AsyncCopyGlobalToLocalOpConversion, AsyncCommitGroupConversion,
16941690
AsyncWaitConversion, AsyncCommitGroupConversion>(
16951691
typeConverter, targetInfo, axisInfoAnalysis, benefit);
16961692
}

0 commit comments

Comments
 (0)