@@ -239,6 +239,42 @@ LinearLayout sharedToLinearLayoutLeadingOffset(ArrayRef<int64_t> shape,
239239 return combineCtaCgaWithShape (tileLayout, shared.getCTALayout (), shape);
240240}
241241
242+ LinearLayout warpsDotOperand (MLIRContext *ctx, ArrayRef<unsigned > warpShape,
243+ ArrayRef<unsigned > warpOrder, unsigned inner) {
244+ // Let warpsPerCTAMma = {2, 2}, then
245+ // warpsPerCTA = {2, 1} for opA and warpsPerCTA = {1, 2} for opB
246+ // assume warpOrder = {1, 0}
247+ // Assume that C is tiled by 2x2 tiles. Since warpOrder={1, 0}, we have that
248+ // the C is owned as per the following layout:
249+ // C: 0 | 1
250+ // - | -
251+ // 2 | 3
252+ // In order to be able to compute C, we need the following warp tiling of
253+ // A and B:
254+ // A: 0 1 | 0 1 B: 0 2 | 1 3
255+ // - - | - - - - | - -
256+ // 2 3 | 2 3 0 2 | 1 3
257+ // In other words, we need to broadcast along K
258+ auto rank = warpShape.size ();
259+ auto dimNames = standardOutDimNames (ctx, rank);
260+ LinearLayout warpLayout = LinearLayout::empty ();
261+
262+ // We have to broadcast along the inner dimension
263+ // For A, when moving along M we go from 0 to 2.
264+ // For B, when moving along N we go from 0 to 1.
265+ // As such, choosing the order of A {1, 0}, gives us the correct broadcasting
266+ // Same happens if the warpOrder is {0, 1}, like in Hopper
267+ for (auto d : warpOrder) {
268+ if (d == inner) {
269+ warpLayout *= LinearLayout::zeros1D (warpShape[d], S (" warp" ), dimNames[d]);
270+ } else {
271+ warpLayout *=
272+ LinearLayout::identity1D (warpShape[d], S (" warp" ), dimNames[d]);
273+ }
274+ }
275+ return warpLayout;
276+ }
277+
242278} // anonymous namespace
243279
244280std::optional<LinearLayout>
@@ -470,7 +506,9 @@ AMDWmmaEncodingAttr::toLinearLayout(ArrayRef<int64_t> shape) const {
470506
471507 // We use the order from fastest varying to slowest varying. So each base
472508 // vector is a tuple of values mapping to matrix C's (N, M[, B]) indices.
473- SmallVector<unsigned > order = triton::gpu::getOrder (*this );
509+ SmallVector<unsigned > threadOrder = getThreadOrder ();
510+ assert (threadOrder[0 ] == mIndex || threadOrder[0 ] == nIndex);
511+ assert (threadOrder[1 ] == mIndex || threadOrder[1 ] == nIndex);
474512
475513 // For wmma with 16x16 output, each of the 32 threads holds 8 elements.
476514 //
@@ -498,29 +536,106 @@ AMDWmmaEncodingAttr::toLinearLayout(ArrayRef<int64_t> shape) const {
498536 ? LinearLayout (
499537 {{kRegister , {/* gap*/ {0 , 2 }, {0 , 4 }, {0 , 8 }}},
500538 {kLane , {{1 , 0 }, {2 , 0 }, {4 , 0 }, {8 , 0 }, /* gap*/ {0 , 1 }}}},
501- {outDimNames[order [0 ]], outDimNames[order [1 ]]})
539+ {outDimNames[threadOrder [0 ]], outDimNames[threadOrder [1 ]]})
502540 : LinearLayout (
503541 {{kRegister , {{0 , 1 }, {0 , 2 }, {0 , 4 }}},
504542 {kLane , {{1 , 0 }, {2 , 0 }, {4 , 0 }, {8 , 0 }, /* gap*/ {0 , 8 }}}},
505- {outDimNames[order [0 ]], outDimNames[order [1 ]]});
543+ {outDimNames[threadOrder [0 ]], outDimNames[threadOrder [1 ]]});
506544
507545 if (hasBatchDim) {
508- assert (order[ 2 ] == 0 ) ;
546+ int batchIndex = 0 ;
509547 // Extend the base vector with one value to accomodate for the batch
510548 // dimension, which appears at the last.
511- tileLayout *= LinearLayout::identity1D (1 , kRegister , outDimNames[order[2 ]]);
512- tileLayout *= LinearLayout::identity1D (1 , kLane , outDimNames[order[2 ]]);
549+ tileLayout *=
550+ LinearLayout::identity1D (1 , kRegister , outDimNames[batchIndex]);
551+ tileLayout *= LinearLayout::identity1D (1 , kLane , outDimNames[batchIndex]);
513552 }
514553
515554 // And each warp takes the same register and lane sub-layout. So mulitply with
516555 // an identity layout for the warp.
556+ auto warpOrder = getWarpOrder ();
517557 LinearLayout warpLayout =
518- identityStandardND (S (" warp" ), getWarpsPerCTA (), order);
519- LinearLayout ctaLayout = tileLayout * warpLayout;
558+ identityStandardND (S (" warp" ), getWarpsPerCTA (), warpOrder);
559+ // reorder dim names in rep order, so combineCtaCgaWithShape generate proper
560+ // extension of layout
561+ auto repOrder = getRepOrder ();
562+ SmallVector<StringAttr> repDimNames;
563+ for (auto dim : repOrder)
564+ repDimNames.push_back (outDimNames[dim]);
565+ LinearLayout ctaLayout = tileLayout.transposeOuts (repDimNames) *
566+ warpLayout.transposeOuts (repDimNames);
520567
521568 return combineCtaCgaWithShape (ctaLayout, getCTALayout (), shape);
522569}
523570
571+ std::optional<LinearLayout>
572+ wmmaDotOperandToLinearLayout (DotOperandEncodingAttr dotWmmaLayout,
573+ ArrayRef<int64_t > shape) {
574+ auto wmmaLayout = llvm::cast<AMDWmmaEncodingAttr>(dotWmmaLayout.getParent ());
575+ auto rank = shape.size ();
576+ bool hasBatchDim = rank == 3 ;
577+ auto kDim = dotWmmaLayout.getOpIdx () == 0 ? rank - 1 : rank - 2 ;
578+ int32_t kSize = shape[kDim ];
579+ MLIRContext *ctx = dotWmmaLayout.getContext ();
580+ SmallVector<StringAttr> outDimNames = standardOutDimNames (ctx, rank);
581+ StringAttr kRegister = S (" register" );
582+ StringAttr kLane = S (" lane" );
583+ StringAttr kWarp = S (" warp" );
584+ // lane order
585+ // operand A: [1, 0] / [2, 1, 0]
586+ // operand B: [0, 1] / [1, 2, 0]
587+ // for both cases it is [k, nonk]/[k, nonk, batch]
588+ SmallVector<unsigned > laneOrder = triton::gpu::getOrder (dotWmmaLayout);
589+ // generate continuous part of register bases(i.e. kWidth)
590+ std::vector<std::vector<int32_t >> registerBase;
591+ const int32_t kWidth = dotWmmaLayout.getKWidth ();
592+ for (int i = 1 ; i < kWidth ; i *= 2 )
593+ registerBase.push_back (std::vector<int32_t >{i, 0 });
594+ std::vector<std::vector<int32_t >> laneBase = {{0 , 1 }, {0 , 2 }, {0 , 4 }, {0 , 8 }};
595+ switch (wmmaLayout.getVersion ()) {
596+ case 1 :
597+ // WMMA version 1 duplicates values in lanes 0-15 and 16-31
598+ laneBase.push_back ({0 , 0 });
599+ break ;
600+ case 2 :
601+ // WMMA version 2 offset values in lanes 0-15 and 16-31 across k dimensions
602+ laneBase.push_back ({kWidth , 0 });
603+ break ;
604+ default :
605+ assert (false && " unexpected version" );
606+ }
607+ // Generate layout for one wmma instruction
608+ LinearLayout tileLayout (
609+ {{kRegister , registerBase}, {kLane , laneBase}},
610+ {outDimNames[laneOrder[0 ]], outDimNames[laneOrder[1 ]]});
611+ if (hasBatchDim) {
612+ assert (laneOrder[2 ] == 0 );
613+ // Extend the base vector with one value to accomodate for the batch
614+ // dimension, which appears at the last.
615+ tileLayout *=
616+ LinearLayout::identity1D (1 , kRegister , outDimNames[laneOrder[2 ]]);
617+ tileLayout *= LinearLayout::identity1D (1 , kLane , outDimNames[laneOrder[2 ]]);
618+ }
619+
620+ // Generate warp layout
621+ auto warpsPerCTA = wmmaLayout.getWarpsPerCTA ();
622+ auto warpOrder = triton::gpu::getWarpOrder (dotWmmaLayout);
623+ LinearLayout warpLayout = warpsDotOperand (ctx, warpsPerCTA, warpOrder, kDim );
624+
625+ // reorder dim names in rep order, so combineCtaCgaWithShape generate proper
626+ // extension of layout
627+ auto repOrder = wmmaLayout.getRepOrderForOperand (dotWmmaLayout.getOpIdx ());
628+ SmallVector<StringAttr> repDimNames;
629+ for (auto dim : repOrder)
630+ repDimNames.push_back (outDimNames[dim]);
631+
632+ // join instruction layout and warps using repetition order of dimensions
633+ LinearLayout ctaLayout = tileLayout.transposeOuts (repDimNames) *
634+ warpLayout.transposeOuts (repDimNames);
635+
636+ return combineCtaCgaWithShape (ctaLayout, wmmaLayout.getCTALayout (), shape);
637+ }
638+
524639std::optional<LinearLayout>
525640BlockedEncodingAttr::toLinearLayout (ArrayRef<int64_t > shape) const {
526641 assert (shape.size () == getOrder ().size ());
@@ -604,44 +719,6 @@ NvidiaMmaEncodingAttr::toLinearLayout(ArrayRef<int64_t> shape) const {
604719 return combineCtaCgaWithShape (ctaLayout, getCTALayout (), shape);
605720}
606721
607- LinearLayout warpsNvidiaDot (MLIRContext *ctx, ArrayRef<unsigned > mmaWarpShape,
608- ArrayRef<unsigned > mmaWarpOrder, bool isA) {
609- // Let warpsPerCTAMma = {2, 2}, then
610- // warpsPerCTA = {2, 1} for opA and warpsPerCTA = {1, 2} for opB
611- // assume warpOrder = {1, 0}
612- // Assume that C is tiled by 2x2 tiles. Since warpOrder={1, 0}, we have that
613- // the C is owned as per the following layout:
614- // C: 0 | 1
615- // - | -
616- // 2 | 3
617- // In order to be able to compute C, we need the following warp tiling of
618- // A and B:
619- // A: 0 1 | 0 1 B: 0 2 | 1 3
620- // - - | - - - - | - -
621- // 2 3 | 2 3 0 2 | 1 3
622- // In other words, we need to broadcast along K
623- auto rank = mmaWarpOrder.size ();
624- auto inner = isA ? rank - 1 : rank - 2 ;
625- auto dimNames = standardOutDimNames (ctx, rank);
626- LinearLayout warpLayout = LinearLayout::empty ();
627-
628- // We have to broadcast along the inner dimension
629- // For A, when moving along M we go from 0 to 2.
630- // For B, when moving along N we go from 0 to 1.
631- // As such, choosing the order of A {1, 0}, gives us the correct broadcasting
632- // Same happens if the mmaWarpOrder is {0, 1}, like in Hopper
633- for (auto d : mmaWarpOrder) {
634- if (d == inner) {
635- warpLayout *=
636- LinearLayout::zeros1D (mmaWarpShape[d], S (" warp" ), dimNames[d]);
637- } else {
638- warpLayout *=
639- LinearLayout::identity1D (mmaWarpShape[d], S (" warp" ), dimNames[d]);
640- }
641- }
642- return warpLayout;
643- }
644-
645722LinearLayout nvidiaDotToLinearLayout (ArrayRef<int64_t > shape,
646723 DotOperandEncodingAttr dot) {
647724 int rank = shape.size ();
@@ -662,8 +739,9 @@ LinearLayout nvidiaDotToLinearLayout(ArrayRef<int64_t> shape,
662739 }
663740 auto ctaLayout =
664741 nvidiaMmaTile (ctx, tileShape, kWidth , getOrder (dot), dot.getRepOrder ());
742+ auto kDim = isA ? rank - 1 : rank - 2 ;
665743 ctaLayout *=
666- warpsNvidiaDot (ctx, mma.getWarpsPerCTA (), mma.getWarpOrder (), isA )
744+ warpsDotOperand (ctx, mma.getWarpsPerCTA (), mma.getWarpOrder (), kDim )
667745 .transposeOuts (llvm::to_vector (ctaLayout.getOutDimNames ()));
668746
669747 return combineCtaCgaWithShape (ctaLayout, getCTALayout (dot), shape);
@@ -674,6 +752,8 @@ DotOperandEncodingAttr::toLinearLayout(ArrayRef<int64_t> shape) const {
674752 auto parent = getParent ();
675753 if (auto mfmaLayout = llvm::dyn_cast<AMDMfmaEncodingAttr>(parent)) {
676754 return mfmaDotToLinearLayout (*this , shape);
755+ } else if (auto wmmaLayout = llvm::dyn_cast<AMDWmmaEncodingAttr>(parent)) {
756+ return wmmaDotOperandToLinearLayout (*this , shape);
677757 } else if (auto mma = mlir::dyn_cast<NvidiaMmaEncodingAttr>(parent)) {
678758 return nvidiaDotToLinearLayout (shape, *this );
679759 }
0 commit comments