From 64eb6d9453a2189bcbc3fd383d78d190ea3dfb89 Mon Sep 17 00:00:00 2001 From: Maksim Levental Date: Wed, 26 Mar 2025 15:59:47 -0400 Subject: [PATCH 1/7] [AMD] DCE/canonicalize true epilogue conditionals --- test/TritonGPU/amd/amd-range-analysis.mlir | 128 +++++++++++++++++++++ 1 file changed, 128 insertions(+) diff --git a/test/TritonGPU/amd/amd-range-analysis.mlir b/test/TritonGPU/amd/amd-range-analysis.mlir index f4ad3b17b096..12535c94a7cf 100644 --- a/test/TritonGPU/amd/amd-range-analysis.mlir +++ b/test/TritonGPU/amd/amd-range-analysis.mlir @@ -1348,3 +1348,131 @@ module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32} { tt.return } } + +// ----- + +#blocked = #ttg.blocked<{sizePerThread = [1, 4], threadsPerWarp = [1, 32], warpsPerCTA = [4, 1], order = [1, 0]}> +#blocked1 = #ttg.blocked<{sizePerThread = [1, 4], threadsPerWarp = [4, 8], warpsPerCTA = [4, 1], order = [1, 0]}> +#mma = #ttg.nvidia_mma<{versionMajor = 2, versionMinor = 0, warpsPerCTA = [4, 1], instrShape = []}> +#shared = #ttg.swizzled_shared<{vec = 8, perPhase = 2, maxPhase = 4, order = [1, 0]}> +#shared1 = #ttg.swizzled_shared<{vec = 8, perPhase = 1, maxPhase = 8, order = [1, 0]}> +#smem = #ttg.shared_memory +module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32} { + tt.func @assume_matmul(%arg0: index, %arg1: index, %arg2: index, %arg3: !tt.ptr {tt.divisibility = 16 : i32}, %arg4: !tt.ptr {tt.divisibility = 16 : i32}) -> tensor<128x128xf32, #mma> { + // expected-remark@+1 {{unsigned : [18446744073709551615, 18446744073709551615] signed : [-1, -1]}} + %c-1 = arith.constant -1 : index + // expected-remark@+2 {{unsigned : [1, 1] signed : [1, 1]}} + // expected-remark@+1 {{non-neg}} + %c1 = arith.constant 1 : index + // expected-remark@+2 {{unsigned : [0, 0] signed : [0, 0]}} + // expected-remark@+1 {{non-neg}} + %c0 = arith.constant 0 : index + // expected-remark@+2 {{unsigned : [1, 1] signed : [1, 1]}} + // expected-remark@+1 {{non-neg}} + %c1_i32 = arith.constant 1 : i32 + // expected-remark@+2 {{unsigned : [0, 0] signed : [0, 0]}} + // expected-remark@+1 {{non-neg}} + %c0_i32 = arith.constant 0 : i32 + // expected-remark@+1 {{unsigned : [1, 1] signed : [-1, -1]}} + %true = arith.constant true + %cst = arith.constant dense<4.000000e+00> : tensor<32x128xf16, #ttg.dot_op<{opIdx = 1, parent = #mma, kWidth = 2}>> + // expected-remark@+2 {{unsigned : [4, 4] signed : [4, 4]}} + // expected-remark@+1 {{non-neg}} + %cst_0 = arith.constant dense<4> : tensor<32x128xi32, #blocked> + // expected-remark@+2 {{unsigned : [4, 4] signed : [4, 4]}} + // expected-remark@+1 {{non-neg}} + %cst_1 = arith.constant dense<4> : tensor<128x32xi32, #blocked1> + %cst_2 = arith.constant dense<0.000000e+00> : tensor<128x128xf32, #mma> + %cst_3 = arith.constant dense<0.000000e+00> : tensor<32x128xf16, #blocked> + %0 = tt.splat %arg3 : !tt.ptr -> tensor<128x32x!tt.ptr, #blocked1> + // expected-remark@+2 {{unsigned : [0, 32] signed : [0, 32]}} + // expected-remark@+1 {{non-neg}} + %1 = tt.make_range {end = 32 : i32, start = 0 : i32} : tensor<32xi32, #ttg.slice<{dim = 0, parent = #blocked1}>> + // expected-remark@+2 {{unsigned : [0, 32] signed : [0, 32]}} + // expected-remark@+1 {{non-neg}} + %2 = tt.expand_dims %1 {axis = 0 : i32} : tensor<32xi32, #ttg.slice<{dim = 0, parent = #blocked1}>> -> tensor<1x32xi32, #blocked1> + // expected-remark@+2 {{unsigned : [0, 32] signed : [0, 32]}} + // expected-remark@+1 {{non-neg}} + %3 = tt.broadcast %2 : tensor<1x32xi32, #blocked1> -> tensor<128x32xi32, #blocked1> + %4 = tt.addptr %0, %3 : tensor<128x32x!tt.ptr, #blocked1>, tensor<128x32xi32, #blocked1> + %5 = tt.splat %arg4 : !tt.ptr -> tensor<32x128x!tt.ptr, #blocked> + // expected-remark@+2 {{unsigned : [0, 128] signed : [0, 128]}} + // expected-remark@+1 {{non-neg}} + %6 = tt.make_range {end = 128 : i32, start = 0 : i32} : tensor<128xi32, #ttg.slice<{dim = 0, parent = #blocked}>> + // expected-remark@+2 {{unsigned : [0, 128] signed : [0, 128]}} + // expected-remark@+1 {{non-neg}} + %7 = tt.expand_dims %6 {axis = 0 : i32} : tensor<128xi32, #ttg.slice<{dim = 0, parent = #blocked}>> -> tensor<1x128xi32, #blocked> + // expected-remark@+2 {{unsigned : [0, 128] signed : [0, 128]}} + // expected-remark@+1 {{non-neg}} + %8 = tt.broadcast %7 : tensor<1x128xi32, #blocked> -> tensor<32x128xi32, #blocked> + %9 = tt.addptr %5, %8 : tensor<32x128x!tt.ptr, #blocked>, tensor<32x128xi32, #blocked> + %10 = ttg.local_alloc : () -> !ttg.memdesc<1x128x32xf16, #shared, #smem, mutable> + %11 = ttg.local_alloc : () -> !ttg.memdesc<1x32x128xf16, #shared1, #smem, mutable> + // expected-remark@+1 {{unsigned : [0, 1] signed : [-1, 0]}} + %12 = arith.cmpi slt, %arg0, %arg1 : index + // expected-remark@+1 {{unsigned : [0, 1] signed : [-1, 0]}} + %13 = tt.splat %12 : i1 -> tensor<128x32xi1, #blocked1> + %14 = tt.load %4, %13 {OpIdx = #amdgpu.OpIdx<0>} : tensor<128x32x!tt.ptr, #blocked1> + // expected-remark@+1 {{unsigned : [0, 1] signed : [-1, 0]}} + %15 = tt.splat %12 : i1 -> tensor<32x128xi1, #blocked> + %16 = tt.load %9, %15, %cst_3 : tensor<32x128x!tt.ptr, #blocked> + %17 = ttg.memdesc_subview %10[%c0_i32, %c0_i32, %c0_i32] : !ttg.memdesc<1x128x32xf16, #shared, #smem, mutable> -> !ttg.memdesc<128x32xf16, #shared, #smem, mutable> + ttg.local_store %14, %17 {OpIdx = #amdgpu.OpIdx<0>} : tensor<128x32xf16, #blocked1> -> !ttg.memdesc<128x32xf16, #shared, #smem, mutable> + %18 = ttg.memdesc_subview %11[%c0_i32, %c0_i32, %c0_i32] : !ttg.memdesc<1x32x128xf16, #shared1, #smem, mutable> -> !ttg.memdesc<32x128xf16, #shared1, #smem, mutable> + ttg.local_store %16, %18 : tensor<32x128xf16, #blocked> -> !ttg.memdesc<32x128xf16, #shared1, #smem, mutable> + // expected-remark@+1 {{unsigned : [0, 18446744073709551615] signed : [-9223372036854775808, 9223372036854775807]}} + %19 = arith.subi %arg1, %arg2 : index + %20:6 = scf.for %arg5 = %arg0 to %19 step %arg2 iter_args(%arg6 = %4, %arg7 = %9, %arg8 = %cst_2, %arg9 = %c0_i32, %arg10 = %17, %arg11 = %18) -> (tensor<128x32x!tt.ptr, #blocked1>, tensor<32x128x!tt.ptr, #blocked>, tensor<128x128xf32, #mma>, i32, !ttg.memdesc<128x32xf16, #shared, #smem, mutable>, !ttg.memdesc<32x128xf16, #shared1, #smem, mutable>) { + %33 = tt.addptr %arg6, %cst_1 : tensor<128x32x!tt.ptr, #blocked1>, tensor<128x32xi32, #blocked1> + %34 = tt.addptr %arg7, %cst_0 : tensor<32x128x!tt.ptr, #blocked>, tensor<32x128xi32, #blocked> + llvm.intr.assume %true : i1 + %35 = tt.load %33 {OpIdx = #amdgpu.OpIdx<0>} : tensor<128x32x!tt.ptr, #blocked1> + %36 = ttg.local_load %arg10 : !ttg.memdesc<128x32xf16, #shared, #smem, mutable> -> tensor<128x32xf16, #ttg.dot_op<{opIdx = 0, parent = #mma, kWidth = 2}>> + %37 = tt.load %34 : tensor<32x128x!tt.ptr, #blocked> + %38 = ttg.local_load %arg11 : !ttg.memdesc<32x128xf16, #shared1, #smem, mutable> -> tensor<32x128xf16, #ttg.dot_op<{opIdx = 1, parent = #mma, kWidth = 2}>> + %39 = arith.mulf %38, %cst : tensor<32x128xf16, #ttg.dot_op<{opIdx = 1, parent = #mma, kWidth = 2}>> + %40 = tt.dot %36, %39, %arg8 : tensor<128x32xf16, #ttg.dot_op<{opIdx = 0, parent = #mma, kWidth = 2}>> * tensor<32x128xf16, #ttg.dot_op<{opIdx = 1, parent = #mma, kWidth = 2}>> -> tensor<128x128xf32, #mma> + // expected-remark@+1 {{unsigned : [0, 4294967295] signed : [-2147483648, 2147483647]}} + %41 = arith.addi %arg9, %c1_i32 : i32 + // expected-remark@+1 {{unsigned : [0, 1] signed : [-1, 0]}} + %42 = arith.cmpi slt, %41, %c1_i32 : i32 + // expected-remark@+1 {{unsigned : [0, 4294967295] signed : [-2147483648, 2147483647]}} + %43 = arith.select %42, %41, %c0_i32 : i32 + %44 = ttg.memdesc_subview %10[%43, %c0_i32, %c0_i32] : !ttg.memdesc<1x128x32xf16, #shared, #smem, mutable> -> !ttg.memdesc<128x32xf16, #shared, #smem, mutable> + ttg.local_store %35, %44 {OpIdx = #amdgpu.OpIdx<0>} : tensor<128x32xf16, #blocked1> -> !ttg.memdesc<128x32xf16, #shared, #smem, mutable> + %45 = ttg.memdesc_subview %11[%43, %c0_i32, %c0_i32] : !ttg.memdesc<1x32x128xf16, #shared1, #smem, mutable> -> !ttg.memdesc<32x128xf16, #shared1, #smem, mutable> + ttg.local_store %37, %45 : tensor<32x128xf16, #blocked> -> !ttg.memdesc<32x128xf16, #shared1, #smem, mutable> + scf.yield %33, %34, %40, %43, %44, %45 : tensor<128x32x!tt.ptr, #blocked1>, tensor<32x128x!tt.ptr, #blocked>, tensor<128x128xf32, #mma>, i32, !ttg.memdesc<128x32xf16, #shared, #smem, mutable>, !ttg.memdesc<32x128xf16, #shared1, #smem, mutable> + } + // expected-remark@+1 {{unsigned : [0, 1] signed : [-1, 0]}} + %21 = arith.cmpi slt, %arg2, %c0 : index + // expected-remark@+1 {{unsigned : [1, 18446744073709551615] signed : [-1, 1]}} + %22 = arith.select %21, %c1, %c-1 : index + // expected-remark@+1 {{unsigned : [0, 18446744073709551615] signed : [-9223372036854775808, 9223372036854775807]}} + %23 = arith.subi %arg1, %arg0 : index + // expected-remark@+1 {{unsigned : [0, 18446744073709551615] signed : [-9223372036854775808, 9223372036854775807]}} + %24 = arith.addi %23, %arg2 : index + // expected-remark@+1 {{unsigned : [0, 18446744073709551615] signed : [-9223372036854775808, 9223372036854775807]}} + %25 = arith.addi %24, %22 : index + // expected-remark@+2 {{unsigned : [1, 9223372036854775807] signed : [1, 9223372036854775807]}} + // expected-remark@+1 {{non-neg}} + %26 = arith.divsi %25, %arg2 : index + %28 = ttg.local_load %20#4 : !ttg.memdesc<128x32xf16, #shared, #smem, mutable> -> tensor<128x32xf16, #ttg.dot_op<{opIdx = 0, parent = #mma, kWidth = 2}>> + %29 = ttg.local_load %20#5 : !ttg.memdesc<32x128xf16, #shared1, #smem, mutable> -> tensor<32x128xf16, #ttg.dot_op<{opIdx = 1, parent = #mma, kWidth = 2}>> + %30 = arith.mulf %29, %cst : tensor<32x128xf16, #ttg.dot_op<{opIdx = 1, parent = #mma, kWidth = 2}>> + // expected-remark@+2 {{unsigned : [1, 1] signed : [-1, -1]}} + // expected-remark@+1 {{result is true}} + %27 = arith.cmpi sge, %26, %c1 : index + llvm.intr.assume %27 : i1 + %31 = scf.if %27 -> (tensor<128x128xf32, #mma>) { + %33 = tt.dot %28, %30, %20#2 : tensor<128x32xf16, #ttg.dot_op<{opIdx = 0, parent = #mma, kWidth = 2}>> * tensor<32x128xf16, #ttg.dot_op<{opIdx = 1, parent = #mma, kWidth = 2}>> -> tensor<128x128xf32, #mma> + scf.yield %33 : tensor<128x128xf32, #mma> + } else { + scf.yield %20#2 : tensor<128x128xf32, #mma> + } + %32 = arith.select %27, %31, %20#2 : tensor<128x128xf32, #mma> + ttg.local_dealloc %10 : !ttg.memdesc<1x128x32xf16, #shared, #smem, mutable> + ttg.local_dealloc %11 : !ttg.memdesc<1x32x128xf16, #shared1, #smem, mutable> + tt.return %32 : tensor<128x128xf32, #mma> + } +} From 95d7dc0f3e676b2ebfd5efd04f825f175fbf0a18 Mon Sep 17 00:00:00 2001 From: Maksim Levental Date: Fri, 28 Mar 2025 15:09:31 -0400 Subject: [PATCH 2/7] add fold-true-cmpi pattern/test pass --- bin/RegisterTritonDialects.h | 2 + test/TritonGPU/amd/amd-fold-true-cmpi.mlir | 108 ++++++++++++++++++ .../amd/include/Analysis/RangeAnalysis.h | 3 +- .../amd/lib/Analysis/RangeAnalysis.cpp | 27 +++++ .../TritonAMDGPUTransforms/StreamPipeline.cpp | 2 +- .../amd/test/lib/Analysis/CMakeLists.txt | 1 + .../test/lib/Analysis/TestFoldTrueCmpIOp.cpp | 47 ++++++++ 7 files changed, 188 insertions(+), 2 deletions(-) create mode 100644 test/TritonGPU/amd/amd-fold-true-cmpi.mlir create mode 100644 third_party/amd/test/lib/Analysis/TestFoldTrueCmpIOp.cpp diff --git a/bin/RegisterTritonDialects.h b/bin/RegisterTritonDialects.h index 38030dad983e..4c33041c27b2 100644 --- a/bin/RegisterTritonDialects.h +++ b/bin/RegisterTritonDialects.h @@ -34,6 +34,7 @@ void registerTestAlignmentPass(); void registerTestAllocationPass(); void registerTestMembarPass(); void registerTestTritonAMDGPURangeAnalysis(); +void registerTestTritonAMDGPUFoldTrueCmpIOp(); } // namespace test } // namespace mlir @@ -47,6 +48,7 @@ inline void registerTritonDialects(mlir::DialectRegistry ®istry) { mlir::test::registerTestAllocationPass(); mlir::test::registerTestMembarPass(); mlir::test::registerTestTritonAMDGPURangeAnalysis(); + mlir::test::registerTestTritonAMDGPUFoldTrueCmpIOp(); mlir::triton::registerConvertTritonToTritonGPUPass(); mlir::triton::gpu::registerAllocateSharedMemoryPass(); mlir::triton::gpu::registerTritonGPUAllocateWarpGroups(); diff --git a/test/TritonGPU/amd/amd-fold-true-cmpi.mlir b/test/TritonGPU/amd/amd-fold-true-cmpi.mlir new file mode 100644 index 000000000000..05375550a2f0 --- /dev/null +++ b/test/TritonGPU/amd/amd-fold-true-cmpi.mlir @@ -0,0 +1,108 @@ +// RUN: triton-opt %s -split-input-file -allow-unregistered-dialect -test-tritonamdgpu-fold-true-cmpi -canonicalize | FileCheck %s + +#blocked = #ttg.blocked<{sizePerThread = [1, 4], threadsPerWarp = [1, 32], warpsPerCTA = [4, 1], order = [1, 0]}> +#blocked1 = #ttg.blocked<{sizePerThread = [1, 4], threadsPerWarp = [4, 8], warpsPerCTA = [4, 1], order = [1, 0]}> +#mma = #ttg.nvidia_mma<{versionMajor = 2, versionMinor = 0, warpsPerCTA = [4, 1], instrShape = []}> +#shared = #ttg.swizzled_shared<{vec = 8, perPhase = 2, maxPhase = 4, order = [1, 0]}> +#shared1 = #ttg.swizzled_shared<{vec = 8, perPhase = 1, maxPhase = 8, order = [1, 0]}> +#smem = #ttg.shared_memory +module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32} { + tt.func @assume_matmul(%arg0: index, %arg1: index, %arg2: index, %arg3: !tt.ptr, %arg4: !tt.ptr) -> tensor<128x128xf32, #mma> { + %c-1 = arith.constant -1 : index + %c1 = arith.constant 1 : index + %c0 = arith.constant 0 : index + %c1_i32 = arith.constant 1 : i32 + %c0_i32 = arith.constant 0 : i32 + %true = arith.constant true + %cst = arith.constant dense<4.000000e+00> : tensor<32x128xf16, #ttg.dot_op<{opIdx = 1, parent = #mma, kWidth = 2}>> + %cst_0 = arith.constant dense<4> : tensor<32x128xi32, #blocked> + %cst_1 = arith.constant dense<4> : tensor<128x32xi32, #blocked1> + %cst_2 = arith.constant dense<0.000000e+00> : tensor<128x128xf32, #mma> + %cst_3 = arith.constant dense<0.000000e+00> : tensor<32x128xf16, #blocked> + %0 = tt.splat %arg3 : !tt.ptr -> tensor<128x32x!tt.ptr, #blocked1> + %1 = tt.make_range {end = 32 : i32, start = 0 : i32} : tensor<32xi32, #ttg.slice<{dim = 0, parent = #blocked1}>> + %2 = tt.expand_dims %1 {axis = 0 : i32} : tensor<32xi32, #ttg.slice<{dim = 0, parent = #blocked1}>> -> tensor<1x32xi32, #blocked1> + %3 = tt.broadcast %2 : tensor<1x32xi32, #blocked1> -> tensor<128x32xi32, #blocked1> + %4 = tt.addptr %0, %3 : tensor<128x32x!tt.ptr, #blocked1>, tensor<128x32xi32, #blocked1> + %5 = tt.splat %arg4 : !tt.ptr -> tensor<32x128x!tt.ptr, #blocked> + %6 = tt.make_range {end = 128 : i32, start = 0 : i32} : tensor<128xi32, #ttg.slice<{dim = 0, parent = #blocked}>> + %7 = tt.expand_dims %6 {axis = 0 : i32} : tensor<128xi32, #ttg.slice<{dim = 0, parent = #blocked}>> -> tensor<1x128xi32, #blocked> + %8 = tt.broadcast %7 : tensor<1x128xi32, #blocked> -> tensor<32x128xi32, #blocked> + %9 = tt.addptr %5, %8 : tensor<32x128x!tt.ptr, #blocked>, tensor<32x128xi32, #blocked> + %10 = ttg.local_alloc : () -> !ttg.memdesc<1x128x32xf16, #shared, #smem, mutable> + %11 = ttg.local_alloc : () -> !ttg.memdesc<1x32x128xf16, #shared1, #smem, mutable> + %12 = arith.cmpi slt, %arg0, %arg1 : index + %13 = tt.splat %12 : i1 -> tensor<128x32xi1, #blocked1> + %14 = tt.load %4, %13 {OpIdx = #amdgpu.OpIdx<0>} : tensor<128x32x!tt.ptr, #blocked1> + %15 = tt.splat %12 : i1 -> tensor<32x128xi1, #blocked> + %16 = tt.load %9, %15, %cst_3 : tensor<32x128x!tt.ptr, #blocked> + %17 = ttg.memdesc_subview %10[%c0_i32, %c0_i32, %c0_i32] : !ttg.memdesc<1x128x32xf16, #shared, #smem, mutable> -> !ttg.memdesc<128x32xf16, #shared, #smem, mutable> + ttg.local_store %14, %17 {OpIdx = #amdgpu.OpIdx<0>} : tensor<128x32xf16, #blocked1> -> !ttg.memdesc<128x32xf16, #shared, #smem, mutable> + %18 = ttg.memdesc_subview %11[%c0_i32, %c0_i32, %c0_i32] : !ttg.memdesc<1x32x128xf16, #shared1, #smem, mutable> -> !ttg.memdesc<32x128xf16, #shared1, #smem, mutable> + ttg.local_store %16, %18 : tensor<32x128xf16, #blocked> -> !ttg.memdesc<32x128xf16, #shared1, #smem, mutable> + %19 = arith.subi %arg1, %arg2 : index + %20:6 = scf.for %arg5 = %arg0 to %19 step %arg2 iter_args(%arg6 = %4, %arg7 = %9, %arg8 = %cst_2, %arg9 = %c0_i32, %arg10 = %17, %arg11 = %18) -> (tensor<128x32x!tt.ptr, #blocked1>, tensor<32x128x!tt.ptr, #blocked>, tensor<128x128xf32, #mma>, i32, !ttg.memdesc<128x32xf16, #shared, #smem, mutable>, !ttg.memdesc<32x128xf16, #shared1, #smem, mutable>) { + %33 = tt.addptr %arg6, %cst_1 : tensor<128x32x!tt.ptr, #blocked1>, tensor<128x32xi32, #blocked1> + %34 = tt.addptr %arg7, %cst_0 : tensor<32x128x!tt.ptr, #blocked>, tensor<32x128xi32, #blocked> + llvm.intr.assume %true : i1 + %35 = tt.load %33 {OpIdx = #amdgpu.OpIdx<0>} : tensor<128x32x!tt.ptr, #blocked1> + %36 = ttg.local_load %arg10 : !ttg.memdesc<128x32xf16, #shared, #smem, mutable> -> tensor<128x32xf16, #ttg.dot_op<{opIdx = 0, parent = #mma, kWidth = 2}>> + %37 = tt.load %34 : tensor<32x128x!tt.ptr, #blocked> + %38 = ttg.local_load %arg11 : !ttg.memdesc<32x128xf16, #shared1, #smem, mutable> -> tensor<32x128xf16, #ttg.dot_op<{opIdx = 1, parent = #mma, kWidth = 2}>> + %39 = arith.mulf %38, %cst : tensor<32x128xf16, #ttg.dot_op<{opIdx = 1, parent = #mma, kWidth = 2}>> + %40 = tt.dot %36, %39, %arg8 : tensor<128x32xf16, #ttg.dot_op<{opIdx = 0, parent = #mma, kWidth = 2}>> * tensor<32x128xf16, #ttg.dot_op<{opIdx = 1, parent = #mma, kWidth = 2}>> -> tensor<128x128xf32, #mma> + %41 = arith.addi %arg9, %c1_i32 : i32 + %42 = arith.cmpi slt, %41, %c1_i32 : i32 + %43 = arith.select %42, %41, %c0_i32 : i32 + %44 = ttg.memdesc_subview %10[%43, %c0_i32, %c0_i32] : !ttg.memdesc<1x128x32xf16, #shared, #smem, mutable> -> !ttg.memdesc<128x32xf16, #shared, #smem, mutable> + ttg.local_store %35, %44 {OpIdx = #amdgpu.OpIdx<0>} : tensor<128x32xf16, #blocked1> -> !ttg.memdesc<128x32xf16, #shared, #smem, mutable> + %45 = ttg.memdesc_subview %11[%43, %c0_i32, %c0_i32] : !ttg.memdesc<1x32x128xf16, #shared1, #smem, mutable> -> !ttg.memdesc<32x128xf16, #shared1, #smem, mutable> + ttg.local_store %37, %45 : tensor<32x128xf16, #blocked> -> !ttg.memdesc<32x128xf16, #shared1, #smem, mutable> + scf.yield %33, %34, %40, %43, %44, %45 : tensor<128x32x!tt.ptr, #blocked1>, tensor<32x128x!tt.ptr, #blocked>, tensor<128x128xf32, #mma>, i32, !ttg.memdesc<128x32xf16, #shared, #smem, mutable>, !ttg.memdesc<32x128xf16, #shared1, #smem, mutable> + } + %21 = arith.cmpi slt, %arg2, %c0 : index + %22 = arith.select %21, %c1, %c-1 : index + %23 = arith.subi %arg1, %arg0 : index + %24 = arith.addi %23, %arg2 : index + %25 = arith.addi %24, %22 : index + %26 = arith.divsi %25, %arg2 : index + %28 = ttg.local_load %20#4 : !ttg.memdesc<128x32xf16, #shared, #smem, mutable> -> tensor<128x32xf16, #ttg.dot_op<{opIdx = 0, parent = #mma, kWidth = 2}>> + %29 = ttg.local_load %20#5 : !ttg.memdesc<32x128xf16, #shared1, #smem, mutable> -> tensor<32x128xf16, #ttg.dot_op<{opIdx = 1, parent = #mma, kWidth = 2}>> + %30 = arith.mulf %29, %cst : tensor<32x128xf16, #ttg.dot_op<{opIdx = 1, parent = #mma, kWidth = 2}>> + %27 = arith.cmpi sge, %26, %c1 : index + llvm.intr.assume %27 : i1 + %31 = scf.if %27 -> (tensor<128x128xf32, #mma>) { + %33 = tt.dot %28, %30, %20#2 : tensor<128x32xf16, #ttg.dot_op<{opIdx = 0, parent = #mma, kWidth = 2}>> * tensor<32x128xf16, #ttg.dot_op<{opIdx = 1, parent = #mma, kWidth = 2}>> -> tensor<128x128xf32, #mma> + scf.yield %33 : tensor<128x128xf32, #mma> + } else { + scf.yield %20#2 : tensor<128x128xf32, #mma> + } + %32 = arith.select %27, %31, %20#2 : tensor<128x128xf32, #mma> + ttg.local_dealloc %10 : !ttg.memdesc<1x128x32xf16, #shared, #smem, mutable> + ttg.local_dealloc %11 : !ttg.memdesc<1x32x128xf16, #shared1, #smem, mutable> + tt.return %32 : tensor<128x128xf32, #mma> + } +} + +// CHECK: #[[$ATTR_2:.+]] = #ttg.nvidia_mma<{versionMajor = 2, versionMinor = 0, warpsPerCTA = [4, 1], instrShape = []}> +// CHECK: #[[$ATTR_3:.+]] = #ttg.swizzled_shared<{vec = 8, perPhase = 2, maxPhase = 4, order = [1, 0]}> +// CHECK: #[[$ATTR_4:.+]] = #ttg.swizzled_shared<{vec = 8, perPhase = 1, maxPhase = 8, order = [1, 0]}> +// CHECK: #[[$ATTR_5:.+]] = #ttg.shared_memory + +// CHECK-LABEL: tt.func @assume_matmul( +// CHECK: %[[VAL_7:.*]] = arith.constant true +// CHECK: %[[VAL_8:.*]] = arith.constant dense<4.000000e+00> : tensor<32x128xf16, #ttg.dot_op<{opIdx = 1, parent = #[[$ATTR_2]], kWidth = 2}>> +// CHECK: %[[VAL_23:.*]] = ttg.local_alloc : () -> !ttg.memdesc<1x128x32xf16, #[[$ATTR_3]], #[[$ATTR_5]], mutable> +// CHECK: %[[VAL_24:.*]] = ttg.local_alloc : () -> !ttg.memdesc<1x32x128xf16, #[[$ATTR_4]], #[[$ATTR_5]], mutable> +// CHECK: %[[VAL_33:.*]]:6 = scf.for +// CHECK: scf.yield +// CHECK: } +// CHECK-NEXT: %[[VAL_54:.*]] = ttg.local_load %[[VAL_55:.*]]#4 : !ttg.memdesc<128x32xf16, #[[$ATTR_3]], #[[$ATTR_5]], mutable> -> tensor<128x32xf16, #ttg.dot_op<{opIdx = 0, parent = #[[$ATTR_2]], kWidth = 2}>> +// CHECK-NEXT: %[[VAL_56:.*]] = ttg.local_load %[[VAL_55]]#5 : !ttg.memdesc<32x128xf16, #[[$ATTR_4]], #[[$ATTR_5]], mutable> -> tensor<32x128xf16, #ttg.dot_op<{opIdx = 1, parent = #[[$ATTR_2]], kWidth = 2}>> +// CHECK-NEXT: %[[VAL_57:.*]] = arith.mulf %[[VAL_56]], %[[VAL_8]] : tensor<32x128xf16, #ttg.dot_op<{opIdx = 1, parent = #[[$ATTR_2]], kWidth = 2}>> +// CHECK-NEXT: llvm.intr.assume %[[VAL_7]] : i1 +// CHECK-NEXT: %[[VAL_58:.*]] = tt.dot %[[VAL_54]], %[[VAL_57]], %[[VAL_55]]#2 : tensor<128x32xf16, #ttg.dot_op<{opIdx = 0, parent = #[[$ATTR_2]], kWidth = 2}>> * tensor<32x128xf16, #ttg.dot_op<{opIdx = 1, parent = #[[$ATTR_2]], kWidth = 2}>> -> tensor<128x128xf32, #[[$ATTR_2]]> +// CHECK-NEXT: ttg.local_dealloc %[[VAL_23]] : !ttg.memdesc<1x128x32xf16, #[[$ATTR_3]], #[[$ATTR_5]], mutable> +// CHECK-NEXT: ttg.local_dealloc %[[VAL_24]] : !ttg.memdesc<1x32x128xf16, #[[$ATTR_4]], #[[$ATTR_5]], mutable> +// CHECK-NEXT: tt.return %[[VAL_58]] : tensor<128x128xf32, #[[$ATTR_2]]> +// CHECK-NEXT: } diff --git a/third_party/amd/include/Analysis/RangeAnalysis.h b/third_party/amd/include/Analysis/RangeAnalysis.h index 14d93f164264..6c4b95ddd28d 100644 --- a/third_party/amd/include/Analysis/RangeAnalysis.h +++ b/third_party/amd/include/Analysis/RangeAnalysis.h @@ -122,7 +122,8 @@ collectRanges(const DataFlowSolver &solver, ValueRange values); bool cmpIIsStaticallyTrue(const DataFlowSolver &solver, arith::CmpIOp cmpOp); -bool isEmptyInitializedRange(ConstantIntRanges rv); +void populateFoldTrueCmpIOpPatterns( + RewritePatternSet &patterns, std::shared_ptr solver); } // namespace mlir::triton::AMD diff --git a/third_party/amd/lib/Analysis/RangeAnalysis.cpp b/third_party/amd/lib/Analysis/RangeAnalysis.cpp index f384194b7abd..79d643c9918f 100644 --- a/third_party/amd/lib/Analysis/RangeAnalysis.cpp +++ b/third_party/amd/lib/Analysis/RangeAnalysis.cpp @@ -474,4 +474,31 @@ TritonIntegerRangeAnalysis::collectAssumptions(Operation *rootOp, return assumptions; } +struct FoldTrueCmpIOp : OpRewritePattern { + using OpRewritePattern::OpRewritePattern; + + FoldTrueCmpIOp(MLIRContext *context, std::shared_ptr solver) + : OpRewritePattern(context), solver(std::move(solver)) {}; + + LogicalResult matchAndRewrite(arith::CmpIOp cmpOp, + PatternRewriter &rewriter) const override { + if (cmpIIsStaticallyTrue(*solver, cmpOp)) { + if (failed(mlir::dataflow::maybeReplaceWithConstant(*solver, rewriter, + cmpOp.getResult()))) { + LDBG("failed to replace with constant op: " << cmpOp); + } + } else { + return failure(); + } + return success(); + } + + std::shared_ptr solver; +}; + +void populateFoldTrueCmpIOpPatterns(RewritePatternSet &patterns, + std::shared_ptr solver) { + patterns.add(patterns.getContext(), std::move(solver)); +} + } // namespace mlir::triton::AMD diff --git a/third_party/amd/lib/TritonAMDGPUTransforms/StreamPipeline.cpp b/third_party/amd/lib/TritonAMDGPUTransforms/StreamPipeline.cpp index 3bd19de27c9f..551ee8e894d9 100644 --- a/third_party/amd/lib/TritonAMDGPUTransforms/StreamPipeline.cpp +++ b/third_party/amd/lib/TritonAMDGPUTransforms/StreamPipeline.cpp @@ -1,6 +1,6 @@ #include "TritonAMDGPUTransforms/Passes.h" -#include "mlir/IR/TypeUtilities.h" #include "mlir/Support/LLVM.h" +#include "third_party/amd/include/Analysis/RangeAnalysis.h" #include "third_party/amd/include/Dialect/TritonAMDGPU/IR/Dialect.h" #include "third_party/amd/lib/TritonAMDGPUToLLVM/SchedInstructions.h" #include "triton/Analysis/AxisInfo.h" diff --git a/third_party/amd/test/lib/Analysis/CMakeLists.txt b/third_party/amd/test/lib/Analysis/CMakeLists.txt index 52342109a3e9..1c9b69db6618 100644 --- a/third_party/amd/test/lib/Analysis/CMakeLists.txt +++ b/third_party/amd/test/lib/Analysis/CMakeLists.txt @@ -1,5 +1,6 @@ add_mlir_library(TritonAMDGPUTestAnalysis TestAMDRangeAnalysis.cpp + TestFoldTrueCmpIOp.cpp DEPENDS TritonTableGen diff --git a/third_party/amd/test/lib/Analysis/TestFoldTrueCmpIOp.cpp b/third_party/amd/test/lib/Analysis/TestFoldTrueCmpIOp.cpp new file mode 100644 index 000000000000..73fd8a27bc9a --- /dev/null +++ b/third_party/amd/test/lib/Analysis/TestFoldTrueCmpIOp.cpp @@ -0,0 +1,47 @@ +#include "mlir/Dialect/LLVMIR/LLVMDialect.h" +#include "mlir/Pass/Pass.h" +#include "mlir/Transforms/GreedyPatternRewriteDriver.h" +#include "third_party/amd/include/Analysis/RangeAnalysis.h" +#include "triton/Analysis/Utility.h" + +using namespace mlir; +using namespace mlir::triton; + +namespace { + +struct TestAMDFoldTrueCmpIOpPass + : PassWrapper> { + + MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(TestAMDFoldTrueCmpIOpPass) + + StringRef getArgument() const final { + return "test-tritonamdgpu-fold-true-cmpi"; + } + StringRef getDescription() const final { + return "print the result of the tritonamdgpu-fold-true-cmpi pass"; + } + + void runOnOperation() override { + DenseMap> assumptions = + AMD::TritonIntegerRangeAnalysis::collectAssumptions(getOperation()); + std::shared_ptr solver = createDataFlowSolver(); + solver->load(assumptions); + if (failed(solver->initializeAndRun(getOperation()))) + return signalPassFailure(); + + ModuleOp mod = getOperation(); + RewritePatternSet patterns(&getContext()); + AMD::populateFoldTrueCmpIOpPatterns(patterns, solver); + if (failed(applyPatternsGreedily(mod, std::move(patterns)))) { + return signalPassFailure(); + } + } +}; + +} // namespace + +namespace mlir::test { +void registerTestTritonAMDGPUFoldTrueCmpIOp() { + PassRegistration(); +} +} // namespace mlir::test From efefa95b94d32e3a4cb02783a70cecfaca766d9d Mon Sep 17 00:00:00 2001 From: Maksim Levental Date: Fri, 28 Mar 2025 15:12:10 -0400 Subject: [PATCH 3/7] add fold-true-cmpi pattern to StreamPipeline.cpp --- .../TritonAMDGPUTransforms/StreamPipeline.cpp | 16 ++++++++++++++-- 1 file changed, 14 insertions(+), 2 deletions(-) diff --git a/third_party/amd/lib/TritonAMDGPUTransforms/StreamPipeline.cpp b/third_party/amd/lib/TritonAMDGPUTransforms/StreamPipeline.cpp index 551ee8e894d9..7ccc86d1b7a2 100644 --- a/third_party/amd/lib/TritonAMDGPUTransforms/StreamPipeline.cpp +++ b/third_party/amd/lib/TritonAMDGPUTransforms/StreamPipeline.cpp @@ -1,5 +1,6 @@ #include "TritonAMDGPUTransforms/Passes.h" #include "mlir/Support/LLVM.h" +#include "mlir/Transforms/GreedyPatternRewriteDriver.h" #include "third_party/amd/include/Analysis/RangeAnalysis.h" #include "third_party/amd/include/Dialect/TritonAMDGPU/IR/Dialect.h" #include "third_party/amd/lib/TritonAMDGPUToLLVM/SchedInstructions.h" @@ -1058,9 +1059,20 @@ struct PipelinePass : public TritonAMDGPUStreamPipelineBase { continue; StreamPipeliner sp(forOp, tt::getNumStagesOrDefault(forOp, numStages), globalPrefetch, localPrefetch, useAsyncCopy); - if (failed(sp.pipelineLoop())) - continue; + (void)sp.pipelineLoop(); } + + DenseMap> assumptions = + tt::AMD::TritonIntegerRangeAnalysis::collectAssumptions(getOperation()); + std::shared_ptr solver = createDataFlowSolver(); + solver->load(assumptions); + if (failed(solver->initializeAndRun(getOperation()))) + return signalPassFailure(); + + ModuleOp mod = getOperation(); + RewritePatternSet patterns(&getContext()); + tt::AMD::populateFoldTrueCmpIOpPatterns(patterns, solver); + (void)applyPatternsGreedily(mod, std::move(patterns)); } }; } // namespace From 0e55f558b482d258f7401eaee57d0bc449403750 Mon Sep 17 00:00:00 2001 From: Maksim Levental Date: Fri, 28 Mar 2025 15:21:50 -0400 Subject: [PATCH 4/7] add tests --- test/TritonGPU/amd/amd-fold-true-cmpi.mlir | 50 +++++++++++++++++++ .../amd/include/Analysis/RangeAnalysis.h | 6 ++- 2 files changed, 54 insertions(+), 2 deletions(-) diff --git a/test/TritonGPU/amd/amd-fold-true-cmpi.mlir b/test/TritonGPU/amd/amd-fold-true-cmpi.mlir index 05375550a2f0..b451c5dabcd3 100644 --- a/test/TritonGPU/amd/amd-fold-true-cmpi.mlir +++ b/test/TritonGPU/amd/amd-fold-true-cmpi.mlir @@ -1,5 +1,55 @@ // RUN: triton-opt %s -split-input-file -allow-unregistered-dialect -test-tritonamdgpu-fold-true-cmpi -canonicalize | FileCheck %s +module attributes {"ttg.num-warps" = 4 : i32} { + tt.func @cmpsle(%arg0: !tt.ptr) -> i1 { + %c0 = arith.constant 0 : i32 + %c1024_i32 = arith.constant 1024 : i32 + %cmpsle = arith.cmpi sle, %c0, %c1024_i32 : i32 + tt.return %cmpsle: i1 + } +} + +// CHECK-LABEL: tt.func @cmpsle( +// CHECK-SAME: %[[VAL_0:.*]]: !tt.ptr) -> i1 { +// CHECK: %[[VAL_1:.*]] = arith.constant true +// CHECK: tt.return %[[VAL_1]] : i1 +// CHECK: } + +// ----- + +module attributes {"ttg.num-warps" = 4 : i32} { + tt.func @assumepid(%arg0: !tt.ptr) -> tensor<1024xf32> { + %c0 = arith.constant 0 : i32 + %c1024_i32 = arith.constant 1024 : i32 + %pid = tt.get_program_id x : i32 + %cmpsle = arith.cmpi sle, %pid, %c1024_i32 : i32 + llvm.intr.assume %cmpsle : i1 + %cmpsge = arith.cmpi sge, %pid, %c0 : i32 + llvm.intr.assume %cmpsge : i1 + %1 = arith.muli %pid, %c1024_i32 : i32 + %2 = tt.addptr %arg0, %1 : !tt.ptr, i32 + %3 = tt.splat %2 : !tt.ptr -> tensor<1024x!tt.ptr> + %4 = tt.load %3 : tensor<1024x!tt.ptr> + tt.return %4 : tensor<1024xf32> + } +} + +// CHECK-LABEL: tt.func @assumepid( +// CHECK-SAME: %[[VAL_0:.*]]: !tt.ptr) -> tensor<1024xf32> { +// CHECK: %[[VAL_1:.*]] = arith.constant true +// CHECK: %[[VAL_2:.*]] = arith.constant 1024 : i32 +// CHECK: %[[VAL_3:.*]] = tt.get_program_id x : i32 +// CHECK: llvm.intr.assume %[[VAL_1]] : i1 +// CHECK: llvm.intr.assume %[[VAL_1]] : i1 +// CHECK: %[[VAL_4:.*]] = arith.muli %[[VAL_3]], %[[VAL_2]] : i32 +// CHECK: %[[VAL_5:.*]] = tt.addptr %[[VAL_0]], %[[VAL_4]] : !tt.ptr, i32 +// CHECK: %[[VAL_6:.*]] = tt.splat %[[VAL_5]] : !tt.ptr -> tensor<1024x!tt.ptr> +// CHECK: %[[VAL_7:.*]] = tt.load %[[VAL_6]] : tensor<1024x!tt.ptr> +// CHECK: tt.return %[[VAL_7]] : tensor<1024xf32> +// CHECK: } + +// ----- + #blocked = #ttg.blocked<{sizePerThread = [1, 4], threadsPerWarp = [1, 32], warpsPerCTA = [4, 1], order = [1, 0]}> #blocked1 = #ttg.blocked<{sizePerThread = [1, 4], threadsPerWarp = [4, 8], warpsPerCTA = [4, 1], order = [1, 0]}> #mma = #ttg.nvidia_mma<{versionMajor = 2, versionMinor = 0, warpsPerCTA = [4, 1], instrShape = []}> diff --git a/third_party/amd/include/Analysis/RangeAnalysis.h b/third_party/amd/include/Analysis/RangeAnalysis.h index 6c4b95ddd28d..6d2d4fb11a85 100644 --- a/third_party/amd/include/Analysis/RangeAnalysis.h +++ b/third_party/amd/include/Analysis/RangeAnalysis.h @@ -122,8 +122,10 @@ collectRanges(const DataFlowSolver &solver, ValueRange values); bool cmpIIsStaticallyTrue(const DataFlowSolver &solver, arith::CmpIOp cmpOp); -void populateFoldTrueCmpIOpPatterns( - RewritePatternSet &patterns, std::shared_ptr solver); +bool isEmptyInitializedRange(ConstantIntRanges rv); + +void populateFoldTrueCmpIOpPatterns(RewritePatternSet &patterns, + std::shared_ptr solver); } // namespace mlir::triton::AMD From 9bc5cec7d9751153dcfec419ae98c23156403c7c Mon Sep 17 00:00:00 2001 From: Maksim Levental Date: Tue, 1 Apr 2025 16:41:36 -0400 Subject: [PATCH 5/7] special case in test_assume --- python/test/unit/language/test_core.py | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/python/test/unit/language/test_core.py b/python/test/unit/language/test_core.py index c8ce5ed9ff1b..09e04ddc455d 100644 --- a/python/test/unit/language/test_core.py +++ b/python/test/unit/language/test_core.py @@ -4666,7 +4666,10 @@ def _kernel(out_ptr, N: tl.constexpr, BLOCK_N: tl.constexpr): if is_interpreter(): return - assert 'llvm.assume' in pgm.asm['llir'] + assert 'llvm.intr.assume' in pgm.asm['ttgir'] + # stream pipeliner on AMD folds true cmpi ops to %true (Which llvm itself then dces) + if not is_hip(): + assert 'llvm.assume' in pgm.asm['llir'] # --------------- From e6703de1277945afc2f6bbb23c873271791384e1 Mon Sep 17 00:00:00 2001 From: Maksim Levental Date: Tue, 1 Apr 2025 17:22:19 -0400 Subject: [PATCH 6/7] address comments --- third_party/amd/include/Analysis/RangeAnalysis.h | 2 +- third_party/amd/lib/Analysis/RangeAnalysis.cpp | 11 ++++++----- .../amd/lib/TritonAMDGPUTransforms/StreamPipeline.cpp | 4 ++-- .../amd/test/lib/Analysis/TestFoldTrueCmpIOp.cpp | 4 ++-- 4 files changed, 11 insertions(+), 10 deletions(-) diff --git a/third_party/amd/include/Analysis/RangeAnalysis.h b/third_party/amd/include/Analysis/RangeAnalysis.h index 6d2d4fb11a85..32e8fc88faa4 100644 --- a/third_party/amd/include/Analysis/RangeAnalysis.h +++ b/third_party/amd/include/Analysis/RangeAnalysis.h @@ -125,7 +125,7 @@ bool cmpIIsStaticallyTrue(const DataFlowSolver &solver, arith::CmpIOp cmpOp); bool isEmptyInitializedRange(ConstantIntRanges rv); void populateFoldTrueCmpIOpPatterns(RewritePatternSet &patterns, - std::shared_ptr solver); + DataFlowSolver *solver); } // namespace mlir::triton::AMD diff --git a/third_party/amd/lib/Analysis/RangeAnalysis.cpp b/third_party/amd/lib/Analysis/RangeAnalysis.cpp index 79d643c9918f..82bb912bc220 100644 --- a/third_party/amd/lib/Analysis/RangeAnalysis.cpp +++ b/third_party/amd/lib/Analysis/RangeAnalysis.cpp @@ -477,8 +477,8 @@ TritonIntegerRangeAnalysis::collectAssumptions(Operation *rootOp, struct FoldTrueCmpIOp : OpRewritePattern { using OpRewritePattern::OpRewritePattern; - FoldTrueCmpIOp(MLIRContext *context, std::shared_ptr solver) - : OpRewritePattern(context), solver(std::move(solver)) {}; + FoldTrueCmpIOp(MLIRContext *context, DataFlowSolver *solver) + : OpRewritePattern(context), solver(solver) {}; LogicalResult matchAndRewrite(arith::CmpIOp cmpOp, PatternRewriter &rewriter) const override { @@ -486,6 +486,7 @@ struct FoldTrueCmpIOp : OpRewritePattern { if (failed(mlir::dataflow::maybeReplaceWithConstant(*solver, rewriter, cmpOp.getResult()))) { LDBG("failed to replace with constant op: " << cmpOp); + return failure(); } } else { return failure(); @@ -493,12 +494,12 @@ struct FoldTrueCmpIOp : OpRewritePattern { return success(); } - std::shared_ptr solver; + DataFlowSolver *solver; }; void populateFoldTrueCmpIOpPatterns(RewritePatternSet &patterns, - std::shared_ptr solver) { - patterns.add(patterns.getContext(), std::move(solver)); + DataFlowSolver *solver) { + patterns.add(patterns.getContext(), solver); } } // namespace mlir::triton::AMD diff --git a/third_party/amd/lib/TritonAMDGPUTransforms/StreamPipeline.cpp b/third_party/amd/lib/TritonAMDGPUTransforms/StreamPipeline.cpp index 7ccc86d1b7a2..7e3e80fcd8a3 100644 --- a/third_party/amd/lib/TritonAMDGPUTransforms/StreamPipeline.cpp +++ b/third_party/amd/lib/TritonAMDGPUTransforms/StreamPipeline.cpp @@ -1064,14 +1064,14 @@ struct PipelinePass : public TritonAMDGPUStreamPipelineBase { DenseMap> assumptions = tt::AMD::TritonIntegerRangeAnalysis::collectAssumptions(getOperation()); - std::shared_ptr solver = createDataFlowSolver(); + std::unique_ptr solver = createDataFlowSolver(); solver->load(assumptions); if (failed(solver->initializeAndRun(getOperation()))) return signalPassFailure(); ModuleOp mod = getOperation(); RewritePatternSet patterns(&getContext()); - tt::AMD::populateFoldTrueCmpIOpPatterns(patterns, solver); + tt::AMD::populateFoldTrueCmpIOpPatterns(patterns, solver.get()); (void)applyPatternsGreedily(mod, std::move(patterns)); } }; diff --git a/third_party/amd/test/lib/Analysis/TestFoldTrueCmpIOp.cpp b/third_party/amd/test/lib/Analysis/TestFoldTrueCmpIOp.cpp index 73fd8a27bc9a..0b8577669aff 100644 --- a/third_party/amd/test/lib/Analysis/TestFoldTrueCmpIOp.cpp +++ b/third_party/amd/test/lib/Analysis/TestFoldTrueCmpIOp.cpp @@ -24,14 +24,14 @@ struct TestAMDFoldTrueCmpIOpPass void runOnOperation() override { DenseMap> assumptions = AMD::TritonIntegerRangeAnalysis::collectAssumptions(getOperation()); - std::shared_ptr solver = createDataFlowSolver(); + std::unique_ptr solver = createDataFlowSolver(); solver->load(assumptions); if (failed(solver->initializeAndRun(getOperation()))) return signalPassFailure(); ModuleOp mod = getOperation(); RewritePatternSet patterns(&getContext()); - AMD::populateFoldTrueCmpIOpPatterns(patterns, solver); + AMD::populateFoldTrueCmpIOpPatterns(patterns, solver.get()); if (failed(applyPatternsGreedily(mod, std::move(patterns)))) { return signalPassFailure(); } From 74a5f63a23d8da2e58df2afd8fb3669ba012b86e Mon Sep 17 00:00:00 2001 From: Maksim Levental Date: Tue, 1 Apr 2025 19:01:13 -0400 Subject: [PATCH 7/7] fix bitPosition error --- .../TritonAMDGPUTransforms/ConvertToBufferOps.cpp | 14 ++++++++++++-- 1 file changed, 12 insertions(+), 2 deletions(-) diff --git a/third_party/amd/lib/TritonAMDGPUTransforms/ConvertToBufferOps.cpp b/third_party/amd/lib/TritonAMDGPUTransforms/ConvertToBufferOps.cpp index 459514b17a62..1d24e82241ef 100644 --- a/third_party/amd/lib/TritonAMDGPUTransforms/ConvertToBufferOps.cpp +++ b/third_party/amd/lib/TritonAMDGPUTransforms/ConvertToBufferOps.cpp @@ -84,8 +84,18 @@ bool verifyNonNegativeExpr(Value expr, const DenseSet &assumptions, std::shared_ptr solver) { LDBG("Determing if non-negative: " << expr); - if (!llvm::isa(expr) && - succeeded(dataflow::staticallyNonNegative(*solver, expr))) { + auto nonNegativePred = [&solver](Value v) -> bool { + if (const auto *r = + solver->lookupState(v)) { + if (r->getValue().isUninitialized()) + return false; + if (AMD::isEmptyInitializedRange(r->getValue().getValue())) + return false; + } + return succeeded(dataflow::staticallyNonNegative(*solver, v)); + }; + + if (!llvm::isa(expr) && nonNegativePred(expr)) { return true; }