diff --git a/.github/workflows/integration-tests.yml b/.github/workflows/integration-tests.yml index ea7d3f19169e..5c45903a8f51 100644 --- a/.github/workflows/integration-tests.yml +++ b/.github/workflows/integration-tests.yml @@ -56,7 +56,8 @@ jobs: - name: Check cpp style if: ${{ matrix.runner != 'macos-latest' }} run: | - sudo apt-get install -y clang-format + # sudo apt-get install -y clang-format + pip install clang-format find . -regex '.*\.\(cpp\|hpp\|h\|cc\)' -not -path "./python/build/*" -not -path "./include/triton/external/*" -print0 | xargs -0 -n1 clang-format -style=file --dry-run -Werror -i || (echo '::error title=Style issues:: Please run `find . -regex ".*\.\(cpp\|hpp\|h\|cc\)" -not -path "./python/build/*" -not -path "./include/triton/external/*" -print0 | xargs -0 -n1 clang-format -style=file -i`' ; exit 1) diff --git a/lib/Dialect/Triton/Transforms/Combine.cpp b/lib/Dialect/Triton/Transforms/Combine.cpp index 249a3c075175..1607ed47f907 100644 --- a/lib/Dialect/Triton/Transforms/Combine.cpp +++ b/lib/Dialect/Triton/Transforms/Combine.cpp @@ -71,6 +71,8 @@ class CombineSelectMaskedLoadPattern : public mlir::RewritePattern { mlir::Value falseValue = selectOp.getFalseValue(); auto *loadOpCandidate = trueValue.getDefiningOp(); + if (!loadOpCandidate) + return mlir::failure(); auto loadOp = llvm::dyn_cast(loadOpCandidate); if (!loadOp) return mlir::failure(); @@ -80,6 +82,8 @@ class CombineSelectMaskedLoadPattern : public mlir::RewritePattern { return mlir::failure(); auto *broadcastOpCandidate = mask.getDefiningOp(); + if (!broadcastOpCandidate) + return mlir::failure(); auto broadcastOp = llvm::dyn_cast(broadcastOpCandidate); if (!broadcastOp) @@ -106,7 +110,11 @@ struct CanonicalizeMaskedLoadPattern if (!mask) return mlir::failure(); - auto constantMask = llvm::dyn_cast(mask.getDefiningOp()); + auto *constantMaskCandidate = mask.getDefiningOp(); + if (!constantMaskCandidate) + return mlir::failure(); + auto constantMask = + llvm::dyn_cast(constantMaskCandidate); if (!constantMask) return mlir::failure(); @@ -152,7 +160,11 @@ struct CanonicalizeMaskedStorePattern if (!mask) return mlir::failure(); - auto constantMask = llvm::dyn_cast(mask.getDefiningOp()); + auto *constantMaskCandidate = mask.getDefiningOp(); + if (!constantMaskCandidate) + return mlir::failure(); + auto constantMask = + llvm::dyn_cast(constantMaskCandidate); if (!constantMask) return mlir::failure(); diff --git a/lib/Dialect/TritonGPU/Transforms/Pipeline.cpp b/lib/Dialect/TritonGPU/Transforms/Pipeline.cpp index 0203ddd2be47..c22b7d5f46f5 100644 --- a/lib/Dialect/TritonGPU/Transforms/Pipeline.cpp +++ b/lib/Dialect/TritonGPU/Transforms/Pipeline.cpp @@ -301,9 +301,17 @@ void LoopPipeliner::emitPrologue() { } // If this is a load/async_copy, we need to update the mask - if (llvm::isa(newOp)) { - Value mask = llvm::isa(newOp) ? newOp->getOperand(1) - : newOp->getOperand(3); + if (Value mask = [&]() { + if (auto loadOp = llvm::dyn_cast(newOp)) { + return loadOp.mask(); + } else if (auto insertSliceAsyncOp = + llvm::dyn_cast( + newOp)) { + return insertSliceAsyncOp.mask(); + } else { + return mlir::Value(); + } + }()) { // assert(I1 or TensorOf<[I1]>); OpBuilder::InsertionGuard g(builder); // TODO: move this out of the loop diff --git a/test/Triton/combine.mlir b/test/Triton/combine.mlir index 503cc9a26f43..7536e08f7779 100644 --- a/test/Triton/combine.mlir +++ b/test/Triton/combine.mlir @@ -61,6 +61,22 @@ func @test_combine_select_masked_load_pattern(%ptr: tensor<8x!tt.ptr>, %con return %0, %1 : tensor<8xf32>, tensor<8xf32> } +// CHECK-LABEL: @test_combine_select_masked_load_fail_pattern +func @test_combine_select_masked_load_fail_pattern(%ptr: tensor<8x!tt.ptr>, %dummy_load: tensor<8xf32>, %dummy_broadcast: tensor<8xi1>, %cond: i1) -> (tensor<8xf32>, tensor<8xf32>) { + %false_val = arith.constant dense<0.0> : tensor<8xf32> + + // Case 1: value at the "load" position is not an "op". Select should not be canonicalized. + // CHECK: %{{.*}} = select %{{.*}}, %{{.*}}, %{{.*}} : tensor<8xf32> + %0 = select %cond, %dummy_load, %false_val : tensor<8xf32> + + // Case 2: value at the "broadcast" position is not an "op". Select should not be canonicalized. + %real_load = tt.load %ptr, %dummy_broadcast, %false_val {cache = 1 : i32, evict = 1 : i32, isVolatile = false} : tensor<8xf32> + // CHECK: %{{.*}} = select %{{.*}}, %{{.*}}, %{{.*}} : tensor<8xf32> + %1 = select %cond, %real_load, %false_val : tensor<8xf32> + + return %0, %1 : tensor<8xf32>, tensor<8xf32> +} + // CHECK-LABEL: @test_combine_broadcast_constant_pattern func @test_combine_broadcast_constant_pattern(%cst : f32) -> tensor<8x2xf32> { // CHECK: %[[cst:.*]] = arith.constant dense<1.000000e+00> : tensor<8x2xf32> @@ -92,6 +108,19 @@ func @test_canonicalize_masked_load_pattern(%ptr: tensor<8x!tt.ptr>) -> (te return %x, %y, %z: tensor<8xf32>, tensor<8xf32>, tensor<8xf32> } +// CHECK-LABEL: @test_canonicalize_masked_load_fail_pattern +func @test_canonicalize_masked_load_fail_pattern(%ptr: tensor<8x!tt.ptr>, %mask: tensor<8xi1>) -> (tensor<8xf32>, tensor<8xf32>) { + %other_val = arith.constant dense<0.0> : tensor<8xf32> + + // Case: value at the "mask" position is not an "op". Load should not be canonicalized. + // CHECK: %[[res1:.*]] = tt.load %{{.*}}, %{{.*}} {cache = 1 : i32, evict = 1 : i32, isVolatile = false} : tensor<8xf32> + %x = tt.load %ptr, %mask {cache = 1 : i32, evict = 1 : i32, isVolatile = false} : tensor<8xf32> + // CHECK: %[[res1:.*]] = tt.load %{{.*}}, %{{.*}}, %{{.*}} {cache = 1 : i32, evict = 1 : i32, isVolatile = false} : tensor<8xf32> + %y = tt.load %ptr, %mask, %other_val {cache = 1 : i32, evict = 1 : i32, isVolatile = false} : tensor<8xf32> + + return %x, %y: tensor<8xf32>, tensor<8xf32> +} + // CHECK-LABEL: @test_canonicalize_masked_store_pattern func @test_canonicalize_masked_store_pattern(%ptr: tensor<8x!tt.ptr>, %val: tensor<8xf32>) { %true_mask = arith.constant dense : tensor<8xi1> @@ -105,3 +134,11 @@ func @test_canonicalize_masked_store_pattern(%ptr: tensor<8x!tt.ptr>, %val: tt.store %ptr, %val, %false_mask : tensor<8xf32> return } + +// CHECK-LABEL: @test_canonicalize_masked_store_fail_pattern +func @test_canonicalize_masked_store_fail_pattern(%ptr: tensor<8x!tt.ptr>, %val: tensor<8xf32>, %mask: tensor<8xi1>) { + // Case: value at the "mask" position is not an "op". Store should not be canonicalized. + // CHECK: tt.store %{{.*}}, %{{.*}}, %{{.*}} : tensor<8xf32> + tt.store %ptr, %val, %mask : tensor<8xf32> + return +}