Skip to content

Commit 1239887

Browse files
authored
[AMD] Enable subview for amd rotating shared attribute (#6160)
This PR fixes assert in converter and adds relevant lit test.
1 parent fa8b7bb commit 1239887

File tree

2 files changed

+23
-1
lines changed

2 files changed

+23
-1
lines changed

lib/Conversion/TritonGPUToLLVM/Utility.cpp

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -353,7 +353,8 @@ Value getSmemVecAddr(const LinearLayout &regLayout,
353353
}
354354
} else { // Case 2 -> rank-reduced swizzling
355355
assert(rank >= 2 && "Swizzling only applies to tensors with rank >= 2");
356-
assert(isa<triton::gpu::SwizzledSharedEncodingAttr>(sharedEnc) &&
356+
assert((isa<triton::gpu::SwizzledSharedEncodingAttr,
357+
triton::gpu::AMDRotatingSharedEncodingAttr>(sharedEnc)) &&
357358
"NVMMA layout not supported for sliced tensors");
358359
// We define both tensor offsets and shared memory offsets:
359360
//

test/Conversion/amd/tritongpu_to_llvm.mlir

Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -359,3 +359,24 @@ module attributes {"ttg.target" = "hip:gfx942", "ttg.num-ctas" = 1 : i32, "ttg.n
359359
tt.return
360360
}
361361
}
362+
363+
// -----
364+
365+
// CHECK-LABEL: amd_rotating_subview_shared_layout
366+
#blocked = #ttg.blocked<{sizePerThread = [1, 1], threadsPerWarp = [8, 8], warpsPerCTA = [2, 2], order = [1, 0], CTAsPerCGA = [1, 1], CTASplitNum = [1, 1], CTAOrder = [1, 0]}>
367+
#shared = #ttg.amd_rotating_shared<{vec = 4, perPhase = 1, maxPhase = 16, order = [1, 0]}>
368+
#smem = #ttg.shared_memory
369+
module attributes {"ttg.target" = "hip:gfx942", "ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32, "ttg.threads-per-warp" = 64 : i32} {
370+
tt.func @amd_rotating_subview_shared_layout(%arg0: tensor<64x64xf16, #blocked>) {
371+
%c0_i32 = arith.constant 0 : i32
372+
%c16_i32 = arith.constant 16 : i32
373+
// CHECK-COUNT-16: llvm.store {{.*}} : vector<1xf16>, !llvm.ptr<3>
374+
%0 = ttg.local_alloc %arg0 : (tensor<64x64xf16, #blocked>) -> !ttg.memdesc<64x64xf16, #shared, #smem, mutable>
375+
%1 = ttg.memdesc_subview %0[%c16_i32, %c0_i32] : !ttg.memdesc<64x64xf16, #shared, #smem, mutable> -> !ttg.memdesc<16x64xf16, #shared, #smem, mutable, 64x64>
376+
// CHECK-COUNT-4: llvm.load {{.*}} : !llvm.ptr<3> -> vector<1xf16>
377+
%2 = ttg.local_load %1 : !ttg.memdesc<16x64xf16, #shared, #smem, mutable, 64x64> -> tensor<16x64xf16, #blocked>
378+
// CHECK-COUNT-4: llvm.store {{.*}} : vector<1xf16>, !llvm.ptr<3>
379+
ttg.local_store %2, %1 : tensor<16x64xf16, #blocked> -> !ttg.memdesc<16x64xf16, #shared, #smem, mutable, 64x64>
380+
tt.return
381+
}
382+
}

0 commit comments

Comments
 (0)