11// RUN: mlir-opt %s --transform-interpreter --split-input-file | FileCheck %s
22
3+ /// CHECK: #[[$MAP:.*]] = affine_map<(d0, d1, d2, d3) -> (d1, 0, d3)>
4+
35///----------------------------------------------------------------------------------------
4- /// vector.transfer_write -> vector.transpose + vector.transfer_write
56/// [Pattern: TransferWritePermutationLowering]
7+ ///
8+ /// IN: vector.transfer_write (_transposed_ minor identity permutation map)
9+ /// OUT: vector.transpose + vector.transfer_write (minor identity permutation map)
610///----------------------------------------------------------------------------------------
7- /// Input:
8- /// * vector.transfer_write op with a permutation that under a transpose
9- /// _would be_ a minor identity permutation map
10- /// Output:
11- /// * vector.transpose + vector.transfer_write with a permutation map which
12- /// _is_ a minor identity
13-
14- // CHECK-LABEL: func.func @xfer_write_transposing_permutation_map
11+
12+ // CHECK-LABEL: func.func @xfer_write_minor_identity_transposed
1513// CHECK-SAME: %[[VEC:.*]]: vector<4x8xi16>,
1614// CHECK-SAME: %[[MEM:.*]]: memref<2x2x8x4xi16>
1715// CHECK: %[[TR:.*]] = vector.transpose %[[VEC]], [1, 0] : vector<4x8xi16> to vector<8x4xi16>
1816// CHECK: vector.transfer_write
1917// CHECK-NOT: permutation_map
2018// CHECK-SAME: %[[TR]], %[[MEM]]{{.*}} {in_bounds = [true, true]} : vector<8x4xi16>, memref<2x2x8x4xi16>
21- func.func @xfer_write_transposing_permutation_map (
19+ func.func @xfer_write_minor_identity_transposed (
2220 %vec: vector <4 x8 xi16 >,
2321 %mem: memref <2 x2 x8 x4 xi16 >,
2422 %idx: index ) {
@@ -33,7 +31,7 @@ func.func @xfer_write_transposing_permutation_map(
3331
3432// Even with out-of-bounds accesses, it is safe to apply this pattern
3533
36- // CHECK-LABEL: func.func @xfer_write_transposing_permutation_map_out_of_bounds
34+ // CHECK-LABEL: func.func @xfer_write_minor_identity_transposed_out_of_bounds
3735// CHECK-SAME: %[[VEC:.*]]: vector<4x8xi16>,
3836// CHECK-SAME: %[[MEM:.*]]: memref<2x2x?x?xi16>,
3937// CHECK-SAME: %[[IDX:.*]]: index) {
@@ -44,7 +42,7 @@ func.func @xfer_write_transposing_permutation_map(
4442// CHECK: vector.transfer_write
4543// CHECK-NOT: permutation_map
4644// CHECK-SAME: %[[TR]], %[[MEM]][%[[IDX]], %[[IDX]], %[[IDX]], %[[IDX]]] : vector<8x4xi16>, memref<2x2x?x?xi16>
47- func.func @xfer_write_transposing_permutation_map_out_of_bounds (
45+ func.func @xfer_write_minor_identity_transposed_out_of_bounds (
4846 %vec: vector <4 x8 xi16 >,
4947 %mem: memref <2 x2 x?x?xi16 >,
5048 %idx: index ) {
@@ -57,15 +55,15 @@ func.func @xfer_write_transposing_permutation_map_out_of_bounds(
5755 return
5856}
5957
60- // CHECK-LABEL: func.func @xfer_write_transposing_permutation_map_with_mask_scalable
58+ // CHECK-LABEL: func.func @xfer_write_minor_identity_transposed_with_mask_scalable
6159// CHECK-SAME: %[[VEC:.*]]: vector<4x[8]xi16>,
6260// CHECK-SAME: %[[MEM:.*]]: memref<2x2x?x4xi16>,
6361// CHECK-SAME: %[[MASK:.*]]: vector<[8]x4xi1>
6462// CHECK: %[[TR:.*]] = vector.transpose %[[VEC]], [1, 0] : vector<4x[8]xi16> to vector<[8]x4xi16>
6563// CHECK: vector.transfer_write
6664// CHECK-NOT: permutation_map
6765// CHECK-SAME: %[[TR]], %[[MEM]]{{.*}}, %[[MASK]] {in_bounds = [true, true]} : vector<[8]x4xi16>, memref<2x2x?x4xi16>
68- func.func @xfer_write_transposing_permutation_map_with_mask_scalable (
66+ func.func @xfer_write_minor_identity_transposed_with_mask_scalable (
6967 %vec: vector <4 x[8 ]xi16 >,
7068 %mem: memref <2 x2 x?x4 xi16 >,
7169 %mask: vector <[8 ]x4 xi1 >,
@@ -82,9 +80,9 @@ func.func @xfer_write_transposing_permutation_map_with_mask_scalable(
8280
8381// Masked version is not supported
8482
85- // CHECK-LABEL: func.func @xfer_write_transposing_permutation_map_masked
83+ // CHECK-LABEL: func.func @xfer_write_minor_identity_transposed_map_masked
8684// CHECK-NOT: vector.transpose
87- func.func @xfer_write_transposing_permutation_map_masked (
85+ func.func @xfer_write_minor_identity_transposed_map_masked (
8886 %vec: vector <4 x8 xi16 >,
8987 %mem: memref <2 x2 x8 x4 xi16 >,
9088 %mask: vector <8 x4 xi1 >,
@@ -102,24 +100,24 @@ func.func @xfer_write_transposing_permutation_map_masked(
102100}
103101
104102///----------------------------------------------------------------------------------------
105- /// vector.transfer_write -> vector.broadcast + vector.transpose + vector.transfer_write
106103/// [Patterns: TransferWriteNonPermutationLowering + TransferWritePermutationLowering]
104+ ///
105+ /// IN: vector.transfer_write
106+ /// (neither a minor identity nor transposed minor identity map)
107+ /// OUT 1: vector.broadcast + vector.transfer_write
108+ /// (transposed minor identity)
109+ /// OUT 2: vector.transfer_write -> vector.broadcast + vector.transpose + vector.transfer_write
110+ /// (minor identity)
107111///----------------------------------------------------------------------------------------
108- /// Input:
109- /// * vector.transfer_write op with a map which _is not_ a permutation of a
110- /// minor identity
111- /// Output:
112- /// * vector.broadcast + vector.transpose + vector.transfer_write with a map
113- /// which _is_ a permutation of a minor identity
114-
115- // CHECK-LABEL: func.func @xfer_write_non_transposing_permutation_map(
112+
113+ // CHECK-LABEL: func.func @xfer_write_non_minor_identity(
116114// CHECK-SAME: %[[MEM:.*]]: memref<?x?xf32>,
117115// CHECK-SAME: %[[VEC:.*]]: vector<7xf32>,
118116// CHECK-SAME: %[[IDX_1:.*]]: index, %[[IDX_2:.*]]: index) {
119117// CHECK: %[[BC:.*]] = vector.broadcast %[[VEC]] : vector<7xf32> to vector<1x7xf32>
120118// CHECK: %[[TR:.*]] = vector.transpose %[[BC]], [1, 0] : vector<1x7xf32> to vector<7x1xf32>
121119// CHECK: vector.transfer_write %[[TR]], %[[MEM]]{{\[}}%[[IDX_1]], %[[IDX_2]]] {in_bounds = [false, true]} : vector<7x1xf32>, memref<?x?xf32>
122- func.func @xfer_write_non_transposing_permutation_map (
120+ func.func @xfer_write_non_minor_identity (
123121 %mem : memref <?x?xf32 >,
124122 %vec : vector <7 xf32 >,
125123 %idx_1 : index ,
@@ -134,7 +132,7 @@ func.func @xfer_write_non_transposing_permutation_map(
134132
135133// Even with out-of-bounds accesses, it is safe to apply this pattern
136134
137- // CHECK-LABEL: func.func @xfer_write_non_transposing_permutation_map_with_mask_out_of_bounds (
135+ // CHECK-LABEL: func.func @xfer_write_non_minor_identity_with_mask_out_of_bounds (
138136// CHECK-SAME: %[[MEM:.*]]: memref<?x?xf32>,
139137// CHECK-SAME: %[[VEC:.*]]: vector<7xf32>,
140138// CHECK-SAME: %[[IDX_1:.*]]: index, %[[IDX_2:.*]]: index,
@@ -144,7 +142,7 @@ func.func @xfer_write_non_transposing_permutation_map(
144142// CHECK: %[[TR_MASK:.*]] = vector.transpose %[[BC_MASK]], [1, 0] : vector<1x7xi1> to vector<7x1xi1>
145143// CHECK: %[[TR_VEC:.*]] = vector.transpose %[[BC_VEC]], [1, 0] : vector<1x7xf32> to vector<7x1xf32>
146144// CHECK: vector.transfer_write %[[TR_VEC]], %[[MEM]]{{\[}}%[[IDX_1]], %[[IDX_2]]], %[[TR_MASK]] {in_bounds = [false, true]} : vector<7x1xf32>, memref<?x?xf32>
147- func.func @xfer_write_non_transposing_permutation_map_with_mask_out_of_bounds (
145+ func.func @xfer_write_non_minor_identity_with_mask_out_of_bounds (
148146 %mem : memref <?x?xf32 >,
149147 %vec : vector <7 xf32 >,
150148 %idx_1 : index ,
@@ -159,7 +157,7 @@ func.func @xfer_write_non_transposing_permutation_map_with_mask_out_of_bounds(
159157 return
160158}
161159
162- // CHECK: func.func @permutation_with_mask_xfer_write_scalable (
160+ // CHECK-LABEL : func.func @xfer_write_non_minor_identity_with_mask_scalable (
163161// CHECK-SAME: %[[VEC:.*]]: vector<4x[8]xi16>,
164162// CHECK-SAME: %[[MEM:.*]]: memref<1x4x?x1xi16>,
165163// CHECK-SAME: %[[MASK:.*]]: vector<4x[8]xi1>
@@ -168,7 +166,7 @@ func.func @xfer_write_non_transposing_permutation_map_with_mask_out_of_bounds(
168166// CHECK: %[[TRANSPOSE_1:.*]] = vector.transpose %[[BC_2]], [1, 2, 0] : vector<1x4x[8]xi1> to vector<4x[8]x1xi1>
169167// CHECK: %[[TRANSPOSE_2:.*]] = vector.transpose %[[BC_1]], [1, 2, 0] : vector<1x4x[8]xi16> to vector<4x[8]x1xi16>
170168// CHECK: vector.transfer_write %[[TRANSPOSE_2]], %[[MEM]]{{.*}}, %[[TRANSPOSE_1]] {in_bounds = [true, true, true]} : vector<4x[8]x1xi16>, memref<1x4x?x1xi16>
171- func.func @permutation_with_mask_xfer_write_scalable (
169+ func.func @xfer_write_non_minor_identity_with_mask_scalable (
172170 %vec: vector <4 x[8 ]xi16 >,
173171 %mem: memref <1 x4 x?x1 xi16 >,
174172 %mask: vector <4 x[8 ]xi1 >,
@@ -184,14 +182,14 @@ func.func @permutation_with_mask_xfer_write_scalable(
184182
185183// Masked version is not supported
186184
187- // CHECK-LABEL: func @masked_permutation_xfer_write_fixed_width
185+ // CHECK-LABEL: func @xfer_write_non_minor_identity_masked
188186// CHECK-SAME: %[[DEST:.*]]: tensor<?x?xf32>,
189187// CHECK-SAME: %[[VEC:.*]]: vector<16xf32>,
190188// CHECK-SAME: %[[IDX:.*]]: index,
191189// CHECK-SAME: %[[MASK:.*]]: vector<16xi1>
192190// CHECK-NOT: vector.transpose
193191// CHECK: vector.mask %[[MASK]] { vector.transfer_write %[[VEC]], %[[DEST]]{{.*}} vector<16xf32>, tensor<?x?xf32> } : vector<16xi1> -> tensor<?x?xf32>
194- func.func @masked_permutation_xfer_write_fixed_width (
192+ func.func @xfer_write_non_minor_identity_masked (
195193 %dest: tensor <?x?xf32 >,
196194 %vec: vector <16 xf32 >,
197195 %idx: index ,
@@ -206,14 +204,14 @@ func.func @masked_permutation_xfer_write_fixed_width(
206204 return %res : tensor <?x?xf32 >
207205}
208206
209- // CHECK-LABEL: func.func @masked_permutation_xfer_write_scalable(
207+ // CHECK-LABEL: func.func @xfer_write_non_minor_identity_masked_scalable
210208// CHECK-SAME: %[[VEC:.*]]: vector<4x[8]xi16>,
211209// CHECK-SAME: %[[DEST:.*]]: tensor<?x?x?x?xf32>,
212210// CHECK-SAME: %[[MASK:.*]]: vector<4x[8]xi1>
213211// CHECK-SAME: -> tensor<?x?x?x?xf32> {
214212// CHECK-NOT: vector.transpose
215213// CHECK: vector.mask %[[MASK]] { vector.transfer_write %[[VEC]], %[[DEST]]{{.*}} : vector<4x[8]xi16>, tensor<?x?x?x?xf32> } : vector<4x[8]xi1> -> tensor<?x?x?x?xf32>
216- func.func @masked_permutation_xfer_write_scalable (
214+ func.func @xfer_write_non_minor_identity_masked_scalable (
217215 %vec: vector <4 x[8 ]xi16 >,
218216 %dest: tensor <?x?x?x?xf32 >,
219217 %mask: vector <4 x[8 ]xi1 >,
@@ -232,13 +230,13 @@ func.func @masked_permutation_xfer_write_scalable(
232230
233231// Masked version is not supported
234232
235- // CHECK-LABEL: func @masked_non_permutation_xfer_write_fixed_width
233+ // CHECK-LABEL: func @xfer_write_non_minor_identity_masked_2
236234// CHECK-SAME: %[[DEST:.*]]: tensor<?x?x?x?xf32>
237235// CHECK-SAME: %[[VEC:.*]]: vector<14x8x16xf32>
238236// CHECK-SAME: %[[DIM:.*]]: index, %[[IDX:.*]]: index) -> tensor<?x?x?x?xf32>
239237// CHECK-NOT: vector.broadcast
240238// CHECK: vector.mask %0 { vector.transfer_write %[[VEC]], %[[DEST]]{{.*}} : vector<14x8x16xf32>, tensor<?x?x?x?xf32> } : vector<14x8x16xi1> -> tensor<?x?x?x?xf32>
241- func.func @masked_non_permutation_xfer_write_fixed_width (
239+ func.func @xfer_write_non_minor_identity_masked_2 (
242240 %dest : tensor <?x?x?x?xf32 >,
243241 %vec : vector <14 x8 x16 xf32 >,
244242 %dim : index ,
@@ -256,15 +254,17 @@ func.func @masked_non_permutation_xfer_write_fixed_width(
256254}
257255
258256///----------------------------------------------------------------------------------------
259- /// vector.transfer_read
257+ /// [Pattern: TransferOpReduceRank (for leading 0 dim) +
258+ /// TransferReadPermutationLowering (for transposed minor identity)]
259+ ///
260+ /// IN: vector.transfer_read
261+ /// (_transposed_ minor identity permutation map, with 0 or more broadcast dims)
262+ /// OUT: vector.transpose + vector.transfer_write
263+ /// (minor identity permutation map with 0 or more leading broadcast dims)
260264///----------------------------------------------------------------------------------------
261- /// Input:
262- /// * vector.transfer_read op with a permutation map
263- /// Output:
264- /// * vector.transfer_read with a permutation map composed of leading zeros followed by a minor identiy +
265- /// vector.transpose op
265+ /// TODO: Inner broadcast dim - see also the block at the bottom of this file
266266
267- // CHECK-LABEL: func.func @permutation_with_mask_xfer_read_fixed_width(
267+ // CHECK-LABEL: func.func @xfer_read_minor_identity_tranposed_with_mask
268268// CHECK-SAME: %[[MEM:.*]]: memref<?x?xf32>,
269269// CHECK-SAME: %[[DIM_1:.*]]: index, %[[DIM_2:.*]]: index, %[[IDX:.*]]: index) -> vector<8x4x2xf32> {
270270// CHECK: %[[PASS_THROUGH:.*]] = arith.constant 0.000000e+00 : f32
@@ -273,7 +273,7 @@ func.func @masked_non_permutation_xfer_write_fixed_width(
273273// CHECK: %[[BCAST:.*]] = vector.broadcast %[[T_READ]] : vector<2x4xf32> to vector<8x2x4xf32>
274274// CHECK: %[[TRANSPOSE:.*]] = vector.transpose %[[BCAST]], [0, 2, 1] : vector<8x2x4xf32> to vector<8x4x2xf32>
275275// CHECK: return %[[TRANSPOSE]] : vector<8x4x2xf32>
276- func.func @permutation_with_mask_xfer_read_fixed_width (
276+ func.func @xfer_read_minor_identity_tranposed_with_mask (
277277 %mem: memref <?x?xf32 >,
278278 %dim_1: index ,
279279 %dim_2: index ,
@@ -290,7 +290,7 @@ func.func @permutation_with_mask_xfer_read_fixed_width(
290290 return %res : vector <8 x4 x2 xf32 >
291291}
292292
293- // CHECK-LABEL: func.func @permutation_with_mask_xfer_read_scalable (
293+ // CHECK-LABEL: func.func @xfer_read_minor_identity_tranposed_with_mask_scalable (
294294// CHECK-SAME: %[[MEM:.*]]: memref<?x?xf32>,
295295// CHECK-SAME: %[[DIM_1:.*]]: index, %[[DIM_2:.*]]: index, %[[IDX:.*]]: index) -> vector<8x[4]x2xf32> {
296296// CHECK: %[[PAD:.*]] = arith.constant 0.000000e+00 : f32
@@ -299,7 +299,7 @@ func.func @permutation_with_mask_xfer_read_fixed_width(
299299// CHECK: %[[BCAST:.*]] = vector.broadcast %[[T_READ]] : vector<2x[4]xf32> to vector<8x2x[4]xf32>
300300// CHECK: %[[TRANSPOSE:.*]] = vector.transpose %[[BCAST]], [0, 2, 1] : vector<8x2x[4]xf32> to vector<8x[4]x2xf32>
301301// CHECK: return %[[TRANSPOSE]] : vector<8x[4]x2xf32>
302- func.func @permutation_with_mask_xfer_read_scalable (
302+ func.func @xfer_read_minor_identity_tranposed_with_mask_scalable (
303303 %mem: memref <?x?xf32 >,
304304 %dim_1: index ,
305305 %dim_2: index ,
@@ -318,12 +318,12 @@ func.func @permutation_with_mask_xfer_read_scalable(
318318
319319// Masked version is not supported
320320
321- // CHECK-LABEL: func @masked_permutation_xfer_read_fixed_width
321+ // CHECK-LABEL: func @xfer_read_minor_identity_transposed_masked(
322322// CHECK-SAME: %[[DEST:.*]]: tensor<?x1xf32>,
323323// CHECK-SAME: %[[MASK:.*]]: vector<4x1xi1>
324324// CHECK-NOT: vector.transpose
325325// CHECK: vector.mask %[[MASK]] { vector.transfer_read %[[DEST]]{{.*}}: tensor<?x1xf32>, vector<1x4x4xf32> } : vector<4x1xi1> -> vector<1x4x4xf32>
326- func.func @masked_permutation_xfer_read_fixed_width (
326+ func.func @xfer_read_minor_identity_transposed_masked (
327327 %dest: tensor <?x1 xf32 >,
328328 %mask : vector <4 x1 xi1 >,
329329 %idx: index ) {
@@ -339,12 +339,12 @@ func.func @masked_permutation_xfer_read_fixed_width(
339339 return
340340}
341341
342- // CHECK-LABEL: func.func @masked_permutation_xfer_read_scalable (
342+ // CHECK-LABEL: func.func @xfer_read_minor_identity_transposed_masked_scalable (
343343// CHECK-SAME: %[[DEST:.*]]: tensor<?x?xf32>,
344344// CHECK-SAME: %[[MASK:.*]]: vector<2x[4]xi1>
345345// CHECK-NOT: vector.transpose
346346// CHECK: %[[T_READ:.*]] = vector.mask %[[MASK]] { vector.transfer_read %[[DEST]]{{.*}} : tensor<?x?xf32>, vector<8x[4]x2xf32> } : vector<2x[4]xi1> -> vector<8x[4]x2xf32>
347- func.func @masked_permutation_xfer_read_scalable (
347+ func.func @xfer_read_minor_identity_transposed_masked_scalable (
348348 %dest: tensor <?x?xf32 >,
349349 %mask : vector <2 x[4 ]xi1 >,
350350 %idx: index ) -> vector <8 x[4 ]x2 xf32 > {
@@ -361,31 +361,17 @@ func.func @masked_permutation_xfer_read_scalable(
361361 return %res : vector <8 x[4 ]x2 xf32 >
362362}
363363
364- module attributes {transform.with_named_sequence } {
365- transform.named_sequence @__transform_main (%module_op: !transform.any_op {transform.readonly }) {
366- %f = transform.structured.match ops {[" func.func" ]} in %module_op
367- : (!transform.any_op ) -> !transform.any_op
368- transform.apply_patterns to %f {
369- transform.apply_patterns.vector.transfer_permutation_patterns
370- } : !transform.any_op
371- transform.yield
372- }
373- }
374-
375- // -----
376-
377364///----------------------------------------------------------------------------------------
378365/// vector.transfer_read
379366///----------------------------------------------------------------------------------------
380367/// TODO: Review and categorize
381368
382- // CHECK: #[[MAP:.*]] = affine_map<(d0, d1, d2, d3) -> (d1, 0, d3)>
383- // CHECK: func.func @transfer_read_reduce_rank_scalable(
369+ // CHECK-LABEL: func.func @xfer_read_minor_identitiy_bcast_dims_scalable
384370// CHECK-SAME: %[[MEM:.*]]: memref<?x?x?x?xf32>, %[[IDX:.*]]: index) -> vector<8x[4]x2x3xf32> {
385- // CHECK: %[[T_READ:.*]] = vector.transfer_read %[[MEM]][%[[IDX]], %[[IDX]], %[[IDX]], %[[IDX]]]{{.*}} permutation_map = #[[MAP]]} : memref<?x?x?x?xf32>, vector<[4]x2x3xf32>
371+ // CHECK: %[[T_READ:.*]] = vector.transfer_read %[[MEM]][%[[IDX]], %[[IDX]], %[[IDX]], %[[IDX]]]{{.*}} permutation_map = #[[$ MAP]]} : memref<?x?x?x?xf32>, vector<[4]x2x3xf32>
386372// CHECK: %[[BC:.*]] = vector.broadcast %[[T_READ]] : vector<[4]x2x3xf32> to vector<8x[4]x2x3xf32>
387373// CHECK: return %[[BC]] : vector<8x[4]x2x3xf32>
388- func.func @transfer_read_reduce_rank_scalable (
374+ func.func @xfer_read_minor_identitiy_bcast_dims_scalable (
389375 %mem: memref <?x?x?x?xf32 >, %idx: index ) -> vector <8 x[4 ]x2 x3 xf32 > {
390376
391377 %pad = arith.constant 0.000000e+00 : f32
@@ -400,13 +386,13 @@ func.func @transfer_read_reduce_rank_scalable(
400386
401387// Masked version is not supported
402388
403- // CHECK-LABEL: func.func @masked_transfer_read_reduce_rank(
389+ // CHECK-LABEL: func.func @xfer_read_minor_identitiy_bcast_dims_masked
404390// CHECK-SAME: %[[MEM:.*]]: memref<?x?x?x?xf32>,
405391// CHECK-SAME: %[[DIM:.*]]: index,
406392// CHECK-SAME: %[[IDX:.*]]: index) -> vector<8x[4]x2x3xf32> {
407393// CHECK-NOT: vector.broadcast
408394// CHECK: %[[MASK:.*]] = vector.mask %0 { vector.transfer_read %[[MEM]]{{.*}} : memref<?x?x?x?xf32>, vector<8x[4]x2x3xf32> } : vector<[4]x3xi1> -> vector<8x[4]x2x3xf32>
409- func.func @masked_transfer_read_reduce_rank (
395+ func.func @xfer_read_minor_identitiy_bcast_dims_masked (
410396 %mem: memref <?x?x?x?xf32 >,
411397 %dim: index ,
412398 %idx: index ) -> vector <8 x[4 ]x2 x3 xf32 > {
0 commit comments