@@ -174,3 +174,59 @@ func.func @non_unit_strides(%arg0: memref<512x16x1xf32, strided<[8192, 16, 4], o
174174// The inner most unit dims can not be dropped if the strides are not ones.
175175// CHECK: func.func @non_unit_strides
176176// CHECK-NOT: memref.subview
177+
178+ // -----
179+
180+ func.func @leading_scalable_dimension_transfer_read (%dest : memref <24 x1 xf32 >) -> vector <[4 ]x1 xf32 > {
181+ %c0 = arith.constant 0 : index
182+ %pad = arith.constant 0.0 : f32
183+ %0 = vector.transfer_read %dest [%c0 , %c0 ], %pad {in_bounds = [true , true ]} : memref <24 x1 xf32 >, vector <[4 ]x1 xf32 >
184+ return %0 : vector <[4 ]x1 xf32 >
185+ }
186+ // CHECK: func.func @leading_scalable_dimension_transfer_read
187+ // CHECK-SAME: %[[DEST:[a-zA-Z0-9]+]]
188+ // CHECK: %[[SUBVIEW:.+]] = memref.subview %[[DEST]][0, 0] [24, 1] [1, 1] : memref<24x1xf32> to memref<24xf32, strided<[1]>>
189+ // CHECK: %[[READ:.+]] = vector.transfer_read %[[SUBVIEW]]{{.*}} {in_bounds = [true]} : memref<24xf32, strided<[1]>>, vector<[4]xf32>
190+ // CHECK: %[[CAST:.+]] = vector.shape_cast %[[READ]] : vector<[4]xf32> to vector<[4]x1xf32>
191+ // CHECK: return %[[CAST]]
192+
193+ // -----
194+
195+ // Negative test: [1] (scalable 1) is _not_ a unit dimension.
196+ func.func @trailing_scalable_one_dim_transfer_read (%dest : memref <24 x1 xf32 >) -> vector <4 x[1 ]xf32 > {
197+ %c0 = arith.constant 0 : index
198+ %pad = arith.constant 0.0 : f32
199+ %0 = vector.transfer_read %dest [%c0 , %c0 ], %pad {in_bounds = [true , true ]} : memref <24 x1 xf32 >, vector <4 x[1 ]xf32 >
200+ return %0 : vector <4 x[1 ]xf32 >
201+ }
202+ // CHECK: func.func @trailing_scalable_one_dim_transfer_read
203+ // CHECK-NOT: vector.shape_cast
204+ // CHECK: vector.transfer_read {{.*}} : memref<24x1xf32>, vector<4x[1]xf32>
205+ // CHECK-NOT: vector.shape_cast
206+
207+ // -----
208+
209+ func.func @leading_scalable_dimension_transfer_write (%dest : memref <24 x1 xf32 >, %vec: vector <[4 ]x1 xf32 >) {
210+ %c0 = arith.constant 0 : index
211+ vector.transfer_write %vec , %dest [%c0 , %c0 ] {in_bounds = [true , true ]} : vector <[4 ]x1 xf32 >, memref <24 x1 xf32 >
212+ return
213+ }
214+ // CHECK: func.func @leading_scalable_dimension_transfer_write
215+ // CHECK-SAME: %[[DEST:[a-zA-Z0-9]+]]
216+ // CHECK-SAME: %[[VEC:[a-zA-Z0-9]+]]
217+ // CHECK: %[[SUBVIEW:.+]] = memref.subview %[[DEST]][0, 0] [24, 1] [1, 1] : memref<24x1xf32> to memref<24xf32, strided<[1]>>
218+ // CHECK: %[[CAST:.+]] = vector.shape_cast %[[VEC]] : vector<[4]x1xf32> to vector<[4]xf32>
219+ // CHECK: vector.transfer_write %[[CAST]], %[[SUBVIEW]]{{.*}} {in_bounds = [true]} : vector<[4]xf32>, memref<24xf32, strided<[1]>>
220+
221+ // -----
222+
223+ // Negative test: [1] (scalable 1) is _not_ a unit dimension.
224+ func.func @trailing_scalable_one_dim_transfer_write (%dest : memref <24 x1 xf32 >, %vec: vector <4 x[1 ]xf32 >, %index: index ) {
225+ %c0 = arith.constant 0 : index
226+ vector.transfer_write %vec , %dest [%index , %c0 ] {in_bounds = [true , true ]} : vector <4 x[1 ]xf32 >, memref <24 x1 xf32 >
227+ return
228+ }
229+ // CHECK: func.func @trailing_scalable_one_dim_transfer_write
230+ // CHECK-NOT: vector.shape_cast
231+ // CHECK: vector.transfer_write {{.*}} : vector<4x[1]xf32>, memref<24x1xf32>
232+ // CHECK-NOT: vector.shape_cast
0 commit comments