Skip to content

Commit 1d8e147

Browse files
authored
[BACKEND] Throw an error instead of miscompiling very large tcgen05.mma along N (#8915)
See the comments in the PR for the full description of the issue. Now the skips in `test_core.py` skip exactly the test that would fail, and these tests now fail with a verifier error, rather than miscompiling.
1 parent 2c59bed commit 1d8e147

File tree

4 files changed

+59
-13
lines changed

4 files changed

+59
-13
lines changed

lib/Dialect/TritonNvidiaGPU/IR/Ops.cpp

Lines changed: 41 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -407,6 +407,10 @@ LogicalResult TCGen5MMAOp::verify() {
407407
return emitOpError("The col stride of the return operand must be 32 / ")
408408
<< retType.getElementTypeBitWidth() << " but got "
409409
<< retEnc.getColStride();
410+
// The maximum size of a MMA instruction is 128x256
411+
if (retEnc.getBlockN() > 256)
412+
return emitOpError("The block size of the return operand must be less than "
413+
"or equal to 256");
410414

411415
auto aSplit = getCTASplitNum(aEnc);
412416
auto bSplit = getCTASplitNum(bEnc);
@@ -422,6 +426,43 @@ LogicalResult TCGen5MMAOp::verify() {
422426
if (getTwoCtas()) {
423427
auto retSplit = getCTASplitNum(retEnc);
424428

429+
auto nPerCTA = retType.getDimSize(1) / retSplit[1];
430+
431+
// [Note: numRepN > 1 and two_ctas]
432+
// Consider, just as an example, num_ctas=16, and a huge tile of shape
433+
// MNK = 512x64x2048
434+
// This is an example of layout with numRepN=2 and two_ctas=true:
435+
// Layout RHS:
436+
// #ttg.memdesc<64x2048xf16,
437+
// #ttg.nvmma_shared<{swizzlingByteWidth = 64, transposed = true,
438+
// elementBitWidth = 16,
439+
// CGALayout = [[0, 1], [0, 2], [0, 4], [0, 0]]}>>
440+
//
441+
// As a LinearLayout:
442+
// offset = [[1, 0], [2, 0], [4, 0], [8, 0], [16, 0], [0, 1], [8, 2],
443+
// [16, 4], [0, 8], [0, 16], [0, 32], [0, 64], [0, 128], [32, 0]]
444+
// block = [[0, 256], [0, 512], [0, 1024], [0, 0]]
445+
//
446+
// The issue is that the data from the CTA1 should be next to that of the
447+
// first part of the instruction. Now, the max instruction size is 128x256,
448+
// so the layout we should use is
449+
// offset = [[1, 0], [2, 0], [4, 0], [8, 0], [16, 0], [0, 1], [8, 2],
450+
// [16, 4], [0, 8], [0, 16], [0, 32], [0, 64], [0, 256], [32, 0]]
451+
// block = [[0, 128], [0, 512], [0, 1024], [0, 0]]
452+
// (note how we swapped the bases [0, 256] and [0, 128])
453+
// The issue with this layout is that it breaks the invariant that the
454+
// CGALayout splits the CGA tile into contiguous CTA tiles,
455+
// i.e. total_layout = cta_layout * cga_layout.
456+
// This is used all over the place, to the point that for all legacy layouts
457+
// we represent the CGALayout as the `cga_layout` we have to multiply on the
458+
// right.
459+
// We could allow with a bit of effort SharedLinearLayouts that did not
460+
// divide on the right by a CGALayout, but for now we throw a lovely error.
461+
if (nPerCTA > 256)
462+
return emitOpError(
463+
"We don't allow to emit more than one mma instruction along N. "
464+
"Reduce the block or increase the number of warps or CTAs along N");
465+
425466
unsigned retM = retSplit[0];
426467
unsigned retN = retSplit[1];
427468
if (aSplit[0] != retM) {

python/test/gluon/test_core.py

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -583,8 +583,9 @@ def min_shape(swizzling, dim0, dim1, trans):
583583
if M * N // 128 // num_ctas > MAX_ROWS:
584584
N //= (M * N // 128 // num_ctas // MAX_ROWS)
585585

586-
if two_ctas and warps != [8, 1] and (shape_m, shape_n, shape_k) == (2, 4, 1):
587-
pytest.skip("FIXME: Fails with wrong answer. Need to investigate")
586+
if two_ctas and N // ctas_per_cga[1] == 512:
587+
# grep for [Note: numRepN > 1 and two_ctas]
588+
pytest.skip("grep for [Note: numRepN > 1 and two_ctas]")
588589

589590
assert M >= 64, "M must be at least 64 for mmav3 and mmav5"
590591

@@ -671,7 +672,7 @@ def make_2cta_cga_layout(ctas_per_cga, cta_split, cta_order, two_cta_dim):
671672
shared_layout_b = ttgl.NVMMASharedLayout(swizzle_byte_width=swizzling_b, element_bitwidth=bitwidth, rank=2,
672673
transposed=transpose_b, cga_layout=cga_layout_b)
673674
if use_tcgen05:
674-
tmem_shape = (min(M // ctas_per_cga[0], 128), N // ctas_per_cga[1])
675+
tmem_shape = (min(M // ctas_per_cga[0], 128), min(N // ctas_per_cga[1], 256))
675676
acc_layout = TensorMemoryLayout(tmem_shape, col_stride=32 // torch.finfo(acc_dtype).bits,
676677
cta_split_num=tuple(ctas_per_cga), two_ctas=two_ctas)
677678
else:

third_party/nvidia/lib/TritonNVIDIAGPUToLLVM/DotOpToLLVM/MMAHelpers.h

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -157,8 +157,11 @@ class DotOpMmaSmemLoader : public DotOpMmaMemLoader {
157157
assert(to_vector(ll.getOutDimNames()) ==
158158
llvm::to_vector(
159159
ArrayRef<StringAttr>{str_attr("offset"), str_attr("block")}));
160-
int32_t totalOffElems = ll.apply({{dims[0], a}, {dims[1], b}})[0].second;
161-
int32_t smemByteOffsetb8 = totalOffElems * desc.bitwidth / 8;
160+
auto offsetBlock = ll.apply({{dims[0], a}, {dims[1], b}});
161+
int32_t offsetElems = offsetBlock[0].second;
162+
int32_t block = offsetBlock[1].second;
163+
assert(block == 0);
164+
int32_t smemByteOffsetb8 = offsetElems * desc.bitwidth / 8;
162165
auto currDesc = desc.descriptor;
163166
// Take the next 0/1/2/3 bits after the 128b tile
164167
uint32_t mask = (desc.swizzlingByteWidth >> 4) - 1;

third_party/nvidia/lib/TritonNVIDIAGPUToLLVM/DotOpToLLVM/MMAv5.cpp

Lines changed: 9 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -415,10 +415,15 @@ void convertDotImpl(const LLVMTypeConverter &typeConverter,
415415
auto tensorMemAttr =
416416
cast<ttng::TensorMemoryEncodingAttr>(dTensorTy.getEncoding());
417417
unsigned mmaSizeM = tensorMemAttr.getBlockM();
418-
unsigned mmaSizeN = std::min(tensorMemAttr.getBlockN(), 256u);
418+
unsigned mmaSizeN = tensorMemAttr.getBlockN();
419+
// Checked in the verifier
420+
assert(mmaSizeN <= 256 &&
421+
"The maximum size of an MMA instruction is 128x256");
419422
unsigned mmaSizeK = op.mmaSizeK;
420423
int numRepM = ceil<unsigned>(M, mmaSizeM);
421424
int numRepN = ceil<unsigned>(N, mmaSizeN);
425+
assert((!twoCTAs || numRepN == 1) &&
426+
"grep for [Note: numRepN > 1 and two_ctas]");
422427
int numRepK = ceil<unsigned>(K, mmaSizeK);
423428

424429
SmallVector<int64_t> shapeA = op.shapeA;
@@ -439,13 +444,9 @@ void convertDotImpl(const LLVMTypeConverter &typeConverter,
439444

440445
auto isFp4b = op.numBitsPerElementB == 4;
441446
// [Instr shape twoCTAs]
442-
// This division by 2 in 2CTA mode a bit subtle:
443-
// The issue here is that in 2CTA you multiply in one instruction a tensor
444-
// of shape MNK = 256, K, N, and you put it into TMEM of shape 128, K, N*2.
445-
// So to compute the shapePerCTA, on the lhs we can look at the TMEM shape,
446-
// but to compute the shapePerCTA on the rhs, we need to divide by 2.
447-
// Something similar happens when we multiply by 2 the mmaSizeM when creating
448-
// It's a massive code smell tho
447+
// To compute a output tile of [mmaSizeM, mmaSizeN] in 2CTA mode, we load
448+
// an A of size 2 * mmaSizeM x K and a B of size K x mmaSizeN.
449+
// Now, since B is split amongs 2 CTAs along N, we need to divide by 2.
449450
DotOpMmaSmemLoader bLoader = DotOpMmaSmemLoader::build(
450451
loc, rewriter, bTensorTy, baseB, {mmaSizeK, mmaSizeN / (twoCTAs ? 2 : 1)},
451452
1, 5, isFp4b);

0 commit comments

Comments
 (0)