@@ -61,6 +61,22 @@ func @test_combine_select_masked_load_pattern(%ptr: tensor<8x!tt.ptr<f32>>, %con
6161 return %0 , %1 : tensor <8 xf32 >, tensor <8 xf32 >
6262}
6363
64+ // CHECK-LABEL: @test_combine_select_masked_load_fail_pattern
65+ func @test_combine_select_masked_load_fail_pattern (%ptr: tensor <8 x!tt.ptr <f32 >>, %dummy_load: tensor <8 xf32 >, %dummy_broadcast: tensor <8 xi1 >, %cond: i1 ) -> (tensor <8 xf32 >, tensor <8 xf32 >) {
66+ %false_val = arith.constant dense <0.0 > : tensor <8 xf32 >
67+
68+ // Case 1: value at the "load" position is not an "op". Select should not be canonicalized.
69+ // CHECK: %{{.*}} = select %{{.*}}, %{{.*}}, %{{.*}} : tensor<8xf32>
70+ %0 = select %cond , %dummy_load , %false_val : tensor <8 xf32 >
71+
72+ // Case 2: value at the "broadcast" position is not an "op". Select should not be canonicalized.
73+ %real_load = tt.load %ptr , %dummy_broadcast , %false_val {cache = 1 : i32 , evict = 1 : i32 , isVolatile = false } : tensor <8 xf32 >
74+ // CHECK: %{{.*}} = select %{{.*}}, %{{.*}}, %{{.*}} : tensor<8xf32>
75+ %1 = select %cond , %real_load , %false_val : tensor <8 xf32 >
76+
77+ return %0 , %1 : tensor <8 xf32 >, tensor <8 xf32 >
78+ }
79+
6480// CHECK-LABEL: @test_combine_broadcast_constant_pattern
6581func @test_combine_broadcast_constant_pattern (%cst : f32 ) -> tensor <8 x2 xf32 > {
6682 // CHECK: %[[cst:.*]] = arith.constant dense<1.000000e+00> : tensor<8x2xf32>
@@ -92,6 +108,19 @@ func @test_canonicalize_masked_load_pattern(%ptr: tensor<8x!tt.ptr<f32>>) -> (te
92108 return %x , %y , %z: tensor <8 xf32 >, tensor <8 xf32 >, tensor <8 xf32 >
93109}
94110
111+ // CHECK-LABEL: @test_canonicalize_masked_load_fail_pattern
112+ func @test_canonicalize_masked_load_fail_pattern (%ptr: tensor <8 x!tt.ptr <f32 >>, %mask: tensor <8 xi1 >) -> (tensor <8 xf32 >, tensor <8 xf32 >) {
113+ %other_val = arith.constant dense <0.0 > : tensor <8 xf32 >
114+
115+ // Case: value at the "mask" position is not an "op". Load should not be canonicalized.
116+ // CHECK: %[[res1:.*]] = tt.load %{{.*}}, %{{.*}} {cache = 1 : i32, evict = 1 : i32, isVolatile = false} : tensor<8xf32>
117+ %x = tt.load %ptr , %mask {cache = 1 : i32 , evict = 1 : i32 , isVolatile = false } : tensor <8 xf32 >
118+ // CHECK: %[[res1:.*]] = tt.load %{{.*}}, %{{.*}}, %{{.*}} {cache = 1 : i32, evict = 1 : i32, isVolatile = false} : tensor<8xf32>
119+ %y = tt.load %ptr , %mask , %other_val {cache = 1 : i32 , evict = 1 : i32 , isVolatile = false } : tensor <8 xf32 >
120+
121+ return %x , %y: tensor <8 xf32 >, tensor <8 xf32 >
122+ }
123+
95124// CHECK-LABEL: @test_canonicalize_masked_store_pattern
96125func @test_canonicalize_masked_store_pattern (%ptr: tensor <8 x!tt.ptr <f32 >>, %val: tensor <8 xf32 >) {
97126 %true_mask = arith.constant dense <true > : tensor <8 xi1 >
@@ -105,3 +134,11 @@ func @test_canonicalize_masked_store_pattern(%ptr: tensor<8x!tt.ptr<f32>>, %val:
105134 tt.store %ptr , %val , %false_mask : tensor <8 xf32 >
106135 return
107136}
137+
138+ // CHECK-LABEL: @test_canonicalize_masked_store_fail_pattern
139+ func @test_canonicalize_masked_store_fail_pattern (%ptr: tensor <8 x!tt.ptr <f32 >>, %val: tensor <8 xf32 >, %mask: tensor <8 xi1 >) {
140+ // Case: value at the "mask" position is not an "op". Store should not be canonicalized.
141+ // CHECK: tt.store %{{.*}}, %{{.*}}, %{{.*}} : tensor<8xf32>
142+ tt.store %ptr , %val , %mask : tensor <8 xf32 >
143+ return
144+ }
0 commit comments