@@ -59,18 +59,18 @@ Type replaceLayout(const Type &type, const Attribute &newLayout) {
5959
6060ttg::DistributedEncodingTrait
6161replaceCTALayout (ttg::DistributedEncodingTrait layout,
62- llvm::ArrayRef<int64_t > shape,
62+ llvm::ArrayRef<int64_t > shape, int numWarps,
6363 const ttg::CTALayoutAttr &newCTALayout) {
6464 if (auto blockedLayout = mlir::dyn_cast<ttg::BlockedEncodingAttr>(layout)) {
6565 return ttg::BlockedEncodingAttr::get (
6666 layout.getContext (), shape, blockedLayout.getSizePerThread (),
67- blockedLayout.getOrder (), ttg::getNumWarpsPerCTA (layout), 32 ,
68- newCTALayout);
67+ blockedLayout.getOrder (), numWarps, 32 , newCTALayout);
6968 } else if (auto sliceLayout =
7069 mlir::dyn_cast<ttg::SliceEncodingAttr>(layout)) {
7170 return ttg::SliceEncodingAttr::get (
7271 layout.getContext (), sliceLayout.getDim (),
73- replaceCTALayout (sliceLayout.getParent (), shape, newCTALayout));
72+ replaceCTALayout (sliceLayout.getParent (), shape, numWarps,
73+ newCTALayout));
7474 } else {
7575 // Other layouts are generated by passes after PlanCTAPass
7676 llvm::report_fatal_error (" replaceCTALayout not implemented" );
@@ -293,11 +293,15 @@ bool CTAPlanner::processDot(triton::FuncOp &funcOp) {
293293 // FIXME: Should consider IR with more than one DotOps
294294 setTiling ({splitM, splitN, 1 });
295295
296+ OpBuilder builder (dot);
297+ auto numThreads = ttg::lookupThreadsPerWarp (builder);
298+ auto numWarps = ttg::lookupNumWarps (dot);
299+
296300 auto newCTALayout = ttg::CTALayoutAttr::get (ctx, {splitM, splitN},
297301 {splitM, splitN}, {1 , 0 });
298302 auto newDLayout = ttg::BlockedEncodingAttr::get (
299303 ctx, dTy.getShape (), dLayout.getSizePerThread (), dLayout.getOrder (),
300- ttg::getNumWarpsPerCTA (dLayout), 32 , newCTALayout);
304+ numWarps, numThreads , newCTALayout);
301305 auto newALayout = ttg::DotOperandEncodingAttr::get (ctx, aLayout.getOpIdx (),
302306 newDLayout, 0 );
303307 auto newBLayout = ttg::DotOperandEncodingAttr::get (ctx, bLayout.getOpIdx (),
@@ -359,12 +363,14 @@ bool CTAPlanner::processReduce(triton::FuncOp &funcOp) {
359363 if (remainingCTAs > 0 )
360364 CTAsPerCGA[order[rank - 1 ]] *= remainingCTAs;
361365
366+ auto numWarps = ttg::lookupNumWarps (reduce);
362367 auto CTALayout =
363368 ttg::CTALayoutAttr::get (context, CTAsPerCGA, CTASplitNum, CTAOrder);
364369 if (!tiled)
365370 setTiling (CTALayout.getCTAsPerCGA ());
366- auto newSrcLayout = replaceCTALayout (
367- cast<ttg::DistributedEncodingTrait>(srcLayout), srcShape, CTALayout);
371+ auto newSrcLayout =
372+ replaceCTALayout (cast<ttg::DistributedEncodingTrait>(srcLayout),
373+ srcShape, numWarps, CTALayout);
368374 auto newResultLayout =
369375 ttg::SliceEncodingAttr::get (context, axis, newSrcLayout);
370376 unsigned numOperands = reduce.getNumOperands ();
@@ -386,6 +392,7 @@ void CTAPlanner::processStoreLikeOps(triton::FuncOp &funcOp) {
386392 stores.push_back (op);
387393 });
388394 assert (stores.size () > 0 && " Cannot find store-like ops" );
395+ auto numWarps = ttg::lookupNumWarps (funcOp);
389396
390397 ttg::CTALayoutAttr CTALayout;
391398 for (Operation *store : stores) {
@@ -398,7 +405,7 @@ void CTAPlanner::processStoreLikeOps(triton::FuncOp &funcOp) {
398405 }
399406 auto newLayout = replaceCTALayout (
400407 cast<ttg::DistributedEncodingTrait>(tensorTy.getEncoding ()),
401- tensorTy.getShape (), CTALayout);
408+ tensorTy.getShape (), numWarps, CTALayout);
402409 processElementwise (store, newLayout);
403410 }
404411 }
@@ -624,6 +631,7 @@ bool CTAPlanner::processLoadStore(Operation *op, Attribute layout) {
624631 }
625632
626633 auto CTALayout = ttg::getCTALayout (layout);
634+ auto numWarps = ttg::lookupNumWarps (op);
627635
628636 llvm::SmallVector<Attribute> newOperandLayouts;
629637 for (unsigned i = 0 ; i < op->getNumOperands (); ++i) {
@@ -634,7 +642,7 @@ bool CTAPlanner::processLoadStore(Operation *op, Attribute layout) {
634642 auto oldLayout =
635643 cast<ttg::DistributedEncodingTrait>(tensorTy.getEncoding ());
636644 auto newLayout =
637- replaceCTALayout (oldLayout, tensorTy.getShape (), CTALayout);
645+ replaceCTALayout (oldLayout, tensorTy.getShape (), numWarps, CTALayout);
638646 newOperandLayouts.push_back (newLayout);
639647 }
640648
@@ -647,7 +655,7 @@ bool CTAPlanner::processLoadStore(Operation *op, Attribute layout) {
647655 auto oldLayout =
648656 cast<ttg::DistributedEncodingTrait>(tensorTy.getEncoding ());
649657 auto newLayout =
650- replaceCTALayout (oldLayout, tensorTy.getShape (), CTALayout);
658+ replaceCTALayout (oldLayout, tensorTy.getShape (), numWarps, CTALayout);
651659 newResultLayouts.push_back (newLayout);
652660 }
653661
0 commit comments