11// RUN: mlir-opt %s --transform-interpreter --split-input-file | FileCheck %s
22
3+ // TODO: Review the usage of `in_bounds` and remove where not affecting the
4+ // generated output.
5+
36/// CHECK: #[[$MAP:.*]] = affine_map<(d0, d1, d2, d3) -> (d1, 0, d3)>
47
58///----------------------------------------------------------------------------------------
@@ -106,8 +109,8 @@ func.func @xfer_write_minor_identity_transposed_map_masked(
106109/// (neither a minor identity nor transposed minor identity map)
107110/// OUT 1: vector.broadcast + vector.transfer_write
108111/// (transposed minor identity)
109- /// OUT 2: vector.transfer_write -> vector.broadcast + vector.transpose + vector.transfer_write
110- /// (minor identity)
112+ /// OUT 2: vector.transfer_write -> vector.broadcast + vector.transpose
113+ /// + vector.transfer_write (minor identity)
111114///----------------------------------------------------------------------------------------
112115
113116// CHECK-LABEL: func.func @xfer_write_non_minor_identity(
@@ -233,16 +236,16 @@ func.func @xfer_write_non_minor_identity_masked_scalable(
233236// CHECK-LABEL: func @xfer_write_non_minor_identity_masked_2
234237// CHECK-SAME: %[[DEST:.*]]: tensor<?x?x?x?xf32>
235238// CHECK-SAME: %[[VEC:.*]]: vector<14x8x16xf32>
236- // CHECK-SAME: %[[DIM:.*]]: index, %[[IDX:.*]]: index) -> tensor<?x?x?x?xf32>
239+ // CHECK-SAME: %[[MASK:.*]]: vector<14x8x16xi1>
240+ // CHECK-SAME: %[[DIM:.*]]: index
237241// CHECK-NOT: vector.broadcast
238- // CHECK: vector.mask %0 { vector.transfer_write %[[VEC]], %[[DEST]]{{.*}} : vector<14x8x16xf32>, tensor<?x?x?x?xf32> } : vector<14x8x16xi1> -> tensor<?x?x?x?xf32>
242+ // CHECK: vector.mask %[[MASK]] { vector.transfer_write %[[VEC]], %[[DEST]]{{.*}} : vector<14x8x16xf32>, tensor<?x?x?x?xf32> } : vector<14x8x16xi1> -> tensor<?x?x?x?xf32>
239243func.func @xfer_write_non_minor_identity_masked_2 (
240244 %dest : tensor <?x?x?x?xf32 >,
241245 %vec : vector <14 x8 x16 xf32 >,
242- %dim : index ,
246+ %mask: vector < 14 x 8 x 16 x i1 > ,
243247 %idx: index ) -> tensor <?x?x?x?xf32 > {
244248
245- %mask = vector.create_mask %dim , %dim , %dim : vector <14 x8 x16 xi1 >
246249 %res = vector.mask %mask {
247250 vector.transfer_write %vec , %dest [%idx , %idx , %idx , %idx ] {
248251 in_bounds = [false , false , true ],
@@ -259,29 +262,27 @@ func.func @xfer_write_non_minor_identity_masked_2(
259262///
260263/// IN: vector.transfer_read
261264/// (_transposed_ minor identity permutation map, with 0 or more broadcast dims)
262- /// OUT: vector.transpose + vector.transfer_write
265+ /// OUT: vector.transfer_read + vector.broadcast + vector.transpose
263266/// (minor identity permutation map with 0 or more leading broadcast dims)
264267///----------------------------------------------------------------------------------------
265268/// TODO: Inner broadcast dim - see also the block at the bottom of this file
266269
267- // CHECK-LABEL: func.func @xfer_read_minor_identity_tranposed_with_mask
270+ // CHECK-LABEL: func.func @xfer_read_minor_identity_transposed_with_mask
268271// CHECK-SAME: %[[MEM:.*]]: memref<?x?xf32>,
269- // CHECK-SAME: %[[DIM_1:.*]]: index, %[[DIM_2:.*]]: index, %[[IDX:.*]]: index) -> vector<8x4x2xf32> {
272+ // CHECK-SAME: %[[MASK:.*]]: vector<2x4xi1>
273+ // CHECK-SAME: %[[IDX:.*]]: index
270274// CHECK: %[[PASS_THROUGH:.*]] = arith.constant 0.000000e+00 : f32
271- // CHECK: %[[MASK:.*]] = vector.create_mask %[[DIM_2]], %[[DIM_1]] : vector<2x4xi1>
272275// CHECK: %[[T_READ:.*]] = vector.transfer_read %[[MEM]]{{\[}}%[[IDX]], %[[IDX]]], %[[PASS_THROUGH]], %[[MASK]] {in_bounds = [true, true]} : memref<?x?xf32>, vector<2x4xf32>
273276// CHECK: %[[BCAST:.*]] = vector.broadcast %[[T_READ]] : vector<2x4xf32> to vector<8x2x4xf32>
274277// CHECK: %[[TRANSPOSE:.*]] = vector.transpose %[[BCAST]], [0, 2, 1] : vector<8x2x4xf32> to vector<8x4x2xf32>
275278// CHECK: return %[[TRANSPOSE]] : vector<8x4x2xf32>
276- func.func @xfer_read_minor_identity_tranposed_with_mask (
279+ func.func @xfer_read_minor_identity_transposed_with_mask (
277280 %mem: memref <?x?xf32 >,
278- %dim_1: index ,
279- %dim_2: index ,
281+ %mask: vector <2 x4 xi1 >,
280282 %idx: index ) -> (vector <8 x4 x2 xf32 >) {
281283
282284 %pad = arith.constant 0.000000e+00 : f32
283285
284- %mask = vector.create_mask %dim_2 , %dim_1 : vector <2 x4 xi1 >
285286 %res = vector.transfer_read %mem [%idx , %idx ], %pad , %mask {
286287 in_bounds = [true , true , true ],
287288 permutation_map = affine_map <(d0 , d1 ) -> (0 , d1 , d0 )>
@@ -290,24 +291,22 @@ func.func @xfer_read_minor_identity_tranposed_with_mask(
290291 return %res : vector <8 x4 x2 xf32 >
291292}
292293
293- // CHECK-LABEL: func.func @xfer_read_minor_identity_tranposed_with_mask_scalable (
294+ // CHECK-LABEL: func.func @xfer_read_minor_identity_transposed_with_mask_scalable (
294295// CHECK-SAME: %[[MEM:.*]]: memref<?x?xf32>,
295- // CHECK-SAME: %[[DIM_1:.*]]: index, %[[DIM_2:.*]]: index, %[[IDX:.*]]: index) -> vector<8x[4]x2xf32> {
296+ // CHECK-SAME: %[[MASK:.*]]: vector<2x[4]xi1>
297+ // CHECK-SAME: %[[IDX:.*]]: index
296298// CHECK: %[[PAD:.*]] = arith.constant 0.000000e+00 : f32
297- // CHECK: %[[MASK:.*]] = vector.create_mask %[[DIM_2]], %[[DIM_1]] : vector<2x[4]xi1>
298299// CHECK: %[[T_READ:.*]] = vector.transfer_read %[[MEM]]{{\[}}%[[IDX]], %[[IDX]]], %[[PAD]], %[[MASK]] {in_bounds = [true, true]} : memref<?x?xf32>, vector<2x[4]xf32>
299300// CHECK: %[[BCAST:.*]] = vector.broadcast %[[T_READ]] : vector<2x[4]xf32> to vector<8x2x[4]xf32>
300301// CHECK: %[[TRANSPOSE:.*]] = vector.transpose %[[BCAST]], [0, 2, 1] : vector<8x2x[4]xf32> to vector<8x[4]x2xf32>
301302// CHECK: return %[[TRANSPOSE]] : vector<8x[4]x2xf32>
302- func.func @xfer_read_minor_identity_tranposed_with_mask_scalable (
303+ func.func @xfer_read_minor_identity_transposed_with_mask_scalable (
303304 %mem: memref <?x?xf32 >,
304- %dim_1: index ,
305- %dim_2: index ,
305+ %mask: vector <2 x[4 ]xi1 >,
306306 %idx: index ) -> (vector <8 x[4 ]x2 xf32 >) {
307307
308308 %pad = arith.constant 0.000000e+00 : f32
309309
310- %mask = vector.create_mask %dim_2 , %dim_1 : vector <2 x[4 ]xi1 >
311310 %res = vector.transfer_read %mem [%idx , %idx ], %pad , %mask {
312311 in_bounds = [true , true , true ],
313312 permutation_map = affine_map <(d0 , d1 ) -> (0 , d1 , d0 )>
@@ -319,24 +318,26 @@ func.func @xfer_read_minor_identity_tranposed_with_mask_scalable(
319318// Masked version is not supported
320319
321320// CHECK-LABEL: func @xfer_read_minor_identity_transposed_masked(
322- // CHECK-SAME: %[[DEST:.*]]: tensor<?x1xf32>,
323- // CHECK-SAME: %[[MASK:.*]]: vector<4x1xi1>
321+ // CHECK-SAME: %[[DEST:.*]]: tensor<?x?xf32>,
322+ // CHECK-SAME: %[[MASK:.*]]: vector<2x4xi1>
323+ // CHECK-SAME: %[[IDX:.*]]: index
324324// CHECK-NOT: vector.transpose
325- // CHECK: vector.mask %[[MASK]] { vector.transfer_read %[[DEST]]{{.*}}: tensor<?x1xf32 >, vector<1x4x4xf32 > } : vector<4x1xi1 > -> vector<1x4x4xf32 >
325+ // CHECK: vector.mask %[[MASK]] { vector.transfer_read %[[DEST]]{{.*}}: tensor<?x?xf32 >, vector<8x4x2xf32 > } : vector<2x4xi1 > -> vector<8x4x2xf32 >
326326func.func @xfer_read_minor_identity_transposed_masked (
327- %dest: tensor <?x 1 x f32 >,
328- %mask : vector <4 x 1 x i1 >,
329- %idx: index ) {
327+ %dest: tensor <?x?x f32 >,
328+ %mask: vector <2 x 4 x i1 >,
329+ %idx: index ) -> ( vector < 8 x 4 x 2 x f32 >) {
330330
331331 %pad = arith.constant 0.000000e+00 : f32
332- %3 = vector.mask %mask {
332+
333+ %res = vector.mask %mask {
333334 vector.transfer_read %dest [%idx , %idx ], %pad {
334- permutation_map = affine_map <(d0 , d1 ) -> (d1 , 0 , d0 )>
335- } : tensor <?x1 xf32 >, vector <1 x4 x4 xf32 >
336- } : vector <4 x1 xi1 > -> vector <1 x4 x4 xf32 >
335+ in_bounds = [true , true , true ],
336+ permutation_map = affine_map <(d0 , d1 ) -> (0 , d1 , d0 )>
337+ } : tensor <?x?xf32 >, vector <8 x4 x2 xf32 >
338+ } : vector <2 x4 xi1 > -> vector <8 x4 x2 xf32 >
337339
338- " test.some_use" (%3 ) : (vector <1 x4 x4 xf32 >) -> ()
339- return
340+ return %res : vector <8 x4 x2 xf32 >
340341}
341342
342343// CHECK-LABEL: func.func @xfer_read_minor_identity_transposed_masked_scalable(
@@ -346,7 +347,7 @@ func.func @xfer_read_minor_identity_transposed_masked(
346347// CHECK: %[[T_READ:.*]] = vector.mask %[[MASK]] { vector.transfer_read %[[DEST]]{{.*}} : tensor<?x?xf32>, vector<8x[4]x2xf32> } : vector<2x[4]xi1> -> vector<8x[4]x2xf32>
347348func.func @xfer_read_minor_identity_transposed_masked_scalable (
348349 %dest: tensor <?x?xf32 >,
349- %mask : vector <2 x[4 ]xi1 >,
350+ %mask: vector <2 x[4 ]xi1 >,
350351 %idx: index ) -> vector <8 x[4 ]x2 xf32 > {
351352
352353 %pad = arith.constant 0.000000e+00 : f32
@@ -388,17 +389,16 @@ func.func @xfer_read_minor_identitiy_bcast_dims_scalable(
388389
389390// CHECK-LABEL: func.func @xfer_read_minor_identitiy_bcast_dims_masked
390391// CHECK-SAME: %[[MEM:.*]]: memref<?x?x?x?xf32>,
391- // CHECK-SAME: %[[DIM :.*]]: index,
392+ // CHECK-SAME: %[[MASK :.*]]: vector<[4]x3xi1>
392393// CHECK-SAME: %[[IDX:.*]]: index) -> vector<8x[4]x2x3xf32> {
393394// CHECK-NOT: vector.broadcast
394- // CHECK: %[[MASK:.*]] = vector.mask %0 { vector.transfer_read %[[MEM]]{{.*}} : memref<?x?x?x?xf32>, vector<8x[4]x2x3xf32> } : vector<[4]x3xi1> -> vector<8x[4]x2x3xf32>
395+ // CHECK: vector.mask %[[MASK]] { vector.transfer_read %[[MEM]]{{.*}} : memref<?x?x?x?xf32>, vector<8x[4]x2x3xf32> } : vector<[4]x3xi1> -> vector<8x[4]x2x3xf32>
395396func.func @xfer_read_minor_identitiy_bcast_dims_masked (
396397 %mem: memref <?x?x?x?xf32 >,
397- %dim: index ,
398+ %mask: vector <[ 4 ]x 3 x i1 > ,
398399 %idx: index ) -> vector <8 x[4 ]x2 x3 xf32 > {
399400
400401 %pad = arith.constant 0.000000e+00 : f32
401- %mask = vector.create_mask %dim , %dim: vector <[4 ]x3 xi1 >
402402
403403 %res = vector.mask %mask {
404404 vector.transfer_read %mem [%idx , %idx , %idx , %idx ], %pad {
0 commit comments