Skip to content

Commit ac79534

Browse files
authored
[Blackwell][Clean up] Remove use of SharedMemoryObject on TMEM (#5817)
Using `getSharedMemoryObjectFromStruct` etc on TMEM is very confusing for new readers, so I'm introducing simpler alternatives for TMEM. In practice, we only need the base address of TMEM. @ThomasRaoux --------- Co-authored-by: Masahiro Masuda <[email protected]>
1 parent 5656701 commit ac79534

File tree

3 files changed

+24
-48
lines changed

3 files changed

+24
-48
lines changed

lib/Conversion/TritonGPUToLLVM/TypeConverter.cpp

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -47,10 +47,17 @@ Type TritonGPUToLLVMTypeConverter::convertTritonTensorType(
4747
Type TritonGPUToLLVMTypeConverter::convertMemDescType(
4848
MemDescType type, const TargetInfoBase &targetInfo) {
4949
auto ctx = type.getContext();
50-
SmallVector<Type, 4> types;
5150
// base ptr
5251
auto ptrType =
5352
LLVM::LLVMPointerType::get(ctx, targetInfo.getSharedAddressSpace());
53+
54+
if (isa<triton::nvidia_gpu::TensorMemoryEncodingAttr,
55+
triton::nvidia_gpu::TensorMemoryScalesEncodingAttr>(
56+
type.getEncoding())) {
57+
return ptrType;
58+
}
59+
60+
SmallVector<Type, 4> types;
5461
types.push_back(ptrType);
5562
auto rank = type.getRank();
5663
// offsets

third_party/nvidia/lib/TritonNVIDIAGPUToLLVM/DotOpToLLVM/MMAv5.cpp

Lines changed: 4 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -11,9 +11,6 @@ using namespace mlir::triton::gpu;
1111
using namespace mlir::triton::NVIDIA;
1212

1313
using ::mlir::LLVM::getSharedMemoryObjectFromStruct;
14-
using ::mlir::triton::gpu::getShapePerCTA;
15-
using ::mlir::triton::gpu::getShapePerCTATile;
16-
using ::mlir::triton::gpu::NvidiaMmaEncodingAttr;
1714
using ::mlir::triton::gpu::NVMMASharedEncodingAttr;
1815

1916
mlir::triton::NVIDIA::DotOpMmaV5TmemLoader::DotOpMmaV5TmemLoader(
@@ -360,11 +357,6 @@ void convertDot(const LLVMTypeConverter *typeConverter,
360357
loc, loadedB, typeConverter->convertType(bTensorTy.getElementType()),
361358
rewriter)
362359
.getBase();
363-
Value baseD =
364-
getSharedMemoryObjectFromStruct(
365-
loc, loadedD, typeConverter->convertType(dTensorTy.getElementType()),
366-
rewriter)
367-
.getBase();
368360

369361
SmallVector<int64_t> dstPerCTA = triton::gpu::getShapePerCTA(dTensorTy);
370362
unsigned int M = dstPerCTA[0];
@@ -404,7 +396,7 @@ void convertDot(const LLVMTypeConverter *typeConverter,
404396
{(unsigned)mmaSizeN, (unsigned)mmaSizeK},
405397
bTensorTy.getElementTypeBitWidth(), rewriter, loc);
406398
DotOpMmaV5TmemLoader dLoader = DotOpMmaV5TmemLoader(
407-
d, baseD, {(unsigned)mmaSizeM, (unsigned)mmaSizeN}, interleaved, false);
399+
d, loadedD, {(unsigned)mmaSizeM, (unsigned)mmaSizeN}, interleaved, false);
408400
for (int m = 0; m < numRepM; m++) {
409401
for (int n = 0; n < numRepN; n++) {
410402
Value useInitAcc = useDFlag;
@@ -505,18 +497,10 @@ struct TCGen5MMAScaledOpConversion
505497
loc, adaptor.getB(),
506498
typeConverter->convertType(bTensorTy.getElementType()), rewriter)
507499
.getBase();
508-
Value baseD =
509-
getSharedMemoryObjectFromStruct(
510-
loc, adaptor.getD(),
511-
typeConverter->convertType(dTensorTy.getElementType()), rewriter)
512-
.getBase();
500+
Value baseD = adaptor.getD();
513501
baseD = tb.ptrtoint(i32_ty, baseD);
514-
Value baseScaleA = getSharedMemoryObjectFromStruct(loc, adaptor.getAScale(),
515-
i8_ty, rewriter)
516-
.getBase();
517-
Value baseScaleB = getSharedMemoryObjectFromStruct(loc, adaptor.getBScale(),
518-
i8_ty, rewriter)
519-
.getBase();
502+
Value baseScaleA = adaptor.getAScale();
503+
Value baseScaleB = adaptor.getBScale();
520504
baseScaleA = tb.ptrtoint(i32_ty, baseScaleA);
521505
baseScaleB = tb.ptrtoint(i32_ty, baseScaleB);
522506

third_party/nvidia/lib/TritonNVIDIAGPUToLLVM/TensorMemoryToLLVM.cpp

Lines changed: 12 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -222,7 +222,7 @@ static void reorderScales(SmallVector<Value> &srcValues, int64_t k) {
222222

223223
static void lowerStoreToTensorMemory(Location loc, ModuleOp mod, Value src,
224224
Value dest, Value llSrc, Value pred,
225-
SharedMemoryObject smemObj,
225+
Value tmemBase,
226226
ConversionPatternRewriter &rewriter) {
227227
auto b = TritonLLVMOpBuilder(loc, rewriter);
228228
SmallVector<Value> srcValues = unpackLLElements(loc, llSrc, rewriter);
@@ -236,7 +236,7 @@ static void lowerStoreToTensorMemory(Location loc, ModuleOp mod, Value src,
236236
}
237237
int regIdx = 0;
238238
calculateAddressAndEmitTmemMessage(
239-
loc, mod, smemObj.getBase(), cast<RankedTensorType>(src.getType()),
239+
loc, mod, tmemBase, cast<RankedTensorType>(src.getType()),
240240
cast<MemDescType>(dest.getType()), rewriter,
241241
[&](Value startAddress, int secondHalfColOffset, bool unpackedb16,
242242
int regsPerMessage, bool useStridedMessage) {
@@ -283,17 +283,13 @@ struct TensorMemoryAllocOpConversion
283283
std::iota(order.begin(), order.end(), 0);
284284
std::reverse(order.begin(), order.end());
285285
auto shape = op.getType().getShape();
286-
auto smemObj = SharedMemoryObject(ptr, op.getType().getElementType(),
287-
shape.size(), loc, rewriter);
288286

289287
if (op.getSrc()) {
290288
lowerStoreToTensorMemory(loc, mod, op.getSrc(), op.getResult(),
291-
adaptor.getSrc(), b.i1_val(true), smemObj,
292-
rewriter);
289+
adaptor.getSrc(), b.i1_val(true), ptr, rewriter);
293290
}
294291

295-
auto retVal = getStructFromSharedMemoryObject(loc, smemObj, rewriter);
296-
rewriter.replaceOp(op, retVal);
292+
rewriter.replaceOp(op, ptr);
297293
return success();
298294
}
299295
};
@@ -381,13 +377,11 @@ struct TensorMemoryLoadOpConversion
381377
auto mod = op->getParentOfType<ModuleOp>();
382378
auto llvmElemTy =
383379
getTypeConverter()->convertType(op.getSrc().getType().getElementType());
384-
auto smemObj = LLVM::getSharedMemoryObjectFromStruct(
385-
op.getLoc(), adaptor.getSrc(), llvmElemTy, rewriter);
380+
auto tmemBase = adaptor.getSrc();
386381

387382
SmallVector<Value> resultVals;
388383
calculateAddressAndEmitTmemMessage(
389-
loc, mod, smemObj.getBase(), op.getType(), op.getSrc().getType(),
390-
rewriter,
384+
loc, mod, tmemBase, op.getType(), op.getSrc().getType(), rewriter,
391385
[&](Value startAddress, int secondHalfColOffset, bool unpackedb16,
392386
int regsPerMessage, bool useStridedMessage) {
393387
Value packedValues = createTensorMemoryLoad(
@@ -420,11 +414,10 @@ struct TensorMemoryStoreOpConversion
420414
auto mod = op->getParentOfType<ModuleOp>();
421415
auto llvmElemTy =
422416
getTypeConverter()->convertType(op.getDst().getType().getElementType());
423-
auto smemObj = LLVM::getSharedMemoryObjectFromStruct(
424-
op.getLoc(), adaptor.getDst(), llvmElemTy, rewriter);
417+
auto tmemBase = adaptor.getDst();
425418
Value pred = adaptor.getPred();
426419
lowerStoreToTensorMemory(loc, mod, op.getSrc(), op.getDst(),
427-
adaptor.getSrc(), pred, smemObj, rewriter);
420+
adaptor.getSrc(), pred, tmemBase, rewriter);
428421

429422
rewriter.eraseOp(op);
430423
return success();
@@ -496,11 +489,7 @@ struct TensorMemoryCopyOpConversion
496489
typeConverter->convertType(srcTy.getElementType()), rewriter)
497490
.getBase();
498491

499-
Value baseDst =
500-
LLVM::getSharedMemoryObjectFromStruct(
501-
loc, adaptor.getDst(),
502-
typeConverter->convertType(srcTy.getElementType()), rewriter)
503-
.getBase();
492+
Value baseDst = adaptor.getDst();
504493

505494
// The following codegen assumes that we use tcgen05.cp only with
506495
// the warpx4.32x128b mode, to load blocked scales from MXFP.
@@ -592,8 +581,7 @@ struct MemDescSubviewOpConversion
592581
}
593582

594583
// newBase = base + offset
595-
auto smemObj = LLVM::getSharedMemoryObjectFromStruct(loc, adaptor.getSrc(),
596-
llvmElemTy, rewriter);
584+
auto tmemBase = adaptor.getSrc();
597585
SmallVector<Value> opOffsetVals = op.getOffsets();
598586
size_t destRank = op.getResult().getType().getRank();
599587
SmallVector<Value> offsetVals;
@@ -605,16 +593,13 @@ struct MemDescSubviewOpConversion
605593
triton::nvidia_gpu::TMemAllocation tmemAlloc =
606594
triton::nvidia_gpu::getTmemAllocSizes(cast<MemDescType>(dstTy));
607595
int numColOffset = tmemAlloc.numCols;
608-
Value newBase = b.ptrtoint(rewriter.getI32Type(), smemObj.getBase());
596+
Value newBase = b.ptrtoint(rewriter.getI32Type(), tmemBase);
609597
newBase = rewriter.create<LLVM::AddOp>(
610598
loc, newBase,
611599
rewriter.create<LLVM::MulOp>(loc, opOffsetVals[0],
612600
b.i32_val(numColOffset)));
613601
auto elemPtrTy = ptr_ty(rewriter.getContext(), 3);
614-
smemObj = SharedMemoryObject(b.inttoptr(elemPtrTy, newBase), llvmElemTy,
615-
offsetVals);
616-
rewriter.replaceOp(op,
617-
getStructFromSharedMemoryObject(loc, smemObj, rewriter));
602+
rewriter.replaceOp(op, b.inttoptr(elemPtrTy, newBase));
618603
return success();
619604
}
620605
};

0 commit comments

Comments
 (0)