Skip to content

Commit 25355fc

Browse files
authored
[TMA] Fix assignment of unswizzled layouts from rank-reducing loads (#6362)
The code was expecting that we needed to rank-reduce the layout, but since we're propagating it backward from the load we're actually increasing the rank of the encoding to match the descriptor rank.
1 parent 0315d72 commit 25355fc

File tree

2 files changed

+29
-2
lines changed

2 files changed

+29
-2
lines changed

include/triton/Dialect/TritonNvidiaGPU/Transforms/TMAUtilities.h

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -81,8 +81,12 @@ updateEncodingForShape(Operation *op, gpu::SharedEncodingTrait encoding,
8181
return swizEnc;
8282

8383
auto rank = tensorType.getRank();
84-
SmallVector<unsigned> order(
85-
swizEnc.getOrder().drop_front(swizEnc.getOrder().size() - rank));
84+
auto oldOrder = swizEnc.getOrder();
85+
assert(oldOrder.size() <= rank);
86+
SmallVector<unsigned> order;
87+
for (int i = 0; i + oldOrder.size() < rank; ++i)
88+
order.push_back(rank - i - 1);
89+
order.append(oldOrder.begin(), oldOrder.end());
8690
auto newCtaEnc = updateCTALayoutForShape(ctaLayout, tensorType.getShape());
8791
return gpu::SwizzledSharedEncodingAttr::get(
8892
ctx, swizEnc.getVec(), swizEnc.getPerPhase(), swizEnc.getMaxPhase(),

test/TritonNvidiaGPU/optimize_descriptor_encoding.mlir

Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -40,3 +40,26 @@ tt.func public @tma_scatter(%arg0: !tt.ptr<i8> {tt.divisibility = 16 : i32}, %ar
4040
tt.return
4141
}
4242
}
43+
44+
// -----
45+
46+
#blocked = #ttg.blocked<{sizePerThread = [1, 1], threadsPerWarp = [1, 32], warpsPerCTA = [4, 1], order = [1, 0]}>
47+
#shared = #ttg.swizzled_shared<{vec = 1, perPhase = 1, maxPhase = 1, order = [1, 0]}>
48+
#smem = #ttg.shared_memory
49+
50+
module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32, ttg.target = "cuda:100"} {
51+
// CHECK-DAG: #[[BLOCKED:.*]] = #ttg.blocked<{sizePerThread = [1, 1], threadsPerWarp = [1, 32], warpsPerCTA = [4, 1], order = [1, 0]}>
52+
// CHECK-DAG: #[[SWIZZLE_3D:.*]] = #ttg.swizzled_shared<{vec = 1, perPhase = 1, maxPhase = 1, order = [2, 1, 0]}>
53+
// CHECK-DAG: #[[SWIZZLE_2D:.*]] = #ttg.swizzled_shared<{vec = 1, perPhase = 1, maxPhase = 1, order = [1, 0]}>
54+
tt.func public @tma_scatter(%arg0: !tt.ptr<f32>, %arg1: i32, %arg2: i32, %arg3: i64, %arg4: i64) {
55+
// CHECK: tt.make_tensor_descriptor {{.*}} : <f32>, <tensor<1x256x32xf32, #[[SWIZZLE_3D]]>>
56+
// CHECK: %[[LOAD:.*]] = tt.descriptor_load {{.*}} : !tt.tensordesc<tensor<1x256x32xf32, #[[SWIZZLE_3D]]>> -> tensor<256x32xf32, #[[BLOCKED]]>
57+
// CHECK: ttg.local_alloc %[[LOAD]] : (tensor<256x32xf32, #[[BLOCKED]]>) -> !ttg.memdesc<256x32xf32, #[[SWIZZLE_2D]], #smem>
58+
%c1_i32 = arith.constant 1 : i32
59+
%c1_i64 = arith.constant 1 : i64
60+
%0 = tt.make_tensor_descriptor %arg0, [%c1_i32, %arg1, %arg2], [%arg3, %arg4, %c1_i64] : <f32>, <tensor<1x256x32xf32>>
61+
%1 = tt.descriptor_load %0[%c1_i32, %c1_i32, %c1_i32] : !tt.tensordesc<tensor<1x256x32xf32>> -> tensor<256x32xf32, #blocked>
62+
%2 = ttg.local_alloc %1 : (tensor<256x32xf32, #blocked>) -> !ttg.memdesc<256x32xf32, #shared, #smem>
63+
tt.return
64+
}
65+
}

0 commit comments

Comments
 (0)