From 68d4ca28c608efbdc11f1d7374564e8ad098619e Mon Sep 17 00:00:00 2001 From: the-strawhat Date: Wed, 15 Oct 2025 22:07:45 +0800 Subject: [PATCH] [AMD] Implement implicit layout conversion for DotOp to enable direct GMEM to reg loads --- bin/RegisterTritonDialects.h | 1 + .../amd/amd-implicit-convert-layout.mlir | 237 +++++++++++++++ third_party/amd/backend/compiler.py | 2 + .../include/TritonAMDGPUTransforms/Passes.td | 10 + .../lib/TritonAMDGPUTransforms/CMakeLists.txt | 1 + .../ImplicitConvertLayout.cpp | 274 ++++++++++++++++++ third_party/amd/python/triton_amd.cc | 1 + 7 files changed, 526 insertions(+) create mode 100644 test/TritonGPU/amd/amd-implicit-convert-layout.mlir create mode 100644 third_party/amd/lib/TritonAMDGPUTransforms/ImplicitConvertLayout.cpp diff --git a/bin/RegisterTritonDialects.h b/bin/RegisterTritonDialects.h index 48049f02682f..1db1bb489031 100644 --- a/bin/RegisterTritonDialects.h +++ b/bin/RegisterTritonDialects.h @@ -103,6 +103,7 @@ inline void registerTritonDialects(mlir::DialectRegistry ®istry) { mlir::arith::registerConvertArithToLLVMInterface(registry); // TritonAMDGPUTransforms passes + mlir::registerTritonAMDGPUImplicitConvertLayout(); mlir::registerTritonAMDGPUAccelerateMatmul(); mlir::registerTritonAMDGPUOptimizeEpilogue(); mlir::registerTritonAMDGPUHoistLayoutConversions(); diff --git a/test/TritonGPU/amd/amd-implicit-convert-layout.mlir b/test/TritonGPU/amd/amd-implicit-convert-layout.mlir new file mode 100644 index 000000000000..c4539dc332e0 --- /dev/null +++ b/test/TritonGPU/amd/amd-implicit-convert-layout.mlir @@ -0,0 +1,237 @@ +// RUN: triton-opt %s -split-input-file --tritonamdgpu-implicit-convert-layout --tritongpu-remove-layout-conversions | FileCheck %s + +#blocked = #ttg.blocked<{sizePerThread = [1, 1, 1, 8], threadsPerWarp = [1, 4, 16, 1], warpsPerCTA = [1, 4, 1, 1], order = [3, 2, 1, 0]}> +#blocked1 = #ttg.blocked<{sizePerThread = [1, 8], threadsPerWarp = [4, 16], warpsPerCTA = [4, 1], order = [1, 0]}> +#blocked2 = #ttg.blocked<{sizePerThread = [1], threadsPerWarp = [64], warpsPerCTA = [4], order = [0]}> +#blocked3 = #ttg.blocked<{sizePerThread = [1, 1, 8], threadsPerWarp = [1, 32, 2], warpsPerCTA = [1, 4, 1], order = [2, 1, 0]}> +#linear = #ttg.linear<{register = [[0, 1], [0, 2], [4, 0], [8, 0]], lane = [[0, 0], [0, 0], [0, 0], [0, 0], [0, 4], [0, 8]], warp = [[1, 0], [2, 0]], block = []}> +#linear1 = #ttg.linear<{register = [[0, 0, 0, 1], [0, 0, 0, 2], [0, 2, 0, 0], [0, 4, 0, 0], [0, 8, 0, 0], [4, 0, 0, 0], [8, 0, 0, 0]], lane = [[0, 0, 1, 0], [0, 0, 2, 0], [0, 0, 4, 0], [0, 0, 8, 0], [0, 0, 0, 4], [0, 1, 0, 0]], warp = [[1, 0, 0, 0], [2, 0, 0, 0]], block = []}> +#linear2 = #ttg.linear<{register = [[0, 1, 0, 0], [0, 2, 0, 0], [2, 0, 0, 0], [4, 0, 0, 0], [8, 0, 0, 0], [0, 0, 4, 0], [0, 0, 8, 0]], lane = [[0, 0, 0, 1], [0, 0, 0, 2], [0, 0, 0, 4], [0, 0, 0, 8], [0, 4, 0, 0], [1, 0, 0, 0]], warp = [[0, 0, 1, 0], [0, 0, 2, 0]], block = []}> +#linear3 = #ttg.linear<{register = [[0, 0, 1], [0, 0, 2], [1, 0, 0], [2, 0, 0], [4, 0, 0], [8, 0, 0], [0, 64, 0]], lane = [[0, 1, 0], [0, 2, 0], [0, 4, 0], [0, 8, 0], [0, 0, 4], [0, 0, 8]], warp = [[0, 16, 0], [0, 32, 0]], block = []}> +#linear4 = #ttg.linear<{register = [[0, 1, 0], [0, 2, 0], [1, 0, 0], [2, 0, 0], [4, 0, 0], [8, 0, 0], [0, 0, 64]], lane = [[0, 0, 1], [0, 0, 2], [0, 0, 4], [0, 0, 8], [0, 4, 0], [0, 8, 0]], warp = [[0, 0, 16], [0, 0, 32]], block = []}> +#mma = #ttg.amd_mfma<{version = 3, warpsPerCTA = [1, 4], instrShape = [16, 16, 16], isTransposed = true}> +#shared = #ttg.swizzled_shared<{vec = 4, perPhase = 1, maxPhase = 16, order = [1, 0]}> +#shared1 = #ttg.swizzled_shared<{vec = 1, perPhase = 1, maxPhase = 1, order = [0, 1]}> +#smem = #ttg.shared_memory +module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32, ttg.target = "hip:gfx942", "ttg.threads-per-warp" = 64 : i32} { + tt.func public @_paged_attn_decode_v2_w_dot_kernel_reshape_noloop_qk(%arg0: !tt.ptr {tt.divisibility = 16 : i32, tt.pointer_range = 32 : i32}, %arg1: !tt.ptr {tt.divisibility = 16 : i32, tt.pointer_range = 32 : i32}, %arg2: !tt.ptr {tt.divisibility = 16 : i32, tt.pointer_range = 32 : i32}, %arg3: !tt.ptr {tt.divisibility = 16 : i32, tt.pointer_range = 32 : i32}, %arg4: !tt.ptr {tt.divisibility = 16 : i32, tt.pointer_range = 32 : i32}, %arg5: !tt.ptr {tt.divisibility = 16 : i32, tt.pointer_range = 32 : i32}, %arg6: !tt.ptr {tt.divisibility = 16 : i32, tt.pointer_range = 32 : i32}, %arg7: !tt.ptr {tt.divisibility = 16 : i32, tt.pointer_range = 32 : i32}, %arg8: f32, %arg9: f32, %arg10: f32, %arg11: i32 {tt.divisibility = 16 : i32}, %arg12: i32 {tt.divisibility = 16 : i32}, %arg13: i32, %arg14: i32 {tt.divisibility = 16 : i32}, %arg15: i32 {tt.divisibility = 16 : i32}, %arg16: i32 {tt.divisibility = 16 : i32}, %arg17: i32 {tt.divisibility = 16 : i32}, %arg18: i32 {tt.divisibility = 16 : i32}, %arg19: i32 {tt.divisibility = 16 : i32}, %arg20: i32 {tt.divisibility = 16 : i32}, %arg21: i32 {tt.divisibility = 16 : i32}, %arg22: i32 {tt.divisibility = 16 : i32}, %arg23: i32, %arg24: i32 {tt.divisibility = 16 : i32}, %arg25: i32 {tt.divisibility = 16 : i32}, %arg26: i32 {tt.divisibility = 16 : i32}, %arg27: i32 {tt.divisibility = 16 : i32}) attributes {noinline = false} { + %cst = arith.constant dense<8> : tensor<1x1x16x1xi32, #blocked> + %cst_0 = arith.constant dense<8> : tensor<16x1xi32, #mma> + %cst_1 = arith.constant dense<16> : tensor<16x1xi32, #linear> + %cst_2 = arith.constant dense<0.000000e+00> : tensor<16x256xf32, #mma> + %cst_3 = arith.constant dense<0.000000e+00> : tensor<16x128xf32, #mma> + %cst_4 = arith.constant dense<128> : tensor<1x128xi32, #blocked1> + %cst_5 = arith.constant dense<1.44269502> : tensor<16x256xf32, #mma> + %cst_6 = arith.constant dense<0xFF800000> : tensor<16x256xf32, #mma> + %cst_7 = arith.constant dense<8> : tensor<16xi32, #blocked2> + %c15_i32 = arith.constant 15 : i32 + %c8_i32 = arith.constant 8 : i32 + %c16_i32 = arith.constant 16 : i32 + %c256_i32 = arith.constant 256 : i32 + %cst_8 = arith.constant dense<0> : tensor<16xi32, #ttg.slice<{dim = 1, parent = #ttg.slice<{dim = 2, parent = #blocked3}>}>> + %cst_9 = arith.constant dense<0> : tensor<16xi32, #ttg.slice<{dim = 1, parent = #ttg.slice<{dim = 2, parent = #ttg.slice<{dim = 3, parent = #blocked}>}>}>> + %cst_10 = arith.constant dense<8> : tensor<16x1xi32, #blocked1> + %cst_11 = arith.constant dense<128> : tensor<1x128xi32, #mma> + %0 = tt.get_program_id x : i32 + %1 = tt.get_program_id y : i32 + %2 = tt.get_program_id z : i32 + %3 = tt.addptr %arg7, %0 : !tt.ptr, i32 + %4 = tt.load %3 : !tt.ptr + %5 = arith.muli %2, %c256_i32 : i32 + %6 = arith.cmpi sge, %5, %4 : i32 + cf.cond_br %6, ^bb1, ^bb2 + ^bb1: // pred: ^bb0 + tt.return + ^bb2: // pred: ^bb0 + %7 = arith.addi %5, %c256_i32 : i32 + %8 = arith.minsi %7, %4 : i32 + %9 = arith.subi %8, %5 : i32 + %10 = arith.addi %9, %c15_i32 : i32 + %11 = arith.divsi %10, %c16_i32 : i32 + %12 = tt.make_range {end = 16 : i32, start = 0 : i32} : tensor<16xi32, #ttg.slice<{dim = 1, parent = #ttg.slice<{dim = 2, parent = #blocked3}>}>> + %13 = tt.make_range {end = 16 : i32, start = 0 : i32} : tensor<16xi32, #ttg.slice<{dim = 1, parent = #ttg.slice<{dim = 2, parent = #ttg.slice<{dim = 3, parent = #blocked}>}>}>> + %14 = tt.make_range {end = 16 : i32, start = 0 : i32} : tensor<16xi32, #blocked2> + %15 = tt.splat %11 : i32 -> tensor<16xi32, #ttg.slice<{dim = 1, parent = #ttg.slice<{dim = 2, parent = #blocked3}>}>> + %16 = tt.splat %11 : i32 -> tensor<16xi32, #ttg.slice<{dim = 1, parent = #ttg.slice<{dim = 2, parent = #ttg.slice<{dim = 3, parent = #blocked}>}>}>> + %17 = arith.cmpi slt, %12, %15 : tensor<16xi32, #ttg.slice<{dim = 1, parent = #ttg.slice<{dim = 2, parent = #blocked3}>}>> + %18 = arith.cmpi slt, %13, %16 : tensor<16xi32, #ttg.slice<{dim = 1, parent = #ttg.slice<{dim = 2, parent = #ttg.slice<{dim = 3, parent = #blocked}>}>}>> + %19 = arith.select %17, %12, %cst_8 : tensor<16xi1, #ttg.slice<{dim = 1, parent = #ttg.slice<{dim = 2, parent = #blocked3}>}>>, tensor<16xi32, #ttg.slice<{dim = 1, parent = #ttg.slice<{dim = 2, parent = #blocked3}>}>> + %20 = arith.select %18, %13, %cst_9 : tensor<16xi1, #ttg.slice<{dim = 1, parent = #ttg.slice<{dim = 2, parent = #ttg.slice<{dim = 3, parent = #blocked}>}>}>>, tensor<16xi32, #ttg.slice<{dim = 1, parent = #ttg.slice<{dim = 2, parent = #ttg.slice<{dim = 3, parent = #blocked}>}>}>> + %21 = arith.muli %2, %c16_i32 : i32 + %22 = arith.muli %0, %arg27 : i32 + %23 = arith.addi %22, %21 : i32 + %24 = tt.splat %23 : i32 -> tensor<16xi32, #ttg.slice<{dim = 1, parent = #ttg.slice<{dim = 2, parent = #blocked3}>}>> + %25 = tt.splat %23 : i32 -> tensor<16xi32, #ttg.slice<{dim = 1, parent = #ttg.slice<{dim = 2, parent = #ttg.slice<{dim = 3, parent = #blocked}>}>}>> + %26 = arith.addi %19, %24 : tensor<16xi32, #ttg.slice<{dim = 1, parent = #ttg.slice<{dim = 2, parent = #blocked3}>}>> + %27 = arith.addi %20, %25 : tensor<16xi32, #ttg.slice<{dim = 1, parent = #ttg.slice<{dim = 2, parent = #ttg.slice<{dim = 3, parent = #blocked}>}>}>> + %28 = amdgpu.buffer_load %arg6[%26] : tensor<16xi32, #ttg.slice<{dim = 1, parent = #ttg.slice<{dim = 2, parent = #blocked3}>}>> + %29 = amdgpu.buffer_load %arg6[%27] : tensor<16xi32, #ttg.slice<{dim = 1, parent = #ttg.slice<{dim = 2, parent = #ttg.slice<{dim = 3, parent = #blocked}>}>}>> + %30 = arith.muli %0, %arg18 : i32 + %31 = arith.muli %1, %c8_i32 : i32 + %32 = tt.make_range {end = 16 : i32, start = 0 : i32} : tensor<16xi32, #ttg.slice<{dim = 1, parent = #mma}>> + %33 = tt.make_range {end = 16 : i32, start = 0 : i32} : tensor<16xi32, #ttg.slice<{dim = 1, parent = #blocked1}>> + %34 = tt.make_range {end = 16 : i32, start = 0 : i32} : tensor<16xi32, #ttg.slice<{dim = 1, parent = #linear}>> + %35 = tt.expand_dims %32 {axis = 1 : i32} : tensor<16xi32, #ttg.slice<{dim = 1, parent = #mma}>> -> tensor<16x1xi32, #mma> + %36 = tt.expand_dims %33 {axis = 1 : i32} : tensor<16xi32, #ttg.slice<{dim = 1, parent = #blocked1}>> -> tensor<16x1xi32, #blocked1> + %37 = tt.expand_dims %34 {axis = 1 : i32} : tensor<16xi32, #ttg.slice<{dim = 1, parent = #linear}>> -> tensor<16x1xi32, #linear> + %38 = tt.make_range {end = 128 : i32, start = 0 : i32} : tensor<128xi32, #ttg.slice<{dim = 0, parent = #mma}>> + %39 = tt.make_range {end = 128 : i32, start = 0 : i32} : tensor<128xi32, #ttg.slice<{dim = 0, parent = #blocked1}>> + %40 = tt.make_range {end = 128 : i32, start = 0 : i32} : tensor<128xi32, #ttg.slice<{dim = 0, parent = #ttg.slice<{dim = 2, parent = #blocked3}>}>> + %41 = tt.expand_dims %38 {axis = 0 : i32} : tensor<128xi32, #ttg.slice<{dim = 0, parent = #mma}>> -> tensor<1x128xi32, #mma> + %42 = tt.expand_dims %39 {axis = 0 : i32} : tensor<128xi32, #ttg.slice<{dim = 0, parent = #blocked1}>> -> tensor<1x128xi32, #blocked1> + %43 = arith.cmpi slt, %35, %cst_0 : tensor<16x1xi32, #mma> + %44 = arith.cmpi slt, %36, %cst_10 : tensor<16x1xi32, #blocked1> + %45 = arith.cmpi slt, %41, %cst_11 : tensor<1x128xi32, #mma> + %46 = arith.cmpi slt, %42, %cst_4 : tensor<1x128xi32, #blocked1> + %47 = tt.broadcast %43 : tensor<16x1xi1, #mma> -> tensor<16x128xi1, #mma> + %48 = tt.broadcast %44 : tensor<16x1xi1, #blocked1> -> tensor<16x128xi1, #blocked1> + %49 = tt.broadcast %45 : tensor<1x128xi1, #mma> -> tensor<16x128xi1, #mma> + %50 = tt.broadcast %46 : tensor<1x128xi1, #blocked1> -> tensor<16x128xi1, #blocked1> + %51 = arith.andi %47, %49 : tensor<16x128xi1, #mma> + %52 = arith.andi %48, %50 : tensor<16x128xi1, #blocked1> + %53 = arith.muli %31, %arg19 : i32 + %54 = tt.splat %arg19 : i32 -> tensor<16x1xi32, #blocked1> + %55 = arith.muli %36, %54 : tensor<16x1xi32, #blocked1> + %56 = arith.addi %30, %53 : i32 + %57 = tt.broadcast %55 : tensor<16x1xi32, #blocked1> -> tensor<16x128xi32, #blocked1> + %58 = tt.broadcast %42 : tensor<1x128xi32, #blocked1> -> tensor<16x128xi32, #blocked1> + %59 = arith.addi %57, %58 : tensor<16x128xi32, #blocked1> + %60 = tt.splat %56 : i32 -> tensor<16x128xi32, #blocked1> + %61 = arith.addi %60, %59 : tensor<16x128xi32, #blocked1> + %62 = amdgpu.buffer_load %arg3[%61], %52 : tensor<16x128xbf16, #blocked1> + %63 = arith.extf %62 : tensor<16x128xbf16, #blocked1> to tensor<16x128xf32, #blocked1> + %64 = tt.splat %arg8 : f32 -> tensor<16x128xf32, #blocked1> + %65 = arith.mulf %63, %64 : tensor<16x128xf32, #blocked1> + %66 = arith.truncf %65 : tensor<16x128xf32, #blocked1> to tensor<16x128xbf16, #blocked1> + %67 = arith.muli %1, %arg21 : i32 + %68 = tt.make_range {end = 16 : i32, start = 0 : i32} : tensor<16xi32, #ttg.slice<{dim = 0, parent = #linear}>> + %69 = tt.make_range {end = 16 : i32, start = 0 : i32} : tensor<16xi32, #ttg.slice<{dim = 0, parent = #ttg.slice<{dim = 2, parent = #ttg.slice<{dim = 3, parent = #blocked}>}>}>> + %70 = tt.make_range {end = 16 : i32, start = 0 : i32} : tensor<16xi32, #ttg.slice<{dim = 0, parent = #ttg.slice<{dim = 1, parent = #blocked3}>}>> + %71 = tt.make_range {end = 16 : i32, start = 0 : i32} : tensor<16xi32, #ttg.slice<{dim = 0, parent = #ttg.slice<{dim = 1, parent = #ttg.slice<{dim = 3, parent = #blocked}>}>}>> + %72 = tt.expand_dims %68 {axis = 0 : i32} : tensor<16xi32, #ttg.slice<{dim = 0, parent = #linear}>> -> tensor<1x16xi32, #linear> + %73 = tt.make_range {end = 8 : i32, start = 0 : i32} : tensor<8xi32, #ttg.slice<{dim = 0, parent = #ttg.slice<{dim = 1, parent = #ttg.slice<{dim = 2, parent = #blocked}>}>}>> + %74 = tt.splat %21 : i32 -> tensor<16x1xi32, #linear> + %75 = arith.addi %74, %37 : tensor<16x1xi32, #linear> + %76 = arith.muli %75, %cst_1 : tensor<16x1xi32, #linear> + %77 = tt.broadcast %76 : tensor<16x1xi32, #linear> -> tensor<16x16xi32, #linear> + %78 = tt.broadcast %72 : tensor<1x16xi32, #linear> -> tensor<16x16xi32, #linear> + %79 = arith.addi %77, %78 : tensor<16x16xi32, #linear> + %80 = tt.expand_dims %29 {axis = 1 : i32} : tensor<16xi32, #ttg.slice<{dim = 1, parent = #ttg.slice<{dim = 2, parent = #ttg.slice<{dim = 3, parent = #blocked}>}>}>> -> tensor<16x1xi32, #ttg.slice<{dim = 2, parent = #ttg.slice<{dim = 3, parent = #blocked}>}>> + %81 = tt.expand_dims %80 {axis = 2 : i32} : tensor<16x1xi32, #ttg.slice<{dim = 2, parent = #ttg.slice<{dim = 3, parent = #blocked}>}>> -> tensor<16x1x1xi32, #ttg.slice<{dim = 3, parent = #blocked}>> + %82 = tt.expand_dims %81 {axis = 3 : i32} : tensor<16x1x1xi32, #ttg.slice<{dim = 3, parent = #blocked}>> -> tensor<16x1x1x1xi32, #blocked> + %83 = tt.splat %arg20 : i32 -> tensor<16x1x1x1xi32, #blocked> + %84 = arith.muli %82, %83 : tensor<16x1x1x1xi32, #blocked> + %85 = tt.broadcast %84 : tensor<16x1x1x1xi32, #blocked> -> tensor<16x16x1x1xi32, #blocked> + %86 = tt.expand_dims %69 {axis = 0 : i32} : tensor<16xi32, #ttg.slice<{dim = 0, parent = #ttg.slice<{dim = 2, parent = #ttg.slice<{dim = 3, parent = #blocked}>}>}>> -> tensor<1x16xi32, #ttg.slice<{dim = 2, parent = #ttg.slice<{dim = 3, parent = #blocked}>}>> + %87 = tt.expand_dims %86 {axis = 2 : i32} : tensor<1x16xi32, #ttg.slice<{dim = 2, parent = #ttg.slice<{dim = 3, parent = #blocked}>}>> -> tensor<1x16x1xi32, #ttg.slice<{dim = 3, parent = #blocked}>> + %88 = tt.expand_dims %87 {axis = 3 : i32} : tensor<1x16x1xi32, #ttg.slice<{dim = 3, parent = #blocked}>> -> tensor<1x16x1x1xi32, #blocked> + %89 = tt.splat %arg22 : i32 -> tensor<1x16x1x1xi32, #blocked> + %90 = arith.muli %88, %89 : tensor<1x16x1x1xi32, #blocked> + %91 = tt.broadcast %90 : tensor<1x16x1x1xi32, #blocked> -> tensor<16x16x1x1xi32, #blocked> + %92 = arith.addi %85, %91 : tensor<16x16x1x1xi32, #blocked> + %93 = tt.broadcast %92 : tensor<16x16x1x1xi32, #blocked> -> tensor<16x16x16x1xi32, #blocked> + %94 = tt.expand_dims %71 {axis = 0 : i32} : tensor<16xi32, #ttg.slice<{dim = 0, parent = #ttg.slice<{dim = 1, parent = #ttg.slice<{dim = 3, parent = #blocked}>}>}>> -> tensor<1x16xi32, #ttg.slice<{dim = 1, parent = #ttg.slice<{dim = 3, parent = #blocked}>}>> + %95 = tt.expand_dims %94 {axis = 1 : i32} : tensor<1x16xi32, #ttg.slice<{dim = 1, parent = #ttg.slice<{dim = 3, parent = #blocked}>}>> -> tensor<1x1x16xi32, #ttg.slice<{dim = 3, parent = #blocked}>> + %96 = tt.expand_dims %95 {axis = 3 : i32} : tensor<1x1x16xi32, #ttg.slice<{dim = 3, parent = #blocked}>> -> tensor<1x1x16x1xi32, #blocked> + %97 = arith.muli %96, %cst : tensor<1x1x16x1xi32, #blocked> + %98 = tt.broadcast %97 : tensor<1x1x16x1xi32, #blocked> -> tensor<16x16x16x1xi32, #blocked> + %99 = arith.addi %93, %98 : tensor<16x16x16x1xi32, #blocked> + %100 = tt.broadcast %99 : tensor<16x16x16x1xi32, #blocked> -> tensor<16x16x16x8xi32, #blocked> + %101 = tt.expand_dims %73 {axis = 0 : i32} : tensor<8xi32, #ttg.slice<{dim = 0, parent = #ttg.slice<{dim = 1, parent = #ttg.slice<{dim = 2, parent = #blocked}>}>}>> -> tensor<1x8xi32, #ttg.slice<{dim = 1, parent = #ttg.slice<{dim = 2, parent = #blocked}>}>> + %102 = tt.expand_dims %101 {axis = 1 : i32} : tensor<1x8xi32, #ttg.slice<{dim = 1, parent = #ttg.slice<{dim = 2, parent = #blocked}>}>> -> tensor<1x1x8xi32, #ttg.slice<{dim = 2, parent = #blocked}>> + %103 = tt.expand_dims %102 {axis = 2 : i32} : tensor<1x1x8xi32, #ttg.slice<{dim = 2, parent = #blocked}>> -> tensor<1x1x1x8xi32, #blocked> + %104 = tt.broadcast %103 : tensor<1x1x1x8xi32, #blocked> -> tensor<16x16x16x8xi32, #blocked> + %105 = arith.addi %100, %104 : tensor<16x16x16x8xi32, #blocked> + %106 = tt.splat %67 : i32 -> tensor<16x16x16x8xi32, #blocked> + %107 = arith.addi %106, %105 : tensor<16x16x16x8xi32, #blocked> + %108 = amdgpu.buffer_load %arg4[%107] : tensor<16x16x16x8xbf16, #blocked> + %109 = ttg.convert_layout %108 : tensor<16x16x16x8xbf16, #blocked> -> tensor<16x16x16x8xbf16, #linear1> + %110 = tt.trans %109 {order = array} : tensor<16x16x16x8xbf16, #linear1> -> tensor<16x8x16x16xbf16, #linear2> + %111 = tt.reshape %110 : tensor<16x8x16x16xbf16, #linear2> -> tensor<128x256xbf16, #ttg.dot_op<{opIdx = 1, parent = #mma, kWidth = 4}>> + %112 = ttg.local_alloc %66 : (tensor<16x128xbf16, #blocked1>) -> !ttg.memdesc<16x128xbf16, #shared, #smem> + %113 = ttg.local_load %112 : !ttg.memdesc<16x128xbf16, #shared, #smem> -> tensor<16x128xbf16, #ttg.dot_op<{opIdx = 0, parent = #mma, kWidth = 4}>> + %114 = tt.dot %113, %111, %cst_2, inputPrecision = tf32 : tensor<16x128xbf16, #ttg.dot_op<{opIdx = 0, parent = #mma, kWidth = 4}>> * tensor<128x256xbf16, #ttg.dot_op<{opIdx = 1, parent = #mma, kWidth = 4}>> -> tensor<16x256xf32, #mma> + %115 = tt.reshape %79 : tensor<16x16xi32, #linear> -> tensor<256xi32, #ttg.slice<{dim = 0, parent = #mma}>> + %116 = tt.expand_dims %115 {axis = 0 : i32} : tensor<256xi32, #ttg.slice<{dim = 0, parent = #mma}>> -> tensor<1x256xi32, #mma> + %117 = tt.splat %4 : i32 -> tensor<1x256xi32, #mma> + %118 = arith.cmpi slt, %116, %117 : tensor<1x256xi32, #mma> + %119 = tt.broadcast %43 : tensor<16x1xi1, #mma> -> tensor<16x256xi1, #mma> + %120 = tt.broadcast %118 : tensor<1x256xi1, #mma> -> tensor<16x256xi1, #mma> + %121 = arith.andi %119, %120 : tensor<16x256xi1, #mma> + %122 = arith.select %121, %114, %cst_6 : tensor<16x256xi1, #mma>, tensor<16x256xf32, #mma> + %123 = "tt.reduce"(%122) <{axis = 1 : i32}> ({ + ^bb0(%arg28: f32, %arg29: f32): + %185 = arith.maxnumf %arg28, %arg29 : f32 + tt.reduce.return %185 : f32 + }) : (tensor<16x256xf32, #mma>) -> tensor<16xf32, #ttg.slice<{dim = 1, parent = #mma}>> + %124 = tt.expand_dims %123 {axis = 1 : i32} : tensor<16xf32, #ttg.slice<{dim = 1, parent = #mma}>> -> tensor<16x1xf32, #mma> + %125 = tt.broadcast %124 : tensor<16x1xf32, #mma> -> tensor<16x256xf32, #mma> + %126 = arith.subf %122, %125 : tensor<16x256xf32, #mma> + %127 = arith.mulf %126, %cst_5 : tensor<16x256xf32, #mma> + %128 = math.exp2 %127 : tensor<16x256xf32, #mma> + %129 = arith.truncf %128 : tensor<16x256xf32, #mma> to tensor<16x256xbf16, #mma> + %130 = "tt.reduce"(%129) <{axis = 1 : i32}> ({ + ^bb0(%arg28: bf16, %arg29: bf16): + %185 = arith.addf %arg28, %arg29 : bf16 + tt.reduce.return %185 : bf16 + }) : (tensor<16x256xbf16, #mma>) -> tensor<16xbf16, #ttg.slice<{dim = 1, parent = #mma}>> + %131 = arith.muli %1, %arg25 : i32 + %132 = tt.expand_dims %28 {axis = 1 : i32} : tensor<16xi32, #ttg.slice<{dim = 1, parent = #ttg.slice<{dim = 2, parent = #blocked3}>}>> -> tensor<16x1xi32, #ttg.slice<{dim = 2, parent = #blocked3}>> + %133 = tt.expand_dims %132 {axis = 2 : i32} : tensor<16x1xi32, #ttg.slice<{dim = 2, parent = #blocked3}>> -> tensor<16x1x1xi32, #blocked3> + %134 = tt.splat %arg24 : i32 -> tensor<16x1x1xi32, #blocked3> + %135 = arith.muli %133, %134 : tensor<16x1x1xi32, #blocked3> + %136 = tt.broadcast %135 : tensor<16x1x1xi32, #blocked3> -> tensor<16x128x1xi32, #blocked3> + %137 = tt.expand_dims %40 {axis = 0 : i32} : tensor<128xi32, #ttg.slice<{dim = 0, parent = #ttg.slice<{dim = 2, parent = #blocked3}>}>> -> tensor<1x128xi32, #ttg.slice<{dim = 2, parent = #blocked3}>> + %138 = tt.expand_dims %137 {axis = 2 : i32} : tensor<1x128xi32, #ttg.slice<{dim = 2, parent = #blocked3}>> -> tensor<1x128x1xi32, #blocked3> + %139 = tt.splat %arg26 : i32 -> tensor<1x128x1xi32, #blocked3> + %140 = arith.muli %138, %139 : tensor<1x128x1xi32, #blocked3> + %141 = tt.broadcast %140 : tensor<1x128x1xi32, #blocked3> -> tensor<16x128x1xi32, #blocked3> + %142 = arith.addi %136, %141 : tensor<16x128x1xi32, #blocked3> + %143 = tt.broadcast %142 : tensor<16x128x1xi32, #blocked3> -> tensor<16x128x16xi32, #blocked3> + %144 = tt.expand_dims %70 {axis = 0 : i32} : tensor<16xi32, #ttg.slice<{dim = 0, parent = #ttg.slice<{dim = 1, parent = #blocked3}>}>> -> tensor<1x16xi32, #ttg.slice<{dim = 1, parent = #blocked3}>> + %145 = tt.expand_dims %144 {axis = 1 : i32} : tensor<1x16xi32, #ttg.slice<{dim = 1, parent = #blocked3}>> -> tensor<1x1x16xi32, #blocked3> + %146 = tt.broadcast %145 : tensor<1x1x16xi32, #blocked3> -> tensor<16x128x16xi32, #blocked3> + %147 = arith.addi %143, %146 : tensor<16x128x16xi32, #blocked3> + %148 = tt.splat %131 : i32 -> tensor<16x128x16xi32, #blocked3> + %149 = arith.addi %148, %147 : tensor<16x128x16xi32, #blocked3> + %150 = amdgpu.buffer_load %arg5[%149] : tensor<16x128x16xbf16, #blocked3> + %151 = ttg.convert_layout %150 : tensor<16x128x16xbf16, #blocked3> -> tensor<16x128x16xbf16, #linear3> + %152 = tt.trans %151 {order = array} : tensor<16x128x16xbf16, #linear3> -> tensor<16x16x128xbf16, #linear4> + %153 = tt.reshape %152 : tensor<16x16x128xbf16, #linear4> -> tensor<256x128xbf16, #ttg.dot_op<{opIdx = 1, parent = #mma, kWidth = 4}>> + %154 = arith.muli %0, %arg11 : i32 + %155 = arith.muli %1, %arg12 : i32 + %156 = arith.addi %154, %155 : i32 + %157 = arith.muli %2, %arg13 : i32 + %158 = arith.addi %156, %157 : i32 + %159 = tt.splat %158 : i32 -> tensor<16xi32, #blocked2> + %160 = arith.cmpi slt, %14, %cst_7 : tensor<16xi32, #blocked2> + %161 = arith.addi %159, %14 : tensor<16xi32, #blocked2> + %162 = ttg.convert_layout %123 : tensor<16xf32, #ttg.slice<{dim = 1, parent = #mma}>> -> tensor<16xf32, #blocked2> + amdgpu.buffer_store %162, %arg1[%161], %160 : tensor<16xf32, #blocked2> + %163 = ttg.convert_layout %130 : tensor<16xbf16, #ttg.slice<{dim = 1, parent = #mma}>> -> tensor<16xbf16, #blocked2> + %164 = arith.extf %163 : tensor<16xbf16, #blocked2> to tensor<16xf32, #blocked2> + amdgpu.buffer_store %164, %arg0[%161], %160 : tensor<16xf32, #blocked2> + %165 = ttg.local_alloc %129 : (tensor<16x256xbf16, #mma>) -> !ttg.memdesc<16x256xbf16, #shared1, #smem> + %166 = ttg.local_load %165 : !ttg.memdesc<16x256xbf16, #shared1, #smem> -> tensor<16x256xbf16, #ttg.dot_op<{opIdx = 0, parent = #mma, kWidth = 4}>> + %167 = tt.dot %166, %153, %cst_3, inputPrecision = tf32 : tensor<16x256xbf16, #ttg.dot_op<{opIdx = 0, parent = #mma, kWidth = 4}>> * tensor<256x128xbf16, #ttg.dot_op<{opIdx = 1, parent = #mma, kWidth = 4}>> -> tensor<16x128xf32, #mma> + %168 = tt.expand_dims %130 {axis = 1 : i32} : tensor<16xbf16, #ttg.slice<{dim = 1, parent = #mma}>> -> tensor<16x1xbf16, #mma> + %169 = arith.extf %168 : tensor<16x1xbf16, #mma> to tensor<16x1xf32, #mma> + %170 = tt.broadcast %169 : tensor<16x1xf32, #mma> -> tensor<16x128xf32, #mma> + %171 = arith.divf %167, %170 : tensor<16x128xf32, #mma> + %172 = arith.muli %0, %arg14 : i32 + %173 = arith.muli %1, %arg15 : i32 + %174 = arith.addi %172, %173 : i32 + %175 = arith.muli %2, %arg16 : i32 + %176 = tt.splat %arg17 : i32 -> tensor<16x1xi32, #mma> + %177 = arith.muli %35, %176 : tensor<16x1xi32, #mma> + %178 = tt.broadcast %177 : tensor<16x1xi32, #mma> -> tensor<16x128xi32, #mma> + %179 = tt.broadcast %41 : tensor<1x128xi32, #mma> -> tensor<16x128xi32, #mma> + %180 = arith.addi %178, %179 : tensor<16x128xi32, #mma> + %181 = arith.addi %174, %175 : i32 + %182 = tt.splat %181 : i32 -> tensor<16x128xi32, #mma> + %183 = arith.addi %182, %180 : tensor<16x128xi32, #mma> + %184 = arith.truncf %171 : tensor<16x128xf32, #mma> to tensor<16x128xbf16, #mma> + amdgpu.buffer_store %184, %arg2[%183], %51 : tensor<16x128xbf16, #mma> + tt.return + } +} diff --git a/third_party/amd/backend/compiler.py b/third_party/amd/backend/compiler.py index e7c5a6674dde..c0546e022135 100644 --- a/third_party/amd/backend/compiler.py +++ b/third_party/amd/backend/compiler.py @@ -249,6 +249,8 @@ def make_ttgir(mod, metadata, options): passes.common.add_canonicalizer(pm) passes.common.add_cse(pm) passes.common.add_symbol_dce(pm) + amd.passes.ttgpuir.add_implicit_convert_layout(pm) + passes.ttgpuir.add_remove_layout_conversions(pm) if use_async_copy: amd.passes.ttgpuir.add_update_async_wait_count(pm, options.arch) pm.run(mod, 'make_ttgir') diff --git a/third_party/amd/include/TritonAMDGPUTransforms/Passes.td b/third_party/amd/include/TritonAMDGPUTransforms/Passes.td index bd9fc77e317c..475e6c9ee177 100644 --- a/third_party/amd/include/TritonAMDGPUTransforms/Passes.td +++ b/third_party/amd/include/TritonAMDGPUTransforms/Passes.td @@ -291,4 +291,14 @@ def TritonAMDGPUOptimizeDotOperands : Pass<"tritonamdgpu-optimize-dot-operands", ]; } +def TritonAMDGPUImplicitConvertLayout: Pass<"tritonamdgpu-implicit-convert-layout", "mlir::ModuleOp"> { + let summary = "Convert #blocked/#linear layouts to #dot_op layouts implicitly before #tt.dot operation"; + + let description = "For layout conversion and shared memory load/store operations before #tt.dot, " + "this pass replaces them without extra conversion cost."; + + let dependentDialects = ["mlir::triton::gpu::TritonGPUDialect", + "mlir::triton::TritonDialect"]; +} + #endif diff --git a/third_party/amd/lib/TritonAMDGPUTransforms/CMakeLists.txt b/third_party/amd/lib/TritonAMDGPUTransforms/CMakeLists.txt index 98d39af2ba2a..78fd45176324 100644 --- a/third_party/amd/lib/TritonAMDGPUTransforms/CMakeLists.txt +++ b/third_party/amd/lib/TritonAMDGPUTransforms/CMakeLists.txt @@ -1,4 +1,5 @@ add_triton_library(TritonAMDGPUTransforms + ImplicitConvertLayout.cpp AccelerateAMDMatmul.cpp BlockPingpong.cpp CanonicalizePointers.cpp diff --git a/third_party/amd/lib/TritonAMDGPUTransforms/ImplicitConvertLayout.cpp b/third_party/amd/lib/TritonAMDGPUTransforms/ImplicitConvertLayout.cpp new file mode 100644 index 000000000000..592c1e3bb979 --- /dev/null +++ b/third_party/amd/lib/TritonAMDGPUTransforms/ImplicitConvertLayout.cpp @@ -0,0 +1,274 @@ +#include +#include + +#include "mlir/Analysis/SliceAnalysis.h" +#include "mlir/IR/BuiltinTypes.h" +#include "mlir/IR/Builders.h" +#include "mlir/Support/LLVM.h" +#include "triton/Analysis/AxisInfo.h" +#include "triton/Dialect/Triton/IR/Dialect.h" +#include "triton/Dialect/Triton/IR/Utility.h" +#include "triton/Dialect/TritonGPU/IR/Attributes.h" +#include "triton/Dialect/TritonGPU/IR/Dialect.h" +#include "triton/Dialect/TritonGPU/Transforms/Passes.h" +#include "triton/Dialect/TritonGPU/Transforms/Utility.h" +#include "triton/Tools/LayoutUtils.h" +#include "triton/Tools/LinearLayout.h" +#include "triton/Tools/StrUtil.h" +#include "llvm/ADT/MapVector.h" +#include "llvm/Support/Debug.h" + +#define DEBUG_TYPE "tritonamdgpu-implicit-convert-layout" +#define DBGS() (llvm::dbgs() << "[" DEBUG_TYPE "]: ") +#define LDBG(X) LLVM_DEBUG(DBGS() << X << "\n") + +namespace ttg = mlir::triton::gpu; +namespace tt = mlir::triton; + +namespace mlir { + +#define GEN_PASS_DEF_TRITONAMDGPUIMPLICITCONVERTLAYOUT +#include "TritonAMDGPUTransforms/Passes.h.inc" + +namespace { + +struct ImplicitConvertLayoutPass : public impl::TritonAMDGPUImplicitConvertLayoutBase { + static Type getNewType(Type type, Attribute encoding) { + RankedTensorType tensorType = cast(type); + return tensorType.cloneWithEncoding(encoding); + } + + void coalesceOp(Attribute srcEncoding, Attribute dstEncoding, Operation *op) { + OpBuilder builder(op); + // Convert operands + // For load/store with tensor pointers, we don't have to change the + // operands' type, we do this by changing the outputs' type of + // `make_tensor_ptr` + SmallVector newArgs; + for (auto operand : op->getOperands()) { + auto tensorType = dyn_cast(operand.getType()); + if (tensorType && + !isa(tensorType.getEncoding())) { + Type newType = getNewType(tensorType, srcEncoding); + newArgs.push_back(builder.create( + op->getLoc(), newType, operand)); + } else { + newArgs.push_back(operand); + } + } + + // Convert output types + SmallVector newTypes; + for (auto t : op->getResultTypes()) { + bool isAsync = isa(op); + newTypes.push_back(isAsync ? t : getNewType(t, dstEncoding)); + } + + // Construct new op with the new encoding + Operation *newOp = + builder.create(op->getLoc(), op->getName().getIdentifier(), newArgs, + newTypes, op->getAttrs()); + + // Cast the results back to the original layout + for (size_t i = 0; i < op->getNumResults(); i++) { + Value newResult = newOp->getResult(i); + if (newTypes[i] != op->getResultTypes()[i]) { + newResult = builder.create( + op->getLoc(), op->getResult(i).getType(), newResult); + } + op->getResult(i).replaceAllUsesWith(newResult); + } + op->erase(); + } + + void runOnOperation() override { + ModuleOp moduleOp = getOperation(); + tt::ModuleAxisInfoAnalysis axisInfoAnalysis(moduleOp); + + moduleOp.walk([&](Operation *cur) { + if (auto loadOp = dyn_cast(cur)) { + auto type = loadOp.getResult().getType(); + if (auto tensorTy = dyn_cast(type)) { + auto encoding = tensorTy.getEncoding(); + llvm::dbgs() << "load op: " << loadOp << "\n"; + llvm::dbgs() << "type: " << type << "\n"; + llvm::dbgs() << "encoding: " << encoding << "\n"; + if (auto blockedEncoding = dyn_cast(encoding)) { + llvm::dbgs() << "blocked encoding to linear layout: " << blockedEncoding.toLinearLayout(tensorTy.getShape()) << "\n"; + } + } + } + + auto dot = dyn_cast(cur); + if (!dot) + return; + + // 1. Check if the dot operand satisfies the implicit conversion conditions + auto BOperand = dot.getB(); + RankedTensorType BOperandTy = BOperand.getType(); + auto opEncoding = dyn_cast(BOperandTy.getEncoding()); + if (!opEncoding) + return; + + // Get backward slices util load op + BackwardSliceOptions opt; + opt.omitBlockArguments = true; + auto filter = [&dot](Operation *op) { + return op->getParentRegion() == dot->getParentRegion(); + }; + opt.filter = filter; + opt.inclusive = true; + llvm::SetVector backwardSlices; + llvm::SmallVector reversedBackwardSlices; + (void)getBackwardSlice(BOperand, &backwardSlices, opt); + for (auto sliceIter = backwardSlices.rbegin(); sliceIter != backwardSlices.rend(); sliceIter++) { + reversedBackwardSlices.emplace_back(*sliceIter); + if (isa(*sliceIter)) { + break; + } + } + if (reversedBackwardSlices.empty() || !isa(reversedBackwardSlices.back())) { + return; + } + + // Get vectorization factor of load op + tt::LoadOp loadOp = dyn_cast(reversedBackwardSlices.back()); + auto loadTy = loadOp.getType(); + int vecFactor = 1; + const int MIN_KWIDTH = 4; + if (auto loadTensorTy = dyn_cast(loadTy)) { + if (auto loadBlockedLayout = dyn_cast(loadTensorTy.getEncoding())) { + auto sizePerThread = loadBlockedLayout.getSizePerThread(); + auto loadOrder = loadBlockedLayout.getOrder(); + vecFactor = sizePerThread[loadOrder[0]]; + } + } + if (vecFactor < MIN_KWIDTH) { + return; + } + + // 2. Infer backward layout conversion "#tt.dot -> #tt.load" + // `layoutMap` maps backward slices to their input layouts + llvm::MapVector layoutMap; + auto newBOpLayout = ttg::DotOperandEncodingAttr::get( + BOperandTy.getContext(), 1, opEncoding.getParent(), vecFactor); + ttg::LinearEncodingAttr curLayout = ttg::LinearEncodingAttr::get( + BOperandTy.getContext(), newBOpLayout.toLinearLayout(BOperandTy.getShape())); + Attribute lastLayout = newBOpLayout; + for (auto slice : reversedBackwardSlices) { + if (!isa(slice)) { + tt::LinearLayout linearLayout = curLayout.getLinearLayout(); + auto resultTy = dyn_cast(slice->getResult(0).getType()); + if (auto transOp = dyn_cast(slice)) { + auto transOrder = to_vector(transOp.getOrder()); + auto originOrder = transOrder; + for (int i = 0; i < transOrder.size(); i++) { + originOrder[transOrder[i]] = i; + } + linearLayout = transposeLinearLayout(curLayout.getLinearLayout(), originOrder); + } + else if (auto reshapeOp = dyn_cast(slice)) { + auto originShape = reshapeOp.getSrc().getType().getShape(); + linearLayout = reshapeLayout(slice->getContext(), curLayout.getLinearLayout(), originShape); + } + // Make sure only valid instructions are included + // else if (!(isa(slice) + // || slice->hasTrait() + // || slice->hasTrait())) { + // llvm::dbgs() << "slice: " << *slice << "\n"; + // assert(false && "unsupported op"); + // } + lastLayout = curLayout; + curLayout = ttg::LinearEncodingAttr::get(BOperandTy.getContext(), linearLayout); + layoutMap[slice] = curLayout; + llvm::dbgs() << "slice: " << *slice << " \n-> input layout: " << layoutMap[slice] << "\n"; + } + else { + assert(false && "local load/alloc should not appear in implicit convert layout"); + } + } + + // 3. Propagate layout to forward slices (backward slices + // should be handled by `remove_layout_conversions` pass) + for (auto it = reversedBackwardSlices.rbegin(); it != reversedBackwardSlices.rend(); it++) { + Operation *slice = *it; + if (isa(slice)) { + Value srcVal = slice->getOperand(0); + Value dstVal = slice->getResult(0); + dstVal.replaceAllUsesWith(srcVal); + slice->erase(); + layoutMap.erase(slice); + } + else { + OpBuilder rewriter(slice); + Attribute srcEncoding = layoutMap[slice]; + Attribute dstEncoding = inferDstEncoding(slice, srcEncoding); + if (slice == reversedBackwardSlices.front()) { + dstEncoding = newBOpLayout; + } + llvm::dbgs() << "op: " << *slice << "\n"; + llvm::dbgs() << "src encoding: " << srcEncoding << "\n"; + llvm::dbgs() << "dst encoding: " << dstEncoding << "\n"; + + // `coalesceOp` will insert convert layout before and after `slice`, + // and we will remove them in `remove_layout_conversions` pass + coalesceOp(srcEncoding, dstEncoding, slice); + } + } + + // 4. Replace layout of operand B + BOperand = dot.getB(); + OpBuilder rewriter(BOperand.getDefiningOp()); + rewriter.setInsertionPointAfter(BOperand.getDefiningOp()); + auto newBType = RankedTensorType::get( + BOperand.getType().getShape(), + BOperand.getType().getElementType(), + newBOpLayout + ); + auto newBOperand = rewriter.create( + BOperand.getDefiningOp()->getLoc(), newBType, BOperand); + BOperand.replaceAllUsesExcept(newBOperand, newBOperand); + + llvm::dbgs() << "dot op: " << dot << "\n"; + llvm::dbgs() << "B tensor type: " << BOperandTy << "\n"; + llvm::dbgs() << "encoding: " << opEncoding << "\n"; + // llvm::dbgs() << "linear layout: " << opEncoding.toLinearLayout(tensorTy.getShape()) << "\n"; + // llvm::dbgs() << "reversed backward slices:\n"; + // for (auto slice : reversedBackwardSlices) { + // llvm::dbgs() << *slice << "\n"; + // } + llvm::dbgs() << "vectorization factor: " << vecFactor << "\n"; + llvm::dbgs() << "new layout: " << newBOpLayout << "\n"; + llvm::dbgs() << "BOperand: " << BOperand << "\n"; + llvm::dbgs() << "BOperand defining op: " << *BOperand.getDefiningOp() << "\n"; + + llvm::dbgs() << "\n"; + + // 5. Replace layout of operand A + auto AOperand = dot.getA(); + auto AOperandTy = AOperand.getType(); + opEncoding = dyn_cast(AOperandTy.getEncoding()); + if (!opEncoding) + return; + auto newAOpLayout = ttg::DotOperandEncodingAttr::get( + AOperandTy.getContext(), 0, opEncoding.getParent(), vecFactor); + + // Assume A{#dot_op(0)} is defined by `A = #ttg.local_load ...`, + // we change the output layout of #ttg.local_load directly + auto localLoadOp = dyn_cast(AOperand.getDefiningOp()); + assert(localLoadOp && "A should be defined by local load"); + rewriter.setInsertionPointAfter(localLoadOp); + auto newLocalLoadOp = rewriter.clone(*localLoadOp); + AOperandTy = AOperandTy.cloneWithEncoding(newAOpLayout); + newLocalLoadOp->getResult(0).setType(AOperandTy); + AOperand.replaceAllUsesWith(newLocalLoadOp->getResult(0)); + localLoadOp->erase(); + + // llvm::dbgs() << "current function:\n"; + // llvm::dbgs() << *cur->getParentOfType() << "\n"; + }); + } +}; + +} +} // namespace mlir diff --git a/third_party/amd/python/triton_amd.cc b/third_party/amd/python/triton_amd.cc index fb3880c5715c..f8977a58ab48 100644 --- a/third_party/amd/python/triton_amd.cc +++ b/third_party/amd/python/triton_amd.cc @@ -89,6 +89,7 @@ void init_triton_amd_passes_ttgpuir(py::module &&m) { ADD_PASS_WRAPPER_0("add_reorder_instructions", mlir::createTritonAMDGPUReorderInstructions); ADD_PASS_WRAPPER_0("add_fold_true_cmpi", mlir::createTritonAMDFoldTrueCmpI); + ADD_PASS_WRAPPER_0("add_implicit_convert_layout", mlir::createTritonAMDGPUImplicitConvertLayout); ADD_PASS_OPTION_WRAPPER_1("add_block_pingpong", mlir::createTritonAMDGPUBlockPingpong, int32_t); ADD_PASS_OPTION_WRAPPER_1("add_schedule_loops",