Skip to content

Commit c8f0fce

Browse files
committed
add fold-true-cmpi pattern/test pass
1 parent 58402e8 commit c8f0fce

File tree

7 files changed

+188
-2
lines changed

7 files changed

+188
-2
lines changed

bin/RegisterTritonDialects.h

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -34,6 +34,7 @@ void registerTestAlignmentPass();
3434
void registerTestAllocationPass();
3535
void registerTestMembarPass();
3636
void registerTestTritonAMDGPURangeAnalysis();
37+
void registerTestTritonAMDGPUFoldTrueCmpIOp();
3738
} // namespace test
3839
} // namespace mlir
3940

@@ -47,6 +48,7 @@ inline void registerTritonDialects(mlir::DialectRegistry &registry) {
4748
mlir::test::registerTestAllocationPass();
4849
mlir::test::registerTestMembarPass();
4950
mlir::test::registerTestTritonAMDGPURangeAnalysis();
51+
mlir::test::registerTestTritonAMDGPUFoldTrueCmpIOp();
5052
mlir::triton::registerConvertTritonToTritonGPUPass();
5153
mlir::triton::gpu::registerAllocateSharedMemoryPass();
5254
mlir::triton::gpu::registerTritonGPUAllocateWarpGroups();
Lines changed: 108 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,108 @@
1+
// RUN: triton-opt %s -split-input-file -allow-unregistered-dialect -test-tritonamdgpu-fold-true-cmpi -canonicalize | FileCheck %s
2+
3+
#blocked = #ttg.blocked<{sizePerThread = [1, 4], threadsPerWarp = [1, 32], warpsPerCTA = [4, 1], order = [1, 0]}>
4+
#blocked1 = #ttg.blocked<{sizePerThread = [1, 4], threadsPerWarp = [4, 8], warpsPerCTA = [4, 1], order = [1, 0]}>
5+
#mma = #ttg.nvidia_mma<{versionMajor = 2, versionMinor = 0, warpsPerCTA = [4, 1], instrShape = []}>
6+
#shared = #ttg.swizzled_shared<{vec = 8, perPhase = 2, maxPhase = 4, order = [1, 0]}>
7+
#shared1 = #ttg.swizzled_shared<{vec = 8, perPhase = 1, maxPhase = 8, order = [1, 0]}>
8+
#smem = #ttg.shared_memory
9+
module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32} {
10+
tt.func @assume_matmul(%arg0: index, %arg1: index, %arg2: index, %arg3: !tt.ptr<f16>, %arg4: !tt.ptr<f16>) -> tensor<128x128xf32, #mma> {
11+
%c-1 = arith.constant -1 : index
12+
%c1 = arith.constant 1 : index
13+
%c0 = arith.constant 0 : index
14+
%c1_i32 = arith.constant 1 : i32
15+
%c0_i32 = arith.constant 0 : i32
16+
%true = arith.constant true
17+
%cst = arith.constant dense<4.000000e+00> : tensor<32x128xf16, #ttg.dot_op<{opIdx = 1, parent = #mma, kWidth = 2}>>
18+
%cst_0 = arith.constant dense<4> : tensor<32x128xi32, #blocked>
19+
%cst_1 = arith.constant dense<4> : tensor<128x32xi32, #blocked1>
20+
%cst_2 = arith.constant dense<0.000000e+00> : tensor<128x128xf32, #mma>
21+
%cst_3 = arith.constant dense<0.000000e+00> : tensor<32x128xf16, #blocked>
22+
%0 = tt.splat %arg3 : !tt.ptr<f16> -> tensor<128x32x!tt.ptr<f16>, #blocked1>
23+
%1 = tt.make_range {end = 32 : i32, start = 0 : i32} : tensor<32xi32, #ttg.slice<{dim = 0, parent = #blocked1}>>
24+
%2 = tt.expand_dims %1 {axis = 0 : i32} : tensor<32xi32, #ttg.slice<{dim = 0, parent = #blocked1}>> -> tensor<1x32xi32, #blocked1>
25+
%3 = tt.broadcast %2 : tensor<1x32xi32, #blocked1> -> tensor<128x32xi32, #blocked1>
26+
%4 = tt.addptr %0, %3 : tensor<128x32x!tt.ptr<f16>, #blocked1>, tensor<128x32xi32, #blocked1>
27+
%5 = tt.splat %arg4 : !tt.ptr<f16> -> tensor<32x128x!tt.ptr<f16>, #blocked>
28+
%6 = tt.make_range {end = 128 : i32, start = 0 : i32} : tensor<128xi32, #ttg.slice<{dim = 0, parent = #blocked}>>
29+
%7 = tt.expand_dims %6 {axis = 0 : i32} : tensor<128xi32, #ttg.slice<{dim = 0, parent = #blocked}>> -> tensor<1x128xi32, #blocked>
30+
%8 = tt.broadcast %7 : tensor<1x128xi32, #blocked> -> tensor<32x128xi32, #blocked>
31+
%9 = tt.addptr %5, %8 : tensor<32x128x!tt.ptr<f16>, #blocked>, tensor<32x128xi32, #blocked>
32+
%10 = ttg.local_alloc : () -> !ttg.memdesc<1x128x32xf16, #shared, #smem, mutable>
33+
%11 = ttg.local_alloc : () -> !ttg.memdesc<1x32x128xf16, #shared1, #smem, mutable>
34+
%12 = arith.cmpi slt, %arg0, %arg1 : index
35+
%13 = tt.splat %12 : i1 -> tensor<128x32xi1, #blocked1>
36+
%14 = tt.load %4, %13 {OpIdx = #amdgpu.OpIdx<0>} : tensor<128x32x!tt.ptr<f16>, #blocked1>
37+
%15 = tt.splat %12 : i1 -> tensor<32x128xi1, #blocked>
38+
%16 = tt.load %9, %15, %cst_3 : tensor<32x128x!tt.ptr<f16>, #blocked>
39+
%17 = ttg.memdesc_subview %10[%c0_i32, %c0_i32, %c0_i32] : !ttg.memdesc<1x128x32xf16, #shared, #smem, mutable> -> !ttg.memdesc<128x32xf16, #shared, #smem, mutable>
40+
ttg.local_store %14, %17 {OpIdx = #amdgpu.OpIdx<0>} : tensor<128x32xf16, #blocked1> -> !ttg.memdesc<128x32xf16, #shared, #smem, mutable>
41+
%18 = ttg.memdesc_subview %11[%c0_i32, %c0_i32, %c0_i32] : !ttg.memdesc<1x32x128xf16, #shared1, #smem, mutable> -> !ttg.memdesc<32x128xf16, #shared1, #smem, mutable>
42+
ttg.local_store %16, %18 : tensor<32x128xf16, #blocked> -> !ttg.memdesc<32x128xf16, #shared1, #smem, mutable>
43+
%19 = arith.subi %arg1, %arg2 : index
44+
%20:6 = scf.for %arg5 = %arg0 to %19 step %arg2 iter_args(%arg6 = %4, %arg7 = %9, %arg8 = %cst_2, %arg9 = %c0_i32, %arg10 = %17, %arg11 = %18) -> (tensor<128x32x!tt.ptr<f16>, #blocked1>, tensor<32x128x!tt.ptr<f16>, #blocked>, tensor<128x128xf32, #mma>, i32, !ttg.memdesc<128x32xf16, #shared, #smem, mutable>, !ttg.memdesc<32x128xf16, #shared1, #smem, mutable>) {
45+
%33 = tt.addptr %arg6, %cst_1 : tensor<128x32x!tt.ptr<f16>, #blocked1>, tensor<128x32xi32, #blocked1>
46+
%34 = tt.addptr %arg7, %cst_0 : tensor<32x128x!tt.ptr<f16>, #blocked>, tensor<32x128xi32, #blocked>
47+
llvm.intr.assume %true : i1
48+
%35 = tt.load %33 {OpIdx = #amdgpu.OpIdx<0>} : tensor<128x32x!tt.ptr<f16>, #blocked1>
49+
%36 = ttg.local_load %arg10 : !ttg.memdesc<128x32xf16, #shared, #smem, mutable> -> tensor<128x32xf16, #ttg.dot_op<{opIdx = 0, parent = #mma, kWidth = 2}>>
50+
%37 = tt.load %34 : tensor<32x128x!tt.ptr<f16>, #blocked>
51+
%38 = ttg.local_load %arg11 : !ttg.memdesc<32x128xf16, #shared1, #smem, mutable> -> tensor<32x128xf16, #ttg.dot_op<{opIdx = 1, parent = #mma, kWidth = 2}>>
52+
%39 = arith.mulf %38, %cst : tensor<32x128xf16, #ttg.dot_op<{opIdx = 1, parent = #mma, kWidth = 2}>>
53+
%40 = tt.dot %36, %39, %arg8 : tensor<128x32xf16, #ttg.dot_op<{opIdx = 0, parent = #mma, kWidth = 2}>> * tensor<32x128xf16, #ttg.dot_op<{opIdx = 1, parent = #mma, kWidth = 2}>> -> tensor<128x128xf32, #mma>
54+
%41 = arith.addi %arg9, %c1_i32 : i32
55+
%42 = arith.cmpi slt, %41, %c1_i32 : i32
56+
%43 = arith.select %42, %41, %c0_i32 : i32
57+
%44 = ttg.memdesc_subview %10[%43, %c0_i32, %c0_i32] : !ttg.memdesc<1x128x32xf16, #shared, #smem, mutable> -> !ttg.memdesc<128x32xf16, #shared, #smem, mutable>
58+
ttg.local_store %35, %44 {OpIdx = #amdgpu.OpIdx<0>} : tensor<128x32xf16, #blocked1> -> !ttg.memdesc<128x32xf16, #shared, #smem, mutable>
59+
%45 = ttg.memdesc_subview %11[%43, %c0_i32, %c0_i32] : !ttg.memdesc<1x32x128xf16, #shared1, #smem, mutable> -> !ttg.memdesc<32x128xf16, #shared1, #smem, mutable>
60+
ttg.local_store %37, %45 : tensor<32x128xf16, #blocked> -> !ttg.memdesc<32x128xf16, #shared1, #smem, mutable>
61+
scf.yield %33, %34, %40, %43, %44, %45 : tensor<128x32x!tt.ptr<f16>, #blocked1>, tensor<32x128x!tt.ptr<f16>, #blocked>, tensor<128x128xf32, #mma>, i32, !ttg.memdesc<128x32xf16, #shared, #smem, mutable>, !ttg.memdesc<32x128xf16, #shared1, #smem, mutable>
62+
}
63+
%21 = arith.cmpi slt, %arg2, %c0 : index
64+
%22 = arith.select %21, %c1, %c-1 : index
65+
%23 = arith.subi %arg1, %arg0 : index
66+
%24 = arith.addi %23, %arg2 : index
67+
%25 = arith.addi %24, %22 : index
68+
%26 = arith.divsi %25, %arg2 : index
69+
%28 = ttg.local_load %20#4 : !ttg.memdesc<128x32xf16, #shared, #smem, mutable> -> tensor<128x32xf16, #ttg.dot_op<{opIdx = 0, parent = #mma, kWidth = 2}>>
70+
%29 = ttg.local_load %20#5 : !ttg.memdesc<32x128xf16, #shared1, #smem, mutable> -> tensor<32x128xf16, #ttg.dot_op<{opIdx = 1, parent = #mma, kWidth = 2}>>
71+
%30 = arith.mulf %29, %cst : tensor<32x128xf16, #ttg.dot_op<{opIdx = 1, parent = #mma, kWidth = 2}>>
72+
%27 = arith.cmpi sge, %26, %c1 : index
73+
llvm.intr.assume %27 : i1
74+
%31 = scf.if %27 -> (tensor<128x128xf32, #mma>) {
75+
%33 = tt.dot %28, %30, %20#2 : tensor<128x32xf16, #ttg.dot_op<{opIdx = 0, parent = #mma, kWidth = 2}>> * tensor<32x128xf16, #ttg.dot_op<{opIdx = 1, parent = #mma, kWidth = 2}>> -> tensor<128x128xf32, #mma>
76+
scf.yield %33 : tensor<128x128xf32, #mma>
77+
} else {
78+
scf.yield %20#2 : tensor<128x128xf32, #mma>
79+
}
80+
%32 = arith.select %27, %31, %20#2 : tensor<128x128xf32, #mma>
81+
ttg.local_dealloc %10 : !ttg.memdesc<1x128x32xf16, #shared, #smem, mutable>
82+
ttg.local_dealloc %11 : !ttg.memdesc<1x32x128xf16, #shared1, #smem, mutable>
83+
tt.return %32 : tensor<128x128xf32, #mma>
84+
}
85+
}
86+
87+
// CHECK: #[[$ATTR_2:.+]] = #ttg.nvidia_mma<{versionMajor = 2, versionMinor = 0, warpsPerCTA = [4, 1], instrShape = []}>
88+
// CHECK: #[[$ATTR_3:.+]] = #ttg.swizzled_shared<{vec = 8, perPhase = 2, maxPhase = 4, order = [1, 0]}>
89+
// CHECK: #[[$ATTR_4:.+]] = #ttg.swizzled_shared<{vec = 8, perPhase = 1, maxPhase = 8, order = [1, 0]}>
90+
// CHECK: #[[$ATTR_5:.+]] = #ttg.shared_memory
91+
92+
// CHECK-LABEL: tt.func @assume_matmul(
93+
// CHECK: %[[VAL_7:.*]] = arith.constant true
94+
// CHECK: %[[VAL_8:.*]] = arith.constant dense<4.000000e+00> : tensor<32x128xf16, #ttg.dot_op<{opIdx = 1, parent = #[[$ATTR_2]], kWidth = 2}>>
95+
// CHECK: %[[VAL_23:.*]] = ttg.local_alloc : () -> !ttg.memdesc<1x128x32xf16, #[[$ATTR_3]], #[[$ATTR_5]], mutable>
96+
// CHECK: %[[VAL_24:.*]] = ttg.local_alloc : () -> !ttg.memdesc<1x32x128xf16, #[[$ATTR_4]], #[[$ATTR_5]], mutable>
97+
// CHECK: %[[VAL_33:.*]]:6 = scf.for
98+
// CHECK: scf.yield
99+
// CHECK: }
100+
// CHECK-NEXT: %[[VAL_54:.*]] = ttg.local_load %[[VAL_55:.*]]#4 : !ttg.memdesc<128x32xf16, #[[$ATTR_3]], #[[$ATTR_5]], mutable> -> tensor<128x32xf16, #ttg.dot_op<{opIdx = 0, parent = #[[$ATTR_2]], kWidth = 2}>>
101+
// CHECK-NEXT: %[[VAL_56:.*]] = ttg.local_load %[[VAL_55]]#5 : !ttg.memdesc<32x128xf16, #[[$ATTR_4]], #[[$ATTR_5]], mutable> -> tensor<32x128xf16, #ttg.dot_op<{opIdx = 1, parent = #[[$ATTR_2]], kWidth = 2}>>
102+
// CHECK-NEXT: %[[VAL_57:.*]] = arith.mulf %[[VAL_56]], %[[VAL_8]] : tensor<32x128xf16, #ttg.dot_op<{opIdx = 1, parent = #[[$ATTR_2]], kWidth = 2}>>
103+
// CHECK-NEXT: llvm.intr.assume %[[VAL_7]] : i1
104+
// CHECK-NEXT: %[[VAL_58:.*]] = tt.dot %[[VAL_54]], %[[VAL_57]], %[[VAL_55]]#2 : tensor<128x32xf16, #ttg.dot_op<{opIdx = 0, parent = #[[$ATTR_2]], kWidth = 2}>> * tensor<32x128xf16, #ttg.dot_op<{opIdx = 1, parent = #[[$ATTR_2]], kWidth = 2}>> -> tensor<128x128xf32, #[[$ATTR_2]]>
105+
// CHECK-NEXT: ttg.local_dealloc %[[VAL_23]] : !ttg.memdesc<1x128x32xf16, #[[$ATTR_3]], #[[$ATTR_5]], mutable>
106+
// CHECK-NEXT: ttg.local_dealloc %[[VAL_24]] : !ttg.memdesc<1x32x128xf16, #[[$ATTR_4]], #[[$ATTR_5]], mutable>
107+
// CHECK-NEXT: tt.return %[[VAL_58]] : tensor<128x128xf32, #[[$ATTR_2]]>
108+
// CHECK-NEXT: }

third_party/amd/include/Analysis/RangeAnalysis.h

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -122,7 +122,8 @@ collectRanges(const DataFlowSolver &solver, ValueRange values);
122122

123123
bool cmpIIsStaticallyTrue(const DataFlowSolver &solver, arith::CmpIOp cmpOp);
124124

125-
bool isEmptyInitializedRange(ConstantIntRanges rv);
125+
void populateFoldTrueCmpIOpPatterns(
126+
RewritePatternSet &patterns, std::shared_ptr<DataFlowSolver> solver);
126127

127128
} // namespace mlir::triton::AMD
128129

third_party/amd/lib/Analysis/RangeAnalysis.cpp

Lines changed: 27 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -474,4 +474,31 @@ TritonIntegerRangeAnalysis::collectAssumptions(Operation *rootOp,
474474
return assumptions;
475475
}
476476

477+
struct FoldTrueCmpIOp : OpRewritePattern<arith::CmpIOp> {
478+
using OpRewritePattern::OpRewritePattern;
479+
480+
FoldTrueCmpIOp(MLIRContext *context, std::shared_ptr<DataFlowSolver> solver)
481+
: OpRewritePattern(context), solver(std::move(solver)) {};
482+
483+
LogicalResult matchAndRewrite(arith::CmpIOp cmpOp,
484+
PatternRewriter &rewriter) const override {
485+
if (cmpIIsStaticallyTrue(*solver, cmpOp)) {
486+
if (failed(mlir::dataflow::maybeReplaceWithConstant(*solver, rewriter,
487+
cmpOp.getResult()))) {
488+
LDBG("failed to replace with constant op: " << cmpOp);
489+
}
490+
} else {
491+
return failure();
492+
}
493+
return success();
494+
}
495+
496+
std::shared_ptr<DataFlowSolver> solver;
497+
};
498+
499+
void populateFoldTrueCmpIOpPatterns(RewritePatternSet &patterns,
500+
std::shared_ptr<DataFlowSolver> solver) {
501+
patterns.add<FoldTrueCmpIOp>(patterns.getContext(), std::move(solver));
502+
}
503+
477504
} // namespace mlir::triton::AMD

third_party/amd/lib/TritonAMDGPUTransforms/StreamPipeline.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
#include "TritonAMDGPUTransforms/Passes.h"
2-
#include "mlir/IR/TypeUtilities.h"
32
#include "mlir/Support/LLVM.h"
3+
#include "third_party/amd/include/Analysis/RangeAnalysis.h"
44
#include "third_party/amd/include/Dialect/TritonAMDGPU/IR/Dialect.h"
55
#include "third_party/amd/lib/TritonAMDGPUToLLVM/SchedInstructions.h"
66
#include "triton/Analysis/AxisInfo.h"

third_party/amd/test/lib/Analysis/CMakeLists.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
add_mlir_library(TritonAMDGPUTestAnalysis
22
TestAMDRangeAnalysis.cpp
3+
TestFoldTrueCmpIOp.cpp
34

45
DEPENDS
56
TritonTableGen
Lines changed: 47 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,47 @@
1+
#include "mlir/Dialect/LLVMIR/LLVMDialect.h"
2+
#include "mlir/Pass/Pass.h"
3+
#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
4+
#include "third_party/amd/include/Analysis/RangeAnalysis.h"
5+
#include "triton/Analysis/Utility.h"
6+
7+
using namespace mlir;
8+
using namespace mlir::triton;
9+
10+
namespace {
11+
12+
struct TestAMDFoldTrueCmpIOpPass
13+
: PassWrapper<TestAMDFoldTrueCmpIOpPass, OperationPass<ModuleOp>> {
14+
15+
MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(TestAMDFoldTrueCmpIOpPass)
16+
17+
StringRef getArgument() const final {
18+
return "test-tritonamdgpu-fold-true-cmpi";
19+
}
20+
StringRef getDescription() const final {
21+
return "print the result of the tritonamdgpu-fold-true-cmpi pass";
22+
}
23+
24+
void runOnOperation() override {
25+
DenseMap<Value, SetVector<Operation *>> assumptions =
26+
AMD::TritonIntegerRangeAnalysis::collectAssumptions(getOperation());
27+
std::shared_ptr<DataFlowSolver> solver = createDataFlowSolver();
28+
solver->load<AMD::TritonIntegerRangeAnalysis>(assumptions);
29+
if (failed(solver->initializeAndRun(getOperation())))
30+
return signalPassFailure();
31+
32+
ModuleOp mod = getOperation();
33+
RewritePatternSet patterns(&getContext());
34+
AMD::populateFoldTrueCmpIOpPatterns(patterns, solver);
35+
if (failed(applyPatternsGreedily(mod, std::move(patterns)))) {
36+
return signalPassFailure();
37+
}
38+
}
39+
};
40+
41+
} // namespace
42+
43+
namespace mlir::test {
44+
void registerTestTritonAMDGPUFoldTrueCmpIOp() {
45+
PassRegistration<TestAMDFoldTrueCmpIOpPass>();
46+
}
47+
} // namespace mlir::test

0 commit comments

Comments
 (0)