Skip to content

Commit 70031c1

Browse files
[Tests] Add tests to check fixes
1 parent a85810d commit 70031c1

File tree

1 file changed

+37
-0
lines changed

1 file changed

+37
-0
lines changed

test/Triton/combine.mlir

Lines changed: 37 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -61,6 +61,22 @@ func @test_combine_select_masked_load_pattern(%ptr: tensor<8x!tt.ptr<f32>>, %con
6161
return %0, %1 : tensor<8xf32>, tensor<8xf32>
6262
}
6363

64+
// CHECK-LABEL: @test_combine_select_masked_load_fail_pattern
65+
func @test_combine_select_masked_load_fail_pattern(%ptr: tensor<8x!tt.ptr<f32>>, %dummy_load: tensor<8xf32>, %dummy_broadcast: tensor<8xi1>, %cond: i1) -> (tensor<8xf32>, tensor<8xf32>) {
66+
%false_val = arith.constant dense<0.0> : tensor<8xf32>
67+
68+
// Case 1: value at the "load" position is not an "op". Select should not be canonicalized.
69+
// CHECK: %{{.*}} = select %{{.*}}, %{{.*}}, %{{.*}} : tensor<8xf32>
70+
%0 = select %cond, %dummy_load, %false_val : tensor<8xf32>
71+
72+
// Case 2: value at the "broadcast" position is not an "op". Select should not be canonicalized.
73+
%real_load = tt.load %ptr, %dummy_broadcast, %false_val {cache = 1 : i32, evict = 1 : i32, isVolatile = false} : tensor<8xf32>
74+
// CHECK: %{{.*}} = select %{{.*}}, %{{.*}}, %{{.*}} : tensor<8xf32>
75+
%1 = select %cond, %real_load, %false_val : tensor<8xf32>
76+
77+
return %0, %1 : tensor<8xf32>, tensor<8xf32>
78+
}
79+
6480
// CHECK-LABEL: @test_combine_broadcast_constant_pattern
6581
func @test_combine_broadcast_constant_pattern(%cst : f32) -> tensor<8x2xf32> {
6682
// CHECK: %[[cst:.*]] = arith.constant dense<1.000000e+00> : tensor<8x2xf32>
@@ -92,6 +108,19 @@ func @test_canonicalize_masked_load_pattern(%ptr: tensor<8x!tt.ptr<f32>>) -> (te
92108
return %x, %y, %z: tensor<8xf32>, tensor<8xf32>, tensor<8xf32>
93109
}
94110

111+
// CHECK-LABEL: @test_canonicalize_masked_load_fail_pattern
112+
func @test_canonicalize_masked_load_fail_pattern(%ptr: tensor<8x!tt.ptr<f32>>, %mask: tensor<8xi1>) -> (tensor<8xf32>, tensor<8xf32>) {
113+
%other_val = arith.constant dense<0.0> : tensor<8xf32>
114+
115+
// Case: value at the "mask" position is not an "op". Load should not be canonicalized.
116+
// CHECK: %[[res1:.*]] = tt.load %{{.*}}, %{{.*}} {cache = 1 : i32, evict = 1 : i32, isVolatile = false} : tensor<8xf32>
117+
%x = tt.load %ptr, %mask {cache = 1 : i32, evict = 1 : i32, isVolatile = false} : tensor<8xf32>
118+
// CHECK: %[[res1:.*]] = tt.load %{{.*}}, %{{.*}}, %{{.*}} {cache = 1 : i32, evict = 1 : i32, isVolatile = false} : tensor<8xf32>
119+
%y = tt.load %ptr, %mask, %other_val {cache = 1 : i32, evict = 1 : i32, isVolatile = false} : tensor<8xf32>
120+
121+
return %x, %y: tensor<8xf32>, tensor<8xf32>
122+
}
123+
95124
// CHECK-LABEL: @test_canonicalize_masked_store_pattern
96125
func @test_canonicalize_masked_store_pattern(%ptr: tensor<8x!tt.ptr<f32>>, %val: tensor<8xf32>) {
97126
%true_mask = arith.constant dense<true> : tensor<8xi1>
@@ -105,3 +134,11 @@ func @test_canonicalize_masked_store_pattern(%ptr: tensor<8x!tt.ptr<f32>>, %val:
105134
tt.store %ptr, %val, %false_mask : tensor<8xf32>
106135
return
107136
}
137+
138+
// CHECK-LABEL: @test_canonicalize_masked_store_fail_pattern
139+
func @test_canonicalize_masked_store_fail_pattern(%ptr: tensor<8x!tt.ptr<f32>>, %val: tensor<8xf32>, %mask: tensor<8xi1>) {
140+
// Case: value at the "mask" position is not an "op". Store should not be canonicalized.
141+
// CHECK: tt.store %{{.*}}, %{{.*}}, %{{.*}} : tensor<8xf32>
142+
tt.store %ptr, %val, %mask : tensor<8xf32>
143+
return
144+
}

0 commit comments

Comments
 (0)