Skip to content

Commit f35f070

Browse files
authored
Fix the parser of MemDesc (#6117)
I have discussed this bug with @ThomasRaoux in #6113. The parser of MemDesc ignores the allocShape during parsing. This PR fixes the bug. <!--- The core Triton is a small number of people, and we receive many PRs (thank you!). To help us review your code more quickly, **if you are a new contributor (less than 3 PRs merged) we ask that you complete the following tasks and include the filled-out checklist in your PR description.** Complete the following tasks before sending your PR, and replace `[ ]` with `[x]` to indicate you have done them. --> # New contributor declaration - [ x] I am not making a trivial change, such as fixing a typo in a comment. - [ x] I have written a PR description following these [rules](https://cbea.ms/git-commit/#why-not-how). - [ x] I have run `pre-commit run --from-ref origin/main --to-ref HEAD`. - Select one of the following. - [ ] I have added tests. - `/test` for `lit` tests - `/unittest` for C++ tests - `/python/test` for end-to-end tests - [ ] This PR does not need a test because `FILL THIS IN`. - Select one of the following. - [x] I have not added any `lit` tests. - [ ] The `lit` tests I have added follow these [best practices](https://mlir.llvm.org/getting_started/TestingGuide/#filecheck-best-practices), including the "tests should be minimal" section. (Usually running Python code and using the instructions it generates is not minimal.)
1 parent 318af1a commit f35f070

File tree

4 files changed

+19
-9
lines changed

4 files changed

+19
-9
lines changed

lib/Dialect/TritonGPU/IR/Types.cpp

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -52,6 +52,10 @@ Type MemDescType::parse(AsmParser &parser) {
5252
if (parser.parseGreater())
5353
return Type();
5454

55+
if (allocShape.size() > 0)
56+
return MemDescType::get(parser.getContext(), dimensions, elementType,
57+
encoding, memorySpace, mutableMemory, allocShape);
58+
5559
return MemDescType::get(parser.getContext(), dimensions, elementType,
5660
encoding, memorySpace, mutableMemory, dimensions);
5761
}

test/TritonGPU/amd/amd-reorder-instructions.mlir

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -143,7 +143,7 @@ module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32, "ttg.thr
143143
ttg.local_store %14, %17 : tensor<128x32xf16, #blocked1> -> !ttg.memdesc<128x32xf16, #shared, #smem, mutable, 1x128x32>
144144
%18 = ttg.memdesc_subview %11[%c0_i32, %c0_i32, %c0_i32] : !ttg.memdesc<1x32x128xf16, #shared1, #smem, mutable> -> !ttg.memdesc<32x128xf16, #shared1, #smem, mutable, 1x32x128>
145145
ttg.local_store %16, %18 : tensor<32x128xf16, #blocked> -> !ttg.memdesc<32x128xf16, #shared1, #smem, mutable, 1x32x128>
146-
%19:6 = scf.for %arg5 = %arg0 to %arg1 step %arg2 iter_args(%arg6 = %4, %arg7 = %9, %arg8 = %cst_2, %arg9 = %c0_i32, %arg10 = %17, %arg11 = %18) -> (tensor<128x32x!tt.ptr<f16>, #blocked1>, tensor<32x128x!tt.ptr<f16>, #blocked>, tensor<128x128xf32, #mma>, i32, !ttg.memdesc<128x32xf16, #shared, #smem, mutable>, !ttg.memdesc<32x128xf16, #shared1, #smem, mutable>) {
146+
%19:6 = scf.for %arg5 = %arg0 to %arg1 step %arg2 iter_args(%arg6 = %4, %arg7 = %9, %arg8 = %cst_2, %arg9 = %c0_i32, %arg10 = %17, %arg11 = %18) -> (tensor<128x32x!tt.ptr<f16>, #blocked1>, tensor<32x128x!tt.ptr<f16>, #blocked>, tensor<128x128xf32, #mma>, i32, !ttg.memdesc<128x32xf16, #shared, #smem, mutable, 1x128x32>, !ttg.memdesc<32x128xf16, #shared1, #smem, mutable, 1x32x128>) {
147147
%20 = arith.subi %arg1, %arg2 : index
148148
%21 = arith.cmpi slt, %arg5, %20 : index
149149
%22 = ttg.local_load %arg10 : !ttg.memdesc<128x32xf16, #shared, #smem, mutable, 1x128x32> -> tensor<128x32xf16, #ttg.dot_op<{opIdx = 0, parent = #mma, kWidth = 2}>>
@@ -163,7 +163,7 @@ module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32, "ttg.thr
163163
ttg.local_store %29, %35 : tensor<128x32xf16, #blocked1> -> !ttg.memdesc<128x32xf16, #shared, #smem, mutable, 1x128x32>
164164
%36 = ttg.memdesc_subview %11[%34, %c0_i32, %c0_i32] : !ttg.memdesc<1x32x128xf16, #shared1, #smem, mutable> -> !ttg.memdesc<32x128xf16, #shared1, #smem, mutable, 1x32x128>
165165
ttg.local_store %31, %36 : tensor<32x128xf16, #blocked> -> !ttg.memdesc<32x128xf16, #shared1, #smem, mutable, 1x32x128>
166-
scf.yield %26, %27, %25, %34, %35, %36 : tensor<128x32x!tt.ptr<f16>, #blocked1>, tensor<32x128x!tt.ptr<f16>, #blocked>, tensor<128x128xf32, #mma>, i32, !ttg.memdesc<128x32xf16, #shared, #smem, mutable>, !ttg.memdesc<32x128xf16, #shared1, #smem, mutable>
166+
scf.yield %26, %27, %25, %34, %35, %36 : tensor<128x32x!tt.ptr<f16>, #blocked1>, tensor<32x128x!tt.ptr<f16>, #blocked>, tensor<128x128xf32, #mma>, i32, !ttg.memdesc<128x32xf16, #shared, #smem, mutable, 1x128x32>, !ttg.memdesc<32x128xf16, #shared1, #smem, mutable, 1x32x128>
167167
}
168168
ttg.local_dealloc %10 : !ttg.memdesc<1x128x32xf16, #shared, #smem, mutable>
169169
ttg.local_dealloc %11 : !ttg.memdesc<1x32x128xf16, #shared1, #smem, mutable>
@@ -242,7 +242,7 @@ module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32, "ttg.thr
242242
ttg.local_store %14, %25 : tensor<128x32xf16, #blocked1> -> !ttg.memdesc<128x32xf16, #shared, #smem, mutable, 1x128x32>
243243
%26 = ttg.memdesc_subview %11[%c0_i32, %c0_i32, %c0_i32] : !ttg.memdesc<2x32x128xf16, #shared1, #smem, mutable> -> !ttg.memdesc<32x128xf16, #shared1, #smem, mutable, 1x32x128>
244244
ttg.local_store %16, %26 : tensor<32x128xf16, #blocked> -> !ttg.memdesc<32x128xf16, #shared1, #smem, mutable, 1x32x128>
245-
%27:8 = scf.for %arg5 = %arg0 to %arg1 step %arg2 iter_args(%arg6 = %19, %arg7 = %20, %arg8 = %cst_2, %arg9 = %c0_i32, %arg10 = %25, %arg11 = %26, %arg12 = %22, %arg13 = %24) -> (tensor<128x32x!tt.ptr<f16>, #blocked1>, tensor<32x128x!tt.ptr<f16>, #blocked>, tensor<128x128xf32, #mma>, i32, !ttg.memdesc<128x32xf16, #shared, #smem, mutable>, !ttg.memdesc<32x128xf16, #shared1, #smem, mutable>, tensor<128x32xf16, #blocked1>, tensor<32x128xf16, #blocked>) {
245+
%27:8 = scf.for %arg5 = %arg0 to %arg1 step %arg2 iter_args(%arg6 = %19, %arg7 = %20, %arg8 = %cst_2, %arg9 = %c0_i32, %arg10 = %25, %arg11 = %26, %arg12 = %22, %arg13 = %24) -> (tensor<128x32x!tt.ptr<f16>, #blocked1>, tensor<32x128x!tt.ptr<f16>, #blocked>, tensor<128x128xf32, #mma>, i32, !ttg.memdesc<128x32xf16, #shared, #smem, mutable, 1x128x32>, !ttg.memdesc<32x128xf16, #shared1, #smem, mutable, 1x32x128>, tensor<128x32xf16, #blocked1>, tensor<32x128xf16, #blocked>) {
246246
%28 = arith.muli %arg2, %c2 : index
247247
%29 = arith.subi %arg1, %28 : index
248248
%30 = arith.cmpi slt, %arg5, %29 : index
@@ -263,7 +263,7 @@ module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32, "ttg.thr
263263
ttg.local_store %arg12, %44 : tensor<128x32xf16, #blocked1> -> !ttg.memdesc<128x32xf16, #shared, #smem, mutable, 1x128x32>
264264
%45 = ttg.memdesc_subview %11[%43, %c0_i32, %c0_i32] : !ttg.memdesc<2x32x128xf16, #shared1, #smem, mutable> -> !ttg.memdesc<32x128xf16, #shared1, #smem, mutable, 1x32x128>
265265
ttg.local_store %arg13, %45 : tensor<32x128xf16, #blocked> -> !ttg.memdesc<32x128xf16, #shared1, #smem, mutable, 1x32x128>
266-
scf.yield %35, %36, %34, %43, %44, %45, %38, %40 : tensor<128x32x!tt.ptr<f16>, #blocked1>, tensor<32x128x!tt.ptr<f16>, #blocked>, tensor<128x128xf32, #mma>, i32, !ttg.memdesc<128x32xf16, #shared, #smem, mutable>, !ttg.memdesc<32x128xf16, #shared1, #smem, mutable>, tensor<128x32xf16, #blocked1>, tensor<32x128xf16, #blocked>
266+
scf.yield %35, %36, %34, %43, %44, %45, %38, %40 : tensor<128x32x!tt.ptr<f16>, #blocked1>, tensor<32x128x!tt.ptr<f16>, #blocked>, tensor<128x128xf32, #mma>, i32, !ttg.memdesc<128x32xf16, #shared, #smem, mutable, 1x128x32>, !ttg.memdesc<32x128xf16, #shared1, #smem, mutable, 1x32x128>, tensor<128x32xf16, #blocked1>, tensor<32x128xf16, #blocked>
267267
}
268268
ttg.local_dealloc %10 : !ttg.memdesc<2x128x32xf16, #shared, #smem, mutable>
269269
ttg.local_dealloc %11 : !ttg.memdesc<2x32x128xf16, #shared1, #smem, mutable>
@@ -335,7 +335,7 @@ module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32, "ttg.thr
335335
ttg.local_store %8, %17 : tensor<16x16xf16, #blocked1> -> !ttg.memdesc<16x16xf16, #shared2, #smem, mutable, 1x16x16>
336336
%18 = ttg.memdesc_subview %1[%c0_i32, %c0_i32, %c0_i32] : !ttg.memdesc<1x16x16xf16, #shared2, #smem, mutable> -> !ttg.memdesc<16x16xf16, #shared2, #smem, mutable, 1x16x16>
337337
ttg.local_store %14, %18 : tensor<16x16xf16, #blocked> -> !ttg.memdesc<16x16xf16, #shared2, #smem, mutable, 1x16x16>
338-
%19:7 = scf.for %arg6 = %c0 to %arg1 step %c1 iter_args(%arg7 = %cst, %arg8 = %arg2, %arg9 = %6, %arg10 = %c0_i32, %arg11 = %17, %arg12 = %18, %arg13 = %16) -> (tensor<16x16xf32, #mma>, tensor<16x16x!tt.ptr<f16>, #blocked1>, tensor<16x!tt.ptr<i64>, #ttg.slice<{dim = 1, parent = #blocked}>>, i32, !ttg.memdesc<16x16xf16, #shared2, #smem, mutable>, !ttg.memdesc<16x16xf16, #shared2, #smem, mutable>, tensor<16xi64, #ttg.slice<{dim = 1, parent = #blocked}>>) {
338+
%19:7 = scf.for %arg6 = %c0 to %arg1 step %c1 iter_args(%arg7 = %cst, %arg8 = %arg2, %arg9 = %6, %arg10 = %c0_i32, %arg11 = %17, %arg12 = %18, %arg13 = %16) -> (tensor<16x16xf32, #mma>, tensor<16x16x!tt.ptr<f16>, #blocked1>, tensor<16x!tt.ptr<i64>, #ttg.slice<{dim = 1, parent = #blocked}>>, i32, !ttg.memdesc<16x16xf16, #shared2, #smem, mutable, 1x16x16>, !ttg.memdesc<16x16xf16, #shared2, #smem, mutable, 1x16x16>, tensor<16xi64, #ttg.slice<{dim = 1, parent = #blocked}>>) {
339339
%20 = arith.subi %arg1, %c2 : index
340340
%21 = arith.cmpi slt, %arg6, %20 : index
341341
%22 = arith.subi %arg1, %c1 : index
@@ -362,7 +362,7 @@ module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32, "ttg.thr
362362
ttg.local_store %30, %42 : tensor<16x16xf16, #blocked1> -> !ttg.memdesc<16x16xf16, #shared2, #smem, mutable, 1x16x16>
363363
%43 = ttg.memdesc_subview %1[%41, %c0_i32, %c0_i32] : !ttg.memdesc<1x16x16xf16, #shared2, #smem, mutable> -> !ttg.memdesc<16x16xf16, #shared2, #smem, mutable, 1x16x16>
364364
ttg.local_store %36, %43 : tensor<16x16xf16, #blocked> -> !ttg.memdesc<16x16xf16, #shared2, #smem, mutable, 1x16x16>
365-
scf.yield %26, %27, %28, %41, %42, %43, %38 : tensor<16x16xf32, #mma>, tensor<16x16x!tt.ptr<f16>, #blocked1>, tensor<16x!tt.ptr<i64>, #ttg.slice<{dim = 1, parent = #blocked}>>, i32, !ttg.memdesc<16x16xf16, #shared2, #smem, mutable>, !ttg.memdesc<16x16xf16, #shared2, #smem, mutable>, tensor<16xi64, #ttg.slice<{dim = 1, parent = #blocked}>>
365+
scf.yield %26, %27, %28, %41, %42, %43, %38 : tensor<16x16xf32, #mma>, tensor<16x16x!tt.ptr<f16>, #blocked1>, tensor<16x!tt.ptr<i64>, #ttg.slice<{dim = 1, parent = #blocked}>>, i32, !ttg.memdesc<16x16xf16, #shared2, #smem, mutable, 1x16x16>, !ttg.memdesc<16x16xf16, #shared2, #smem, mutable, 1x16x16>, tensor<16xi64, #ttg.slice<{dim = 1, parent = #blocked}>>
366366
}
367367
ttg.local_dealloc %0 : !ttg.memdesc<1x16x16xf16, #shared2, #smem, mutable>
368368
ttg.local_dealloc %1 : !ttg.memdesc<1x16x16xf16, #shared2, #smem, mutable>

test/TritonGPU/dot-operands.mlir

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -57,7 +57,7 @@ module attributes {"ttg.target" = "cuda:90", "ttg.num-ctas" = 1 : i32, "ttg.num-
5757
#blocked1 = #ttg.blocked<{sizePerThread = [1, 16], threadsPerWarp = [1, 32], warpsPerCTA = [1, 4], order = [0, 1]}>
5858
#mma = #ttg.nvidia_mma<{versionMajor = 2, versionMinor = 0, warpsPerCTA = [4, 1], instrShape = [16, 8]}>
5959
#smem = #ttg.shared_memory
60-
module attributes {"ttg.target" = "cuda:90", "ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32, "ttg.threads-per-warp" = 32 : i32} {
60+
module attributes {"ttg.target" = "cuda:80", "ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32, "ttg.threads-per-warp" = 32 : i32} {
6161
// CHECK-LABEL: mmav2_reorder_transpose
6262
// CHECK: ttg.local_alloc
6363
// CHECK: ttg.memdesc_trans
@@ -118,8 +118,8 @@ module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32, ttg.targ
118118
%acc_tm: !ttg.memdesc<128x128xf32, #tmem, #ttng.tensor_memory>
119119
) attributes {noinline = false} {
120120
%true = arith.constant true
121-
%A_la = ttg.local_alloc : () -> !ttg.memdesc<2x512xi8, #shared1, #smem, mutable>
122-
%B_la = ttg.local_alloc : () -> !ttg.memdesc<2x512xi8, #shared1, #smem, mutable>
121+
%A_la = ttg.local_alloc : () -> !ttg.memdesc<2x512xi8, #shared1, #smem, mutable, 3x2x512>
122+
%B_la = ttg.local_alloc : () -> !ttg.memdesc<2x512xi8, #shared1, #smem, mutable, 3x2x512>
123123
%A_ll = ttg.local_load %A_la : !ttg.memdesc<2x512xi8, #shared1, #smem, mutable, 3x2x512> -> tensor<2x512xi8, #blocked4>
124124
%B_ll = ttg.local_load %B_la : !ttg.memdesc<2x512xi8, #shared1, #smem, mutable, 3x2x512> -> tensor<2x512xi8, #blocked4>
125125
%A_r = tt.reshape %A_ll : tensor<2x512xi8, #blocked4> -> tensor<2x1x32x4x4xi8, #blocked8>

test/TritonGPU/ops.mlir

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -58,6 +58,12 @@ module attributes {"ttg.target" = "cuda:0", "ttg.num-ctas" = 1 : i32, "ttg.num-w
5858
tt.func @memdesc(%d : !ttg.memdesc<1x64x16xf16, #shared0, #smem>) {
5959
tt.return
6060
}
61+
62+
// CHECK-LABEL: memdesc_with_alloc_shape
63+
// CHECK-SAME: !ttg.memdesc<64x16xf16, #{{.+}}, mutable, 2x64x16>
64+
tt.func @memdesc_with_alloc_shape(%d : !ttg.memdesc<64x16xf16, #shared0, #smem, mutable, 2x64x16>){
65+
tt.return
66+
}
6167
}
6268

6369
// -----

0 commit comments

Comments
 (0)