@@ -222,7 +222,7 @@ static void reorderScales(SmallVector<Value> &srcValues, int64_t k) {
222222
223223static void lowerStoreToTensorMemory (Location loc, ModuleOp mod, Value src,
224224 Value dest, Value llSrc, Value pred,
225- SharedMemoryObject smemObj ,
225+ Value tmemBase ,
226226 ConversionPatternRewriter &rewriter) {
227227 auto b = TritonLLVMOpBuilder (loc, rewriter);
228228 SmallVector<Value> srcValues = unpackLLElements (loc, llSrc, rewriter);
@@ -236,7 +236,7 @@ static void lowerStoreToTensorMemory(Location loc, ModuleOp mod, Value src,
236236 }
237237 int regIdx = 0 ;
238238 calculateAddressAndEmitTmemMessage (
239- loc, mod, smemObj. getBase () , cast<RankedTensorType>(src.getType ()),
239+ loc, mod, tmemBase , cast<RankedTensorType>(src.getType ()),
240240 cast<MemDescType>(dest.getType ()), rewriter,
241241 [&](Value startAddress, int secondHalfColOffset, bool unpackedb16,
242242 int regsPerMessage, bool useStridedMessage) {
@@ -283,17 +283,13 @@ struct TensorMemoryAllocOpConversion
283283 std::iota (order.begin (), order.end (), 0 );
284284 std::reverse (order.begin (), order.end ());
285285 auto shape = op.getType ().getShape ();
286- auto smemObj = SharedMemoryObject (ptr, op.getType ().getElementType (),
287- shape.size (), loc, rewriter);
288286
289287 if (op.getSrc ()) {
290288 lowerStoreToTensorMemory (loc, mod, op.getSrc (), op.getResult (),
291- adaptor.getSrc (), b.i1_val (true ), smemObj,
292- rewriter);
289+ adaptor.getSrc (), b.i1_val (true ), ptr, rewriter);
293290 }
294291
295- auto retVal = getStructFromSharedMemoryObject (loc, smemObj, rewriter);
296- rewriter.replaceOp (op, retVal);
292+ rewriter.replaceOp (op, ptr);
297293 return success ();
298294 }
299295};
@@ -381,13 +377,11 @@ struct TensorMemoryLoadOpConversion
381377 auto mod = op->getParentOfType <ModuleOp>();
382378 auto llvmElemTy =
383379 getTypeConverter ()->convertType (op.getSrc ().getType ().getElementType ());
384- auto smemObj = LLVM::getSharedMemoryObjectFromStruct (
385- op.getLoc (), adaptor.getSrc (), llvmElemTy, rewriter);
380+ auto tmemBase = adaptor.getSrc ();
386381
387382 SmallVector<Value> resultVals;
388383 calculateAddressAndEmitTmemMessage (
389- loc, mod, smemObj.getBase (), op.getType (), op.getSrc ().getType (),
390- rewriter,
384+ loc, mod, tmemBase, op.getType (), op.getSrc ().getType (), rewriter,
391385 [&](Value startAddress, int secondHalfColOffset, bool unpackedb16,
392386 int regsPerMessage, bool useStridedMessage) {
393387 Value packedValues = createTensorMemoryLoad (
@@ -420,11 +414,10 @@ struct TensorMemoryStoreOpConversion
420414 auto mod = op->getParentOfType <ModuleOp>();
421415 auto llvmElemTy =
422416 getTypeConverter ()->convertType (op.getDst ().getType ().getElementType ());
423- auto smemObj = LLVM::getSharedMemoryObjectFromStruct (
424- op.getLoc (), adaptor.getDst (), llvmElemTy, rewriter);
417+ auto tmemBase = adaptor.getDst ();
425418 Value pred = adaptor.getPred ();
426419 lowerStoreToTensorMemory (loc, mod, op.getSrc (), op.getDst (),
427- adaptor.getSrc (), pred, smemObj , rewriter);
420+ adaptor.getSrc (), pred, tmemBase , rewriter);
428421
429422 rewriter.eraseOp (op);
430423 return success ();
@@ -496,11 +489,7 @@ struct TensorMemoryCopyOpConversion
496489 typeConverter->convertType (srcTy.getElementType ()), rewriter)
497490 .getBase ();
498491
499- Value baseDst =
500- LLVM::getSharedMemoryObjectFromStruct (
501- loc, adaptor.getDst (),
502- typeConverter->convertType (srcTy.getElementType ()), rewriter)
503- .getBase ();
492+ Value baseDst = adaptor.getDst ();
504493
505494 // The following codegen assumes that we use tcgen05.cp only with
506495 // the warpx4.32x128b mode, to load blocked scales from MXFP.
@@ -592,8 +581,7 @@ struct MemDescSubviewOpConversion
592581 }
593582
594583 // newBase = base + offset
595- auto smemObj = LLVM::getSharedMemoryObjectFromStruct (loc, adaptor.getSrc (),
596- llvmElemTy, rewriter);
584+ auto tmemBase = adaptor.getSrc ();
597585 SmallVector<Value> opOffsetVals = op.getOffsets ();
598586 size_t destRank = op.getResult ().getType ().getRank ();
599587 SmallVector<Value> offsetVals;
@@ -605,16 +593,13 @@ struct MemDescSubviewOpConversion
605593 triton::nvidia_gpu::TMemAllocation tmemAlloc =
606594 triton::nvidia_gpu::getTmemAllocSizes (cast<MemDescType>(dstTy));
607595 int numColOffset = tmemAlloc.numCols ;
608- Value newBase = b.ptrtoint (rewriter.getI32Type (), smemObj. getBase () );
596+ Value newBase = b.ptrtoint (rewriter.getI32Type (), tmemBase );
609597 newBase = rewriter.create <LLVM::AddOp>(
610598 loc, newBase,
611599 rewriter.create <LLVM::MulOp>(loc, opOffsetVals[0 ],
612600 b.i32_val (numColOffset)));
613601 auto elemPtrTy = ptr_ty (rewriter.getContext (), 3 );
614- smemObj = SharedMemoryObject (b.inttoptr (elemPtrTy, newBase), llvmElemTy,
615- offsetVals);
616- rewriter.replaceOp (op,
617- getStructFromSharedMemoryObject (loc, smemObj, rewriter));
602+ rewriter.replaceOp (op, b.inttoptr (elemPtrTy, newBase));
618603 return success ();
619604 }
620605};
0 commit comments