@@ -390,6 +390,135 @@ AMDMfmaEncodingAttr::toLinearLayout(ArrayRef<int64_t> shape) const {
390390 return combineCtaCgaWithShape (ctaLayout, getCTALayout (), shape);
391391}
392392
393+ LinearLayout chooseDotDsReadB64Tr16Layout (DotOperandEncodingAttr dotMfmaLayout,
394+ ArrayRef<int64_t > shape,
395+ int32_t elemBitWidth) {
396+ auto mfmaLayout = llvm::cast<AMDMfmaEncodingAttr>(dotMfmaLayout.getParent ());
397+ assert (mfmaLayout.getMDim () == 16 || mfmaLayout.getNDim () == 32 );
398+ assert (elemBitWidth == 16 );
399+
400+ auto rank = shape.size ();
401+ bool hasBatchDim = rank == 3 ;
402+ int32_t kWidthDot = dotMfmaLayout.getKWidth ();
403+ // Number of bits loaded by an LDS read. ds_read_tr primarily supports 64-bit
404+ // loads for most element sizes (16b, 8b, 4b).
405+ const int32_t ldsReadWidth = 64 ;
406+ int32_t kWidthTransRead = ldsReadWidth / elemBitWidth;
407+ auto kDim = dotMfmaLayout.getOpIdx () == 0 ? rank - 1 : rank - 2 ;
408+
409+ int32_t kSize = shape[kDim ];
410+ auto warpsPerCTA = mfmaLayout.getWarpsPerCTA ();
411+
412+ MLIRContext *ctx = dotMfmaLayout.getContext ();
413+ SmallVector<StringAttr> outDimNames = standardOutDimNames (ctx, rank);
414+
415+ StringAttr kRegister = S (" register" );
416+ StringAttr kLane = S (" lane" );
417+ StringAttr kWarp = S (" warp" );
418+
419+ // register order
420+ // operand A: [1, 0] / [2, 1, 0]
421+ // operand B: [0, 1] / [1, 2, 0]
422+ // Regular dot mfma order for both cases is [k, nonk]/[k, nonk, batch]
423+ // For LDS transpose layout swap order to [nonk, k]/[nonk, k, batch]
424+ SmallVector<unsigned > order = triton::gpu::getOrder (dotMfmaLayout);
425+ std::swap (order[0 ], order[1 ]);
426+
427+ // In the LDS transpose logic, each thread accesses 64 bits (8 bytes) of data.
428+ // The smallest unit for transposing is a 4x4 sub-tile of threads, where each
429+ // thread reads 4 16-bit elements along the non-K dimension, resulting in a
430+ // [non-K, K] = {16, 4} sub-tile of elements. Because of transposing
431+ // mechanism, thread ends up with 4 16-bit elements along K dim.
432+ //
433+ // The MFMA selection logic prioritizes double-rate MFMA instructions whenever
434+ // possible. Specifically:
435+ // - For MFMA operations that are non-K = 16, when blockK > 16, mfma16x16x32
436+ // is selected; otherwise (blockK ≤ 16), mfma16x16x16 remains the choice.
437+ // - For MFMA operations that are non-K = 32, when blockK > 8, mfma32x32x16 is
438+ // selected; otherwise (blockK ≤ 8), mfma32x32x8 is used.
439+ //
440+ // In double-rate MFMA instructions, each thread holds 8 elements along the K
441+ // dimension.
442+ // - The first 4 elements belong to the first sub-tile.
443+ // - The next 4 elements belong to the second sub-tile.
444+ //
445+ // We then group these into larger tiles, each consisting of 8 of these 16x4
446+ // sub-tiles. These tiles correspond to data for one mfma instruction. The
447+ // shapes of these tiles depend on the MFMA instruction used:
448+ // 1. For mfma32x32x16, the tile shape is [non-K, K] = {32, 16}.
449+ // 2. For mfma16x16x32, the tile shape is [non-K, K] = {16, 32}.
450+ //
451+ // For single-rate mfma instructions, each thread holds 4 elements along K
452+ // dimension. This means larger tile (that corresponds to one mfma
453+ // instruction) consists of 4 16x4 sub-tiles.
454+ std::vector<std::vector<int32_t >> registerBase = {{1 , 0 },
455+ {2 , 0 }}; // first sub-tile
456+ std::vector<std::vector<int32_t >> laneBase = {{kWidthTransRead , 0 },
457+ {2 * kWidthTransRead , 0 },
458+ {0 , 1 },
459+ {0 , 2 }}; // first sub-tile
460+
461+ // Extend register base for multiple tiles in K dimension (corresponding to
462+ // multiple mfma instructions accross k dim).
463+ auto populateRegisterBase = [&](int kTileSize ) {
464+ const int regsPerTile = 8 ;
465+ int numRegs = (kSize / kTileSize ) * regsPerTile;
466+ for (int reg = regsPerTile; reg < numRegs; reg *= 2 ) {
467+ registerBase.push_back ({0 , (reg / regsPerTile) * kTileSize });
468+ }
469+ };
470+
471+ const bool isMfma32 = (mfmaLayout.getMDim () == 32 );
472+ const bool isMfma16 = (mfmaLayout.getMDim () == 16 );
473+ const int kTileSize = isMfma32 ? 16 : 32 ;
474+
475+ if (kSize >= kTileSize ) {
476+ // Handles mfma32x32x16 and mfma16x16x32 cases
477+ assert (kWidthDot == 8 );
478+ registerBase.push_back ({0 , 4 }); // second sub-tile
479+ populateRegisterBase (kTileSize );
480+ auto laneBaseExt = isMfma32
481+ ? std::vector<std::vector<int32_t >>{{16 , 0 }, {0 , 8 }}
482+ : std::vector<std::vector<int32_t >>{{0 , 8 }, {0 , 16 }};
483+ laneBase.insert (laneBase.end (), laneBaseExt.begin (), laneBaseExt.end ());
484+ } else {
485+ // Handles mfma32x32x8 and mfma16x16x16 cases
486+ assert (kWidthDot == 4 );
487+ auto laneBaseExt = isMfma32
488+ ? std::vector<std::vector<int32_t >>{{16 , 0 }, {0 , 4 }}
489+ : std::vector<std::vector<int32_t >>{{0 , 4 }, {0 , 8 }};
490+ laneBase.insert (laneBase.end (), laneBaseExt.begin (), laneBaseExt.end ());
491+ }
492+
493+ // Base vectors above are defined in a fixed order [non-k-dim, k-dim].
494+ // To assign them to actual matrix dimensions `order` array is used.
495+ // For operand A: non-k-dim -> dim0, k-dim -> dim1
496+ // For operand B: non-k-dim -> dim1, k-dim -> dim0
497+ LinearLayout tileLayout ({{kRegister , registerBase}, {kLane , laneBase}},
498+ {outDimNames[order[0 ]], outDimNames[order[1 ]]});
499+
500+ if (hasBatchDim) {
501+ assert (order[2 ] == 0 );
502+ // Extend the base vector with one value to accommodate for the batch
503+ // dimension, which appears at the last.
504+ tileLayout *= LinearLayout::identity1D (1 , kRegister , outDimNames[order[2 ]]);
505+ tileLayout *= LinearLayout::identity1D (1 , kLane , outDimNames[order[2 ]]);
506+ }
507+
508+ // warp order
509+ // common for both operand A and B: [0, 1] / [0, 1, 2]
510+ // in both cases it is [M dim, N dim]/[batch, M dim, N dim]
511+ SmallVector<unsigned > warpOrder = triton::gpu::getWarpOrder (dotMfmaLayout);
512+ LinearLayout warpLayout = identityStandardND (kWarp , warpsPerCTA, warpOrder);
513+
514+ LinearLayout ctaLayout = tileLayout.transposeOuts (outDimNames) *
515+ warpLayout.transposeOuts (outDimNames);
516+ auto finalLayout =
517+ combineCtaCgaWithShape (ctaLayout, mfmaLayout.getCTALayout (), shape);
518+
519+ return finalLayout;
520+ }
521+
393522LinearLayout mfmaDotToLinearLayout (DotOperandEncodingAttr dotMfmaLayout,
394523 ArrayRef<int64_t > shape) {
395524
@@ -1200,4 +1329,10 @@ LinearLayout chooseLdMatrixLayout(Attribute enc, ArrayRef<int64_t> shape,
12001329 return chooseDotLdMatrixLayout (dot, shape, needTrans, elemBitWidth);
12011330}
12021331
1332+ LinearLayout chooseDsReadB64Tr16Layout (Attribute enc, ArrayRef<int64_t > shape,
1333+ int32_t elemBitWidth) {
1334+ auto dot = cast<DotOperandEncodingAttr>(enc);
1335+ return chooseDotDsReadB64Tr16Layout (dot, shape, elemBitWidth);
1336+ }
1337+
12031338} // namespace mlir::triton::gpu
0 commit comments