Skip to content

Commit 7b29378

Browse files
authored
Remove verification of tmem allocation size (#8782)
This check can fail, because `supportMMA` does not validate that the result of a dot must be able to fit in tmem before using MMAv5. This then results in an assertion failure. Instead of hardcoding this constraint in the optimizer, remove the check and allow the allocator to run out of tmem later.
1 parent 06c0b20 commit 7b29378

File tree

2 files changed

+21
-8
lines changed

2 files changed

+21
-8
lines changed

lib/Dialect/TritonGPU/IR/Types.cpp

Lines changed: 0 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -143,14 +143,6 @@ LogicalResult MemDescType::verify(function_ref<InFlightDiagnostic()> emitError,
143143
<< ll.getOutDimSize(dims[0]) << "x"
144144
<< ll.getOutDimSize(dims[1]);
145145
}
146-
// Note the following holds for both M=64 and M=128 with 2CTA
147-
auto nCol = ll.getInDimSize(StringAttr::get(ctx, "col"));
148-
if (nCol / (enc.getCTASplitM() * enc.getCTASplitN()) >
149-
512 * 32 / bitwidth) {
150-
return emitError() << "nCol / (CTASplitM * CTASplitN) must be less than "
151-
"or equal to 512 * 32 / bitwidth but got "
152-
<< nCol / (enc.getCTASplitM() * enc.getCTASplitN());
153-
}
154146
} else if (auto enc = dyn_cast<SharedEncodingTrait>(encoding)) {
155147
if (memorySpace != SharedMemorySpaceAttr::get(ctx)) {
156148
return emitError()

test/TritonGPU/accelerate-matmul.mlir

Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -743,3 +743,24 @@ module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32, ttg.targ
743743
tt.return %d : tensor<128x128xf32, #blocked3>
744744
}
745745
}
746+
747+
// -----
748+
749+
// We previously asserted that a tmem allocation must fit in the available tmem.
750+
// This would cause an assertion failure if the result matrix was too large.
751+
// Check that we allow the large result in AccelerateMatmul, and leave it to
752+
// the allocator to fail later.
753+
754+
#blocked = #ttg.blocked<{sizePerThread = [1, 1], threadsPerWarp = [1, 32], warpsPerCTA = [1, 4], order = [1, 0]}>
755+
#blocked1 = #ttg.blocked<{sizePerThread = [1, 1], threadsPerWarp = [1, 32], warpsPerCTA = [1, 4], order = [1, 0]}>
756+
#blocked2 = #ttg.blocked<{sizePerThread = [1, 1], threadsPerWarp = [1, 32], warpsPerCTA = [1, 4], order = [1, 0]}>
757+
module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32, ttg.target = "cuda:100", "ttg.threads-per-warp" = 32 : i32} {
758+
// CHECK-LABEL: @res_too_big_for_mmav5
759+
tt.func public @res_too_big_for_mmav5(%a: tensor<1024x16xf32, #blocked2>, %b: tensor<16x128xf32, #blocked1>, %c: tensor<1024x128xf32, #blocked>) -> tensor<1024x128xf32, #blocked> {
760+
%ad = ttg.convert_layout %a : tensor<1024x16xf32, #blocked2> -> tensor<1024x16xf32, #ttg.dot_op<{opIdx = 0, parent = #blocked}>>
761+
%bd = ttg.convert_layout %b : tensor<16x128xf32, #blocked1> -> tensor<16x128xf32, #ttg.dot_op<{opIdx = 1, parent = #blocked}>>
762+
// CHECK: ttng.tc_gen5_mma
763+
%d = tt.dot %ad, %bd, %c, inputPrecision = tf32 : tensor<1024x16xf32, #ttg.dot_op<{opIdx = 0, parent = #blocked}>> * tensor<16x128xf32, #ttg.dot_op<{opIdx = 1, parent = #blocked}>> -> tensor<1024x128xf32, #blocked>
764+
tt.return %d : tensor<1024x128xf32, #blocked>
765+
}
766+
}

0 commit comments

Comments
 (0)