Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion .github/workflows/integration-tests.yml
Original file line number Diff line number Diff line change
Expand Up @@ -56,7 +56,7 @@ jobs:
- name: Check cpp style
if: ${{ matrix.runner != 'macos-latest' }}
run: |
sudo apt-get install -y clang-format
pip install clang-format
find . -regex '.*\.\(cpp\|hpp\|h\|cc\)' -not -path "./python/build/*" -not -path "./include/triton/external/*" -print0 | xargs -0 -n1 clang-format -style=file --dry-run -Werror -i ||
(echo '::error title=Style issues:: Please run `find . -regex ".*\.\(cpp\|hpp\|h\|cc\)" -not -path "./python/build/*" -not -path "./include/triton/external/*" -print0 | xargs -0 -n1 clang-format -style=file -i`' ; exit 1)

Expand Down
10 changes: 6 additions & 4 deletions lib/Dialect/Triton/Transforms/Combine.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -71,7 +71,7 @@ class CombineSelectMaskedLoadPattern : public mlir::RewritePattern {
mlir::Value falseValue = selectOp.getFalseValue();

auto *loadOpCandidate = trueValue.getDefiningOp();
auto loadOp = llvm::dyn_cast<triton::LoadOp>(loadOpCandidate);
auto loadOp = llvm::dyn_cast_or_null<triton::LoadOp>(loadOpCandidate);
if (!loadOp)
return mlir::failure();

Expand All @@ -81,7 +81,7 @@ class CombineSelectMaskedLoadPattern : public mlir::RewritePattern {

auto *broadcastOpCandidate = mask.getDefiningOp();
auto broadcastOp =
llvm::dyn_cast<triton::BroadcastOp>(broadcastOpCandidate);
llvm::dyn_cast_or_null<triton::BroadcastOp>(broadcastOpCandidate);
if (!broadcastOp)
return mlir::failure();

Expand All @@ -106,7 +106,8 @@ struct CanonicalizeMaskedLoadPattern
if (!mask)
return mlir::failure();

auto constantMask = llvm::dyn_cast<arith::ConstantOp>(mask.getDefiningOp());
auto constantMask =
llvm::dyn_cast_or_null<arith::ConstantOp>(mask.getDefiningOp());
if (!constantMask)
return mlir::failure();

Expand Down Expand Up @@ -152,7 +153,8 @@ struct CanonicalizeMaskedStorePattern
if (!mask)
return mlir::failure();

auto constantMask = llvm::dyn_cast<arith::ConstantOp>(mask.getDefiningOp());
auto constantMask =
llvm::dyn_cast_or_null<arith::ConstantOp>(mask.getDefiningOp());
if (!constantMask)
return mlir::failure();

Expand Down
14 changes: 11 additions & 3 deletions lib/Dialect/TritonGPU/Transforms/Pipeline.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -301,9 +301,17 @@ void LoopPipeliner::emitPrologue() {
}

// If this is a load/async_copy, we need to update the mask
if (llvm::isa<triton::LoadOp, triton::gpu::InsertSliceAsyncOp>(newOp)) {
Value mask = llvm::isa<triton::LoadOp>(newOp) ? newOp->getOperand(1)
: newOp->getOperand(3);
if (Value mask = [&]() {
if (auto loadOp = llvm::dyn_cast<triton::LoadOp>(newOp)) {
return loadOp.mask();
} else if (auto insertSliceAsyncOp =
llvm::dyn_cast<triton::gpu::InsertSliceAsyncOp>(
newOp)) {
return insertSliceAsyncOp.mask();
} else {
return mlir::Value();
}
}()) {
// assert(I1 or TensorOf<[I1]>);
OpBuilder::InsertionGuard g(builder);
// TODO: move this out of the loop
Expand Down
37 changes: 37 additions & 0 deletions test/Triton/combine.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -61,6 +61,22 @@ func @test_combine_select_masked_load_pattern(%ptr: tensor<8x!tt.ptr<f32>>, %con
return %0, %1 : tensor<8xf32>, tensor<8xf32>
}

// CHECK-LABEL: @test_combine_select_masked_load_fail_pattern
func @test_combine_select_masked_load_fail_pattern(%ptr: tensor<8x!tt.ptr<f32>>, %dummy_load: tensor<8xf32>, %dummy_broadcast: tensor<8xi1>, %cond: i1) -> (tensor<8xf32>, tensor<8xf32>) {
%false_val = arith.constant dense<0.0> : tensor<8xf32>

// Case 1: value at the "load" position is not an "op". Select should not be canonicalized.
// CHECK: %{{.*}} = select %{{.*}}, %{{.*}}, %{{.*}} : tensor<8xf32>
%0 = select %cond, %dummy_load, %false_val : tensor<8xf32>

// Case 2: value at the "broadcast" position is not an "op". Select should not be canonicalized.
%real_load = tt.load %ptr, %dummy_broadcast, %false_val {cache = 1 : i32, evict = 1 : i32, isVolatile = false} : tensor<8xf32>
// CHECK: %{{.*}} = select %{{.*}}, %{{.*}}, %{{.*}} : tensor<8xf32>
%1 = select %cond, %real_load, %false_val : tensor<8xf32>

return %0, %1 : tensor<8xf32>, tensor<8xf32>
}

// CHECK-LABEL: @test_combine_broadcast_constant_pattern
func @test_combine_broadcast_constant_pattern(%cst : f32) -> tensor<8x2xf32> {
// CHECK: %[[cst:.*]] = arith.constant dense<1.000000e+00> : tensor<8x2xf32>
Expand Down Expand Up @@ -92,6 +108,19 @@ func @test_canonicalize_masked_load_pattern(%ptr: tensor<8x!tt.ptr<f32>>) -> (te
return %x, %y, %z: tensor<8xf32>, tensor<8xf32>, tensor<8xf32>
}

// CHECK-LABEL: @test_canonicalize_masked_load_fail_pattern
func @test_canonicalize_masked_load_fail_pattern(%ptr: tensor<8x!tt.ptr<f32>>, %mask: tensor<8xi1>) -> (tensor<8xf32>, tensor<8xf32>) {
%other_val = arith.constant dense<0.0> : tensor<8xf32>

// Case: value at the "mask" position is not an "op". Load should not be canonicalized.
// CHECK: %[[res1:.*]] = tt.load %{{.*}}, %{{.*}} {cache = 1 : i32, evict = 1 : i32, isVolatile = false} : tensor<8xf32>
%x = tt.load %ptr, %mask {cache = 1 : i32, evict = 1 : i32, isVolatile = false} : tensor<8xf32>
// CHECK: %[[res1:.*]] = tt.load %{{.*}}, %{{.*}}, %{{.*}} {cache = 1 : i32, evict = 1 : i32, isVolatile = false} : tensor<8xf32>
%y = tt.load %ptr, %mask, %other_val {cache = 1 : i32, evict = 1 : i32, isVolatile = false} : tensor<8xf32>

return %x, %y: tensor<8xf32>, tensor<8xf32>
}

// CHECK-LABEL: @test_canonicalize_masked_store_pattern
func @test_canonicalize_masked_store_pattern(%ptr: tensor<8x!tt.ptr<f32>>, %val: tensor<8xf32>) {
%true_mask = arith.constant dense<true> : tensor<8xi1>
Expand All @@ -105,3 +134,11 @@ func @test_canonicalize_masked_store_pattern(%ptr: tensor<8x!tt.ptr<f32>>, %val:
tt.store %ptr, %val, %false_mask : tensor<8xf32>
return
}

// CHECK-LABEL: @test_canonicalize_masked_store_fail_pattern
func @test_canonicalize_masked_store_fail_pattern(%ptr: tensor<8x!tt.ptr<f32>>, %val: tensor<8xf32>, %mask: tensor<8xi1>) {
// Case: value at the "mask" position is not an "op". Store should not be canonicalized.
// CHECK: tt.store %{{.*}}, %{{.*}}, %{{.*}} : tensor<8xf32>
tt.store %ptr, %val, %mask : tensor<8xf32>
return
}