Skip to content

Commit ef1a296

Browse files
[BACKEND] llvm::dyn_cast -> llvm::dyn_cast_or_null (#689)
1 parent b0e1cba commit ef1a296

File tree

4 files changed

+55
-8
lines changed

4 files changed

+55
-8
lines changed

.github/workflows/integration-tests.yml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -56,7 +56,7 @@ jobs:
5656
- name: Check cpp style
5757
if: ${{ matrix.runner != 'macos-latest' }}
5858
run: |
59-
sudo apt-get install -y clang-format
59+
pip install clang-format
6060
find . -regex '.*\.\(cpp\|hpp\|h\|cc\)' -not -path "./python/build/*" -not -path "./include/triton/external/*" -print0 | xargs -0 -n1 clang-format -style=file --dry-run -Werror -i ||
6161
(echo '::error title=Style issues:: Please run `find . -regex ".*\.\(cpp\|hpp\|h\|cc\)" -not -path "./python/build/*" -not -path "./include/triton/external/*" -print0 | xargs -0 -n1 clang-format -style=file -i`' ; exit 1)
6262

lib/Dialect/Triton/Transforms/Combine.cpp

Lines changed: 6 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -71,7 +71,7 @@ class CombineSelectMaskedLoadPattern : public mlir::RewritePattern {
7171
mlir::Value falseValue = selectOp.getFalseValue();
7272

7373
auto *loadOpCandidate = trueValue.getDefiningOp();
74-
auto loadOp = llvm::dyn_cast<triton::LoadOp>(loadOpCandidate);
74+
auto loadOp = llvm::dyn_cast_or_null<triton::LoadOp>(loadOpCandidate);
7575
if (!loadOp)
7676
return mlir::failure();
7777

@@ -81,7 +81,7 @@ class CombineSelectMaskedLoadPattern : public mlir::RewritePattern {
8181

8282
auto *broadcastOpCandidate = mask.getDefiningOp();
8383
auto broadcastOp =
84-
llvm::dyn_cast<triton::BroadcastOp>(broadcastOpCandidate);
84+
llvm::dyn_cast_or_null<triton::BroadcastOp>(broadcastOpCandidate);
8585
if (!broadcastOp)
8686
return mlir::failure();
8787

@@ -106,7 +106,8 @@ struct CanonicalizeMaskedLoadPattern
106106
if (!mask)
107107
return mlir::failure();
108108

109-
auto constantMask = llvm::dyn_cast<arith::ConstantOp>(mask.getDefiningOp());
109+
auto constantMask =
110+
llvm::dyn_cast_or_null<arith::ConstantOp>(mask.getDefiningOp());
110111
if (!constantMask)
111112
return mlir::failure();
112113

@@ -152,7 +153,8 @@ struct CanonicalizeMaskedStorePattern
152153
if (!mask)
153154
return mlir::failure();
154155

155-
auto constantMask = llvm::dyn_cast<arith::ConstantOp>(mask.getDefiningOp());
156+
auto constantMask =
157+
llvm::dyn_cast_or_null<arith::ConstantOp>(mask.getDefiningOp());
156158
if (!constantMask)
157159
return mlir::failure();
158160

lib/Dialect/TritonGPU/Transforms/Pipeline.cpp

Lines changed: 11 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -301,9 +301,17 @@ void LoopPipeliner::emitPrologue() {
301301
}
302302

303303
// If this is a load/async_copy, we need to update the mask
304-
if (llvm::isa<triton::LoadOp, triton::gpu::InsertSliceAsyncOp>(newOp)) {
305-
Value mask = llvm::isa<triton::LoadOp>(newOp) ? newOp->getOperand(1)
306-
: newOp->getOperand(3);
304+
if (Value mask = [&]() {
305+
if (auto loadOp = llvm::dyn_cast<triton::LoadOp>(newOp)) {
306+
return loadOp.mask();
307+
} else if (auto insertSliceAsyncOp =
308+
llvm::dyn_cast<triton::gpu::InsertSliceAsyncOp>(
309+
newOp)) {
310+
return insertSliceAsyncOp.mask();
311+
} else {
312+
return mlir::Value();
313+
}
314+
}()) {
307315
// assert(I1 or TensorOf<[I1]>);
308316
OpBuilder::InsertionGuard g(builder);
309317
// TODO: move this out of the loop

test/Triton/combine.mlir

Lines changed: 37 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -61,6 +61,22 @@ func @test_combine_select_masked_load_pattern(%ptr: tensor<8x!tt.ptr<f32>>, %con
6161
return %0, %1 : tensor<8xf32>, tensor<8xf32>
6262
}
6363

64+
// CHECK-LABEL: @test_combine_select_masked_load_fail_pattern
65+
func @test_combine_select_masked_load_fail_pattern(%ptr: tensor<8x!tt.ptr<f32>>, %dummy_load: tensor<8xf32>, %dummy_broadcast: tensor<8xi1>, %cond: i1) -> (tensor<8xf32>, tensor<8xf32>) {
66+
%false_val = arith.constant dense<0.0> : tensor<8xf32>
67+
68+
// Case 1: value at the "load" position is not an "op". Select should not be canonicalized.
69+
// CHECK: %{{.*}} = select %{{.*}}, %{{.*}}, %{{.*}} : tensor<8xf32>
70+
%0 = select %cond, %dummy_load, %false_val : tensor<8xf32>
71+
72+
// Case 2: value at the "broadcast" position is not an "op". Select should not be canonicalized.
73+
%real_load = tt.load %ptr, %dummy_broadcast, %false_val {cache = 1 : i32, evict = 1 : i32, isVolatile = false} : tensor<8xf32>
74+
// CHECK: %{{.*}} = select %{{.*}}, %{{.*}}, %{{.*}} : tensor<8xf32>
75+
%1 = select %cond, %real_load, %false_val : tensor<8xf32>
76+
77+
return %0, %1 : tensor<8xf32>, tensor<8xf32>
78+
}
79+
6480
// CHECK-LABEL: @test_combine_broadcast_constant_pattern
6581
func @test_combine_broadcast_constant_pattern(%cst : f32) -> tensor<8x2xf32> {
6682
// CHECK: %[[cst:.*]] = arith.constant dense<1.000000e+00> : tensor<8x2xf32>
@@ -92,6 +108,19 @@ func @test_canonicalize_masked_load_pattern(%ptr: tensor<8x!tt.ptr<f32>>) -> (te
92108
return %x, %y, %z: tensor<8xf32>, tensor<8xf32>, tensor<8xf32>
93109
}
94110

111+
// CHECK-LABEL: @test_canonicalize_masked_load_fail_pattern
112+
func @test_canonicalize_masked_load_fail_pattern(%ptr: tensor<8x!tt.ptr<f32>>, %mask: tensor<8xi1>) -> (tensor<8xf32>, tensor<8xf32>) {
113+
%other_val = arith.constant dense<0.0> : tensor<8xf32>
114+
115+
// Case: value at the "mask" position is not an "op". Load should not be canonicalized.
116+
// CHECK: %[[res1:.*]] = tt.load %{{.*}}, %{{.*}} {cache = 1 : i32, evict = 1 : i32, isVolatile = false} : tensor<8xf32>
117+
%x = tt.load %ptr, %mask {cache = 1 : i32, evict = 1 : i32, isVolatile = false} : tensor<8xf32>
118+
// CHECK: %[[res1:.*]] = tt.load %{{.*}}, %{{.*}}, %{{.*}} {cache = 1 : i32, evict = 1 : i32, isVolatile = false} : tensor<8xf32>
119+
%y = tt.load %ptr, %mask, %other_val {cache = 1 : i32, evict = 1 : i32, isVolatile = false} : tensor<8xf32>
120+
121+
return %x, %y: tensor<8xf32>, tensor<8xf32>
122+
}
123+
95124
// CHECK-LABEL: @test_canonicalize_masked_store_pattern
96125
func @test_canonicalize_masked_store_pattern(%ptr: tensor<8x!tt.ptr<f32>>, %val: tensor<8xf32>) {
97126
%true_mask = arith.constant dense<true> : tensor<8xi1>
@@ -105,3 +134,11 @@ func @test_canonicalize_masked_store_pattern(%ptr: tensor<8x!tt.ptr<f32>>, %val:
105134
tt.store %ptr, %val, %false_mask : tensor<8xf32>
106135
return
107136
}
137+
138+
// CHECK-LABEL: @test_canonicalize_masked_store_fail_pattern
139+
func @test_canonicalize_masked_store_fail_pattern(%ptr: tensor<8x!tt.ptr<f32>>, %val: tensor<8xf32>, %mask: tensor<8xi1>) {
140+
// Case: value at the "mask" position is not an "op". Store should not be canonicalized.
141+
// CHECK: tt.store %{{.*}}, %{{.*}}, %{{.*}} : tensor<8xf32>
142+
tt.store %ptr, %val, %mask : tensor<8xf32>
143+
return
144+
}

0 commit comments

Comments
 (0)