diff --git a/lib/Dialect/TritonGPU/Transforms/ReduceDataDuplication.cpp b/lib/Dialect/TritonGPU/Transforms/ReduceDataDuplication.cpp index deec43f1161c..e49e0b7bed2a 100644 --- a/lib/Dialect/TritonGPU/Transforms/ReduceDataDuplication.cpp +++ b/lib/Dialect/TritonGPU/Transforms/ReduceDataDuplication.cpp @@ -45,6 +45,17 @@ class TritonGPUReduceDataDuplicationPass if (!cvtNeedsSharedMemory(srcType, dstType)) return; auto order = getOrderForMemory(srcType); + auto inputOp = cvtOp.getSrc().getDefiningOp(); + // if input of convert_layout is transOp, actuall order is the order of + // the transOp input. By setting lds order to be the same as input, + // ds_write is more efficient + if (inputOp) { + if (auto transOp = dyn_cast(inputOp)) { + order = getOrderForMemory( + cast(transOp.getSrc().getType())); + } + } + auto sharedMemorySpace = triton::gpu::SharedMemorySpaceAttr::get(srcType.getContext()); auto tmpType = triton::gpu::MemDescType::get( diff --git a/test/Conversion/amd/mfma-shortcut.mlir b/test/Conversion/amd/mfma-shortcut.mlir index 4a496459d0cb..0252722f3c25 100644 --- a/test/Conversion/amd/mfma-shortcut.mlir +++ b/test/Conversion/amd/mfma-shortcut.mlir @@ -108,3 +108,19 @@ module attributes {"ttg.num-warps" = 4 : i32, "ttg.threads-per-warp" = 64 : i32} tt.return } } + + +// ----- + +#mma = #ttg.amd_mfma<{version = 4, warpsPerCTA = [8, 1], instrShape = [16, 16], isTransposed = true}> +#linear = #ttg.linear<{register = [[1, 0], [2, 0], [16, 0]], lane = [[0, 1], [0, 2], [0, 4], [0, 8], [4, 0], [8, 0]], warp = [[0, 16], [0, 32], [0, 64]], block = []}> +module attributes {"ttg.num-warps" = 8 : i32, "ttg.threads-per-warp" = 64 : i32} { + // GFX950-LABEL: mfma_dotop_lds_layout_order + tt.func public @mfma_dotop_lds_layout_order(%arg0: tensor<128x32xbf16, #mma>) { + %1 = tt.trans %arg0 {order = array} : tensor<128x32xbf16, #mma> -> tensor<32x128xbf16, #linear> + // GFX950-COUNT-2: llvm.store + // GFX950-COUNT-8: rocdl.ds.read.tr16.b64 + %2 = ttg.convert_layout %1 : tensor<32x128xbf16, #linear> -> tensor<32x128xbf16, #ttg.dot_op<{opIdx = 0, parent = #mma, kWidth = 8}>> + tt.return + } +}