Skip to content

Commit efe3db2

Browse files
authored
[mlir][vector] Add tests for populateSinkVectorBroadcastPatterns (1/n) (llvm#102286)
Adds tests for scalable vectors in: * sink-vector-broadcast.mlir This test file excercises patterns grouped under `populateSinkVectorBroadcastPatterns`, which includes: * `ReorderElementwiseOpsOnBroadcast`, * `ReorderCastOpsOnBroadcast`. Right now there are only tests for the former. However, I've noticed that "vector-reduce-to-contract.mlir" contains tests for the latter and I've left a few TODOs to group these tests back together in one file. Additionally, added some helpful `notifyMatchFailure` messages in `ReorderElementwiseOpsOnBroadcast`.
1 parent 9c70205 commit efe3db2

File tree

3 files changed

+120
-24
lines changed

3 files changed

+120
-24
lines changed

mlir/lib/Dialect/Vector/Transforms/VectorTransforms.cpp

Lines changed: 10 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -979,15 +979,18 @@ struct ReorderElementwiseOpsOnBroadcast final
979979
if (!llvm::isa<ShapedType>(op->getResults()[0].getType()))
980980
return failure();
981981
if (!OpTrait::hasElementwiseMappableTraits(op))
982+
return rewriter.notifyMatchFailure(
983+
op, "Op doesn't have ElementwiseMappableTraits");
984+
if (op->getNumOperands() == 0)
982985
return failure();
983-
if (op->getNumOperands() == 0 ||
984-
op->getResults()[0].getType() != op->getOperand(0).getType()) {
985-
return failure();
986-
}
987-
// Avoid operations that only accept vector types, since broadcast
988-
// source might be scalar types.
986+
if (op->getResults()[0].getType() != op->getOperand(0).getType())
987+
return rewriter.notifyMatchFailure(op,
988+
"result and operand type mismatch");
989989
if (isa<vector::FMAOp>(op)) {
990-
return failure();
990+
return rewriter.notifyMatchFailure(
991+
op,
992+
"Op only accepts vector types - not supported as broadcast source "
993+
"might be a scalar");
991994
}
992995

993996
// Get the type of the lhs operand

mlir/test/Dialect/Vector/sink-vector-broadcast.mlir

Lines changed: 100 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -1,18 +1,35 @@
11
// RUN: mlir-opt %s -test-sink-vector-broadcast -split-input-file | FileCheck %s
22

3+
//-----------------------------------------------------------------------------
4+
// [Pattern: ReorderElementwiseOpsOnBroadcast]
5+
//-----------------------------------------------------------------------------
6+
37
// CHECK-LABEL: func.func @broadcast_scalar_with_bcast(
48
// CHECK-SAME: %[[ARG_0:.*]]: index, %[[ARG_1:.*]]: index) -> vector<1x4xindex> {
59
// CHECK: %[[ADD:.*]] = arith.addi %[[ARG_0]], %[[ARG_1]] : index
610
// CHECK: %[[BCAST:.*]] = vector.broadcast %[[ADD]] : index to vector<1x4xindex>
711
// CHECK: return %[[BCAST]] : vector<1x4xindex>
812

9-
func.func @broadcast_scalar_with_bcast( %arg1: index, %arg2: index) -> vector<1x4xindex> {
13+
func.func @broadcast_scalar_with_bcast(%arg1: index, %arg2: index) -> vector<1x4xindex> {
1014
%0 = vector.broadcast %arg1 : index to vector<1x4xindex>
1115
%1 = vector.broadcast %arg2 : index to vector<1x4xindex>
1216
%2 = arith.addi %0, %1 : vector<1x4xindex>
1317
return %2 : vector<1x4xindex>
1418
}
1519

20+
// CHECK-LABEL: func.func @broadcast_scalar_with_bcast_scalable(
21+
// CHECK-SAME: %[[ARG_0:.*]]: index, %[[ARG_1:.*]]: index) -> vector<1x[4]xindex> {
22+
// CHECK: %[[ADD:.*]] = arith.addi %[[ARG_0]], %[[ARG_1]] : index
23+
// CHECK: %[[BCAST:.*]] = vector.broadcast %[[ADD]] : index to vector<1x[4]xindex>
24+
// CHECK: return %[[BCAST]] : vector<1x[4]xindex>
25+
26+
func.func @broadcast_scalar_with_bcast_scalable(%arg1: index, %arg2: index) -> vector<1x[4]xindex> {
27+
%0 = vector.broadcast %arg1 : index to vector<1x[4]xindex>
28+
%1 = vector.broadcast %arg2 : index to vector<1x[4]xindex>
29+
%2 = arith.addi %0, %1 : vector<1x[4]xindex>
30+
return %2 : vector<1x[4]xindex>
31+
}
32+
1633
// -----
1734

1835
// CHECK-LABEL: func.func @broadcast_scalar_with_bcast_and_splat(
@@ -21,13 +38,26 @@ func.func @broadcast_scalar_with_bcast( %arg1: index, %arg2: index) -> vector<1x
2138
// CHECK: %[[ADD:.*]] = arith.addi %[[ARG1]], %[[ARG2]] : index
2239
// CHECK: %[[BCAST:.*]] = vector.broadcast %[[ADD]] : index to vector<1x4xindex>
2340
// CHECK: return %[[BCAST]] : vector<1x4xindex>
24-
func.func @broadcast_scalar_with_bcast_and_splat( %arg1: index, %arg2: index) -> vector<1x4xindex> {
41+
func.func @broadcast_scalar_with_bcast_and_splat(%arg1: index, %arg2: index) -> vector<1x4xindex> {
2542
%0 = vector.splat %arg1 : vector<1x4xindex>
2643
%1 = vector.broadcast %arg2 : index to vector<1x4xindex>
2744
%2 = arith.addi %0, %1 : vector<1x4xindex>
2845
return %2 : vector<1x4xindex>
2946
}
3047

48+
// CHECK-LABEL: func.func @broadcast_scalar_with_bcast_and_splat_scalable(
49+
// CHECK-SAME: %[[ARG1:.*]]: index,
50+
// CHECK-SAME: %[[ARG2:.*]]: index) -> vector<1x[4]xindex> {
51+
// CHECK: %[[ADD:.*]] = arith.addi %[[ARG1]], %[[ARG2]] : index
52+
// CHECK: %[[BCAST:.*]] = vector.broadcast %[[ADD]] : index to vector<1x[4]xindex>
53+
// CHECK: return %[[BCAST]] : vector<1x[4]xindex>
54+
func.func @broadcast_scalar_with_bcast_and_splat_scalable(%arg1: index, %arg2: index) -> vector<1x[4]xindex> {
55+
%0 = vector.splat %arg1 : vector<1x[4]xindex>
56+
%1 = vector.broadcast %arg2 : index to vector<1x[4]xindex>
57+
%2 = arith.addi %0, %1 : vector<1x[4]xindex>
58+
return %2 : vector<1x[4]xindex>
59+
}
60+
3161
// -----
3262

3363
// CHECK-LABEL: func.func @broadcast_vector(
@@ -37,13 +67,27 @@ func.func @broadcast_scalar_with_bcast_and_splat( %arg1: index, %arg2: index) ->
3767
// CHECK: %[[BCAST:.*]] = vector.broadcast %[[ADDF]] : vector<4xf32> to vector<3x4xf32>
3868
// CHECK: return %[[BCAST]] : vector<3x4xf32>
3969

40-
func.func @broadcast_vector( %arg1: vector<4xf32>, %arg2: vector<4xf32>) -> vector<3x4xf32> {
70+
func.func @broadcast_vector(%arg1: vector<4xf32>, %arg2: vector<4xf32>) -> vector<3x4xf32> {
4171
%arg1_bcast = vector.broadcast %arg1 : vector<4xf32> to vector<3x4xf32>
4272
%arg2_bcast = vector.broadcast %arg2 : vector<4xf32> to vector<3x4xf32>
4373
%2 = arith.addf %arg1_bcast, %arg2_bcast : vector<3x4xf32>
4474
return %2 : vector<3x4xf32>
4575
}
4676

77+
// CHECK-LABEL: func.func @broadcast_vector_scalable(
78+
// CHECK-SAME: %[[ARG_0:.*]]: vector<[4]xf32>,
79+
// CHECK-SAME: %[[ARG_1:.*]]: vector<[4]xf32>) -> vector<3x[4]xf32> {
80+
// CHECK: %[[ADDF:.*]] = arith.addf %[[ARG_0]], %[[ARG_1]] : vector<[4]xf32>
81+
// CHECK: %[[BCAST:.*]] = vector.broadcast %[[ADDF]] : vector<[4]xf32> to vector<3x[4]xf32>
82+
// CHECK: return %[[BCAST]] : vector<3x[4]xf32>
83+
84+
func.func @broadcast_vector_scalable(%arg1: vector<[4]xf32>, %arg2: vector<[4]xf32>) -> vector<3x[4]xf32> {
85+
%arg1_bcast = vector.broadcast %arg1 : vector<[4]xf32> to vector<3x[4]xf32>
86+
%arg2_bcast = vector.broadcast %arg2 : vector<[4]xf32> to vector<3x[4]xf32>
87+
%2 = arith.addf %arg1_bcast, %arg2_bcast : vector<3x[4]xf32>
88+
return %2 : vector<3x[4]xf32>
89+
}
90+
4791
// -----
4892

4993
// CHECK-LABEL: func.func @broadcast_scalar_and_vec(
@@ -53,13 +97,27 @@ func.func @broadcast_vector( %arg1: vector<4xf32>, %arg2: vector<4xf32>) -> vect
5397
// CHECK: %[[BCAST:.*]] = vector.broadcast %[[ARG2]] : vector<4xindex> to vector<1x4xindex>
5498
// CHECK: %[[ADD:.*]] = arith.addi %[[SPLAT]], %[[BCAST]] : vector<1x4xindex>
5599
// CHECK: return %[[ADD]] : vector<1x4xindex>
56-
func.func @broadcast_scalar_and_vec( %arg1: index, %arg2: vector<4xindex>) -> vector<1x4xindex> {
100+
func.func @broadcast_scalar_and_vec(%arg1: index, %arg2: vector<4xindex>) -> vector<1x4xindex> {
57101
%0 = vector.splat %arg1 : vector<1x4xindex>
58102
%1 = vector.broadcast %arg2 : vector<4xindex> to vector<1x4xindex>
59103
%2 = arith.addi %0, %1 : vector<1x4xindex>
60104
return %2 : vector<1x4xindex>
61105
}
62106

107+
// CHECK-LABEL: func.func @broadcast_scalar_and_vec_scalable(
108+
// CHECK-SAME: %[[ARG1:.*]]: index,
109+
// CHECK-SAME: %[[ARG2:.*]]: vector<[4]xindex>) -> vector<1x[4]xindex> {
110+
// CHECK: %[[SPLAT:.*]] = vector.splat %[[ARG1]] : vector<1x[4]xindex>
111+
// CHECK: %[[BCAST:.*]] = vector.broadcast %[[ARG2]] : vector<[4]xindex> to vector<1x[4]xindex>
112+
// CHECK: %[[ADD:.*]] = arith.addi %[[SPLAT]], %[[BCAST]] : vector<1x[4]xindex>
113+
// CHECK: return %[[ADD]] : vector<1x[4]xindex>
114+
func.func @broadcast_scalar_and_vec_scalable(%arg1: index, %arg2: vector<[4]xindex>) -> vector<1x[4]xindex> {
115+
%0 = vector.splat %arg1 : vector<1x[4]xindex>
116+
%1 = vector.broadcast %arg2 : vector<[4]xindex> to vector<1x[4]xindex>
117+
%2 = arith.addi %0, %1 : vector<1x[4]xindex>
118+
return %2 : vector<1x[4]xindex>
119+
}
120+
63121
// -----
64122

65123
// CHECK-LABEL: func.func @broadcast_vector_and_scalar(
@@ -69,12 +127,25 @@ func.func @broadcast_scalar_and_vec( %arg1: index, %arg2: vector<4xindex>) -> ve
69127
// CHECK: %[[ADD:.*]] = arith.addi %[[BCAST]], %[[ARG_1]] : vector<4xi32>
70128
// CHECK: return %[[ADD]] : vector<4xi32>
71129

72-
func.func @broadcast_vector_and_scalar( %arg1: i32, %arg2: vector<4xi32>) -> vector<4xi32> {
130+
func.func @broadcast_vector_and_scalar(%arg1: i32, %arg2: vector<4xi32>) -> vector<4xi32> {
73131
%arg1_bcast = vector.broadcast %arg1 : i32 to vector<4xi32>
74132
%2 = arith.addi %arg1_bcast, %arg2 : vector<4xi32>
75133
return %2 : vector<4xi32>
76134
}
77135

136+
// CHECK-LABEL: func.func @broadcast_vector_and_scalar_scalable(
137+
// CHECK-SAME: %[[ARG_0:.*]]: i32,
138+
// CHECK-SAME: %[[ARG_1:.*]]: vector<[4]xi32>) -> vector<[4]xi32> {
139+
// CHECK: %[[BCAST:.*]] = vector.broadcast %[[ARG_0]] : i32 to vector<[4]xi32>
140+
// CHECK: %[[ADD:.*]] = arith.addi %[[BCAST]], %[[ARG_1]] : vector<[4]xi32>
141+
// CHECK: return %[[ADD]] : vector<[4]xi32>
142+
143+
func.func @broadcast_vector_and_scalar_scalable(%arg1: i32, %arg2: vector<[4]xi32>) -> vector<[4]xi32> {
144+
%arg1_bcast = vector.broadcast %arg1 : i32 to vector<[4]xi32>
145+
%2 = arith.addi %arg1_bcast, %arg2 : vector<[4]xi32>
146+
return %2 : vector<[4]xi32>
147+
}
148+
78149
// -----
79150

80151
#matmat_accesses = [
@@ -87,40 +158,52 @@ func.func @broadcast_vector_and_scalar( %arg1: i32, %arg2: vector<4xi32>) -> vec
87158
iterator_types = ["parallel", "parallel", "reduction"]
88159
}
89160

90-
// CHECK-LABEL: func.func @broadcast_not_elementwise() -> vector<2x2xf32> {
91-
// CHECK-DAG: %[[VAL_0:.*]] = arith.constant dense<1.000000e+00> : vector<2x2xf32>
92-
// CHECK-DAG: %[[VAL_1:.*]] = arith.constant dense<2.000000e+00> : vector<2x2xf32>
93-
// CHECK-DAG: %[[VAL_2:.*]] = arith.constant dense<3.000000e+00> : vector<2x2xf32>
94-
// CHECK: %[[VAL_3:.*]] = vector.contract {indexing_maps = [#map, #map1, #map2], iterator_types = ["parallel", "parallel", "reduction"], kind = #vector.kind<add>} %[[VAL_0]], %[[VAL_1]], %[[VAL_2]] : vector<2x2xf32>, vector<2x2xf32> into vector<2x2xf32>
95-
func.func @broadcast_not_elementwise() -> vector<2x2xf32> {
161+
// CHECK-LABEL: func.func @negative_not_elementwise
162+
// CHECK-DAG: %[[F1:.*]] = arith.constant dense<1.000000e+00> : vector<2x2xf32>
163+
// CHECK-DAG: %[[F2:.*]] = arith.constant dense<2.000000e+00> : vector<2x2xf32>
164+
// CHECK-DAG: %[[F3:.*]] = arith.constant dense<3.000000e+00> : vector<2x2xf32>
165+
// CHECK: %[[RES:.*]] = vector.contract {indexing_maps = [#map, #map1, #map2], iterator_types = ["parallel", "parallel", "reduction"], kind = #vector.kind<add>} %[[F1]], %[[F2]], %[[F3]] : vector<2x2xf32>, vector<2x2xf32> into vector<2x2xf32>
166+
func.func @negative_not_elementwise() -> vector<2x2xf32> {
96167
%f1 = arith.constant 1.0: f32
97168
%f2 = arith.constant 2.0: f32
98169
%f3 = arith.constant 3.0: f32
99170

100171
%A = vector.broadcast %f1 : f32 to vector<2x2xf32>
101172
%B = vector.broadcast %f2 : f32 to vector<2x2xf32>
102173
%C = vector.broadcast %f3 : f32 to vector<2x2xf32>
103-
%mm1 = vector.contract #matmat_trait %A, %B, %C
174+
%res = vector.contract #matmat_trait %A, %B, %C
104175
: vector<2x2xf32>, vector<2x2xf32> into vector<2x2xf32>
105176

106-
return %mm1 : vector<2x2xf32>
177+
return %res : vector<2x2xf32>
107178
}
108179

109-
// CHECK-LABEL: func.func @dont_sink_cmp(
180+
// -----
181+
182+
// The source and the result for arith.cmp have different types - not supported
183+
184+
// CHECK-LABEL: func.func @negative_source_and_result_mismatch
110185
// CHECK: %[[BROADCAST:.+]] = vector.broadcast
111186
// CHECK: %[[RETURN:.+]] = arith.cmpf uno, %[[BROADCAST]], %[[BROADCAST]]
112187
// CHECK: return %[[RETURN]]
113-
func.func @dont_sink_cmp(%arg0 : f32, %arg1 : vector<1xf32>) -> vector<1xi1> {
188+
func.func @negative_source_and_result_mismatch(%arg0 : f32, %arg1 : vector<1xf32>) -> vector<1xi1> {
114189
%0 = vector.broadcast %arg0 : f32 to vector<1xf32>
115190
%1 = arith.cmpf uno, %0, %0 : vector<1xf32>
116191
return %1 : vector<1xi1>
117192
}
118193

119-
// CHECK-LABEL: func.func @dont_sink_fma(
194+
// -----
195+
196+
// vector.fma only supports vectors - currently it's not possible to replace this with e.g.:
197+
// %scalar_res = vector.fma %scalar_1, %scalar2
198+
// %vec_res = vector.broadcast %scalar_res
199+
//
200+
// TODO: It should be possible to support this case
201+
202+
// CHECK-LABEL: func.func @negative_op_only_supports_vectors
120203
// CHECK: %[[BROADCAST:.+]] = vector.broadcast
121204
// CHECK: %[[RESULT:.+]] = vector.fma %[[BROADCAST]]
122205
// CHECK: return %[[RESULT]]
123-
func.func @dont_sink_fma(%arg0 : f32) -> vector<1xf32> {
206+
func.func @negative_op_only_supports_vectors(%arg0 : f32) -> vector<1xf32> {
124207
%0 = vector.broadcast %arg0 : f32 to vector<1xf32>
125208
%1 = vector.fma %0, %0, %0 : vector<1xf32>
126209
return %1 : vector<1xf32>

mlir/test/Dialect/Vector/vector-reduce-to-contract.mlir

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -246,8 +246,12 @@ func.func @contract_broadcast_would_have_no_reduction_dim_pair(%arg0 : vector<1x
246246

247247

248248
//===----------------------------------------------------------------------===//
249+
// [Pattern: ReorderCastOpsOnBroadcast]
250+
//
249251
// Reorder casting ops and vector ops. The casting ops have almost identical
250252
// pattern, so only arith.extsi op is tested.
253+
//
254+
// TODO: Potential duplication with sink-vector-broadcast.mlir
251255
//===----------------------------------------------------------------------===//
252256

253257
// -----
@@ -272,6 +276,11 @@ func.func @broadcast_scalar_extsi(%a : i8) -> vector<2x4xi32> {
272276

273277
// -----
274278

279+
//===----------------------------------------------------------------------===//
280+
// [Pattern: ReorderElementwiseOpsOnTranspose]
281+
//
282+
// TODO: Potential duplication with sink-vector-broadcast.mlir
283+
//===----------------------------------------------------------------------===//
275284
func.func @transpose_extsi(%a : vector<4x2xi8>) -> vector<2x4xi32> {
276285
// CHECK: %[[EXT:.+]] = arith.extsi %{{.+}} : vector<4x2xi8> to vector<4x2xi32>
277286
// CHECK: vector.transpose %[[EXT]], [1, 0] : vector<4x2xi32> to vector<2x4xi32>
@@ -282,6 +291,7 @@ func.func @transpose_extsi(%a : vector<4x2xi8>) -> vector<2x4xi32> {
282291

283292
//===----------------------------------------------------------------------===//
284293
// Reorder elementwise ops and vector ops.
294+
// TODO: Potential duplication with sink-vector-broadcast.mlir
285295
//===----------------------------------------------------------------------===//
286296

287297
// -----

0 commit comments

Comments
 (0)