@@ -26,6 +26,43 @@ static int __builtin_ctz(unsigned x) {
2626
2727#endif
2828
29+ // This reverts #5645, because it introduced increased register pressure in AMD
30+ // backend.
31+ // TODO: remove when new implementation performance reaches target level
32+ namespace {
33+
34+ LinearLayout getRegToSharedLayout (MLIRContext *ctx, ArrayRef<int64_t > shape,
35+ LinearLayout regLayout, Attribute dstEnc,
36+ int elemBitWidth) {
37+ StringAttr kBlock = StringAttr::get (ctx, (" block" ));
38+ int rank = shape.size ();
39+
40+ LinearLayout sharedLayout = triton::gpu::toLinearLayout (shape, dstEnc);
41+ auto sharedOrder = triton::gpu::getOrder (dstEnc);
42+
43+ // sharedLayout's in-dims are currently (offset, block). Reshape to
44+ // (offsetX1, offsetX2, ..., block) so that we can apply the N-dimensional
45+ // shmem strides. (The offsetX's appear in minor-to-major order.)
46+ auto sharedLegacy = cast<triton::gpu::SwizzledSharedEncodingAttr>(dstEnc);
47+ SmallVector<std::pair<StringAttr, int32_t >> multiDimSharedSize;
48+ for (int i = 0 ; i < rank; i++) {
49+ int dim = sharedOrder[i];
50+ int64_t size = std::max (
51+ int64_t {1 },
52+ shape[dim] / sharedLegacy.getCTALayout ().getCTASplitNum ()[dim]);
53+ multiDimSharedSize.push_back (
54+ {StringAttr::get (ctx, (" offset" + std::to_string (dim))), size});
55+ }
56+ multiDimSharedSize.push_back ({kBlock , sharedLayout.getInDimSize (kBlock )});
57+ sharedLayout = sharedLayout.reshapeIns (multiDimSharedSize);
58+
59+ // regToSharedLayout maps from (register, lane, warp, block) to (offsetX1,
60+ // ..., offsetXN, block), where the offsetX's are in minor-to-major order.
61+ return regLayout.invertAndCompose (sharedLayout);
62+ }
63+
64+ } // namespace
65+
2966namespace mlir {
3067
3168namespace triton ::gpu {
@@ -251,6 +288,27 @@ Value getSmemVecAddr(const LinearLayout ®Layout,
251288 {kWarp , warpId},
252289 {kBlock , blockId}})[0 ]
253290 .second ;
291+ // This reverts #5645, because it introduced increased register pressure in
292+ // AMD backend.
293+ // TODO: remove when new implementation performance reaches target level
294+ if (auto swizzledSharedEnc =
295+ mlir::dyn_cast<triton::gpu::SwizzledSharedEncodingAttr>(
296+ sharedEnc)) {
297+ auto regToSharedLayout =
298+ getRegToSharedLayout (ctx, shape, regLayout, swizzledSharedEnc,
299+ elemLlvmTy.getIntOrFloatBitWidth ());
300+ auto smemOrder = swizzledSharedEnc.getOrder ();
301+ smemOffsets = llvm::to_vector (llvm::drop_end (llvm::make_second_range (
302+ applyLinearLayout (loc, rewriter, regToSharedLayout,
303+ {{kRegister , regId},
304+ {kLane , laneId},
305+ {kWarp , warpId},
306+ {kBlock , b.i32_val (0 )}}))));
307+ // Reorder strides according to `order`. This way they match the
308+ // multi-dimensional offsets in regToSharedLayout.
309+ smemOffset = dot (rewriter, loc, smemOffsets,
310+ applyPermutation (smemStrides, smemOrder));
311+ }
254312 } else { // Case 2 -> rank-reduced swizzling
255313 assert (rank >= 2 && " Swizzling only applies to tensors with rank >= 2" );
256314 assert (isa<triton::gpu::SwizzledSharedEncodingAttr>(sharedEnc) &&
0 commit comments