@@ -840,6 +840,99 @@ module attributes {transform.with_named_sequence} {
840
840
}
841
841
}
842
842
843
+ // -----
844
+
845
+ ///----------------------------------------------------------------------------------------
846
+ /// Tests for linalg.mmt4d
847
+ ///----------------------------------------------------------------------------------------
848
+
849
+ func.func @mmt4d (%A: memref <16 x16 x8 x1 xf32 >, %B: memref <16 x16 x8 x1 xf32 >, %C_in: memref <16 x16 x8 x8 xf32 >) {
850
+ linalg.mmt4d ins (%A , %B: memref <16 x16 x8 x1 xf32 >, memref <16 x16 x8 x1 xf32 >)
851
+ outs (%C_in: memref <16 x16 x8 x8 xf32 >)
852
+ return
853
+ }
854
+
855
+ // CHECK-LABEL: func.func @mmt4d(
856
+ // CHECK-SAME: %[[A:.*]]: memref<16x16x8x1xf32>, %[[B:.*]]: memref<16x16x8x1xf32>, %[[C:.*]]: memref<16x16x8x8xf32>) {
857
+ // CHECK: %[[VEC_A:.*]] = vector.transfer_read %[[A]]{{.*}} : memref<16x16x8x1xf32>, vector<16x16x16x8x8x1xf32>
858
+ // CHECK: %[[VEC_B:.*]] = vector.transfer_read %[[B]]{{.*}} : memref<16x16x8x1xf32>, vector<16x16x16x8x8x1xf32>
859
+ // CHECK: %[[VEC_C:.*]] = vector.transfer_read %[[C]]{{.*}} : memref<16x16x8x8xf32>, vector<16x16x8x8xf32>
860
+ // CHECK: %[[MUL:.*]] = arith.mulf %[[VEC_A]], %[[VEC_B]] : vector<16x16x16x8x8x1xf32>
861
+ // CHECK: %[[RED:.*]] = vector.multi_reduction <add>, %[[MUL]], %[[VEC_C]] [2, 5] : vector<16x16x16x8x8x1xf32> to vector<16x16x8x8xf32>
862
+ // CHECK: vector.transfer_write %[[RED]], %[[C]]{{.*}} : vector<16x16x8x8xf32>, memref<16x16x8x8xf32>
863
+
864
+ module attributes {transform.with_named_sequence } {
865
+ transform.named_sequence @__transform_main (%arg1: !transform.any_op {transform.readonly }) {
866
+ %mmt4d = transform.structured.match ops {[" linalg.mmt4d" ]} in %arg1 : (!transform.any_op ) -> !transform.any_op
867
+ transform.structured.vectorize %mmt4d : !transform.any_op
868
+ transform.yield
869
+ }
870
+ }
871
+
872
+ // -----
873
+
874
+ func.func @mmt4d_scalable (%A: memref <16 x16 x8 x1 xf32 >, %B: memref <16 x16 x?x1 xf32 >, %C_in: memref <16 x16 x8 x?xf32 >) {
875
+ linalg.mmt4d ins (%A , %B: memref <16 x16 x8 x1 xf32 >, memref <16 x16 x?x1 xf32 >)
876
+ outs (%C_in: memref <16 x16 x8 x?xf32 >)
877
+ return
878
+ }
879
+ // CHECK-LABEL: func.func @mmt4d_scalable(
880
+ // CHECK-SAME: %[[A:.*]]: memref<16x16x8x1xf32>,
881
+ // CHECK-SAME: %[[B:.*]]: memref<16x16x?x1xf32>,
882
+ // CHECK-SAME: %[[C_IN:.*]]: memref<16x16x8x?xf32>) {
883
+ // CHECK: %[[VAL_0:.*]] = arith.constant 16 : index
884
+ // CHECK: %[[VAL_1:.*]] = arith.constant 16 : index
885
+ // CHECK: %[[VAL_2:.*]] = arith.constant 16 : index
886
+ // CHECK: %[[C8:.*]] = arith.constant 8 : index
887
+ // CHECK: %[[C2:.*]] = arith.constant 2 : index
888
+ // CHECK: %[[DIM_2:.*]] = memref.dim %[[B]], %[[C2]] : memref<16x16x?x1xf32>
889
+ // CHECK: %[[VAL_6:.*]] = arith.constant 1 : index
890
+ // CHECK: %[[VEC_A:.*]] = vector.transfer_read %[[A]]{{.*}} : memref<16x16x8x1xf32>, vector<16x16x16x8x[4]x1xf32>
891
+ // CHECK: %[[MASK_1:.*]] = vector.create_mask %[[VAL_1]], %[[VAL_2]], %[[DIM_2]], %[[VAL_6]] : vector<16x16x[4]x1xi1>
892
+ // CHECK: %[[VEC_B:.*]] = vector.mask %[[MASK_1]] { vector.transfer_read %[[B]]{{.*}} : memref<16x16x?x1xf32>, vector<16x16x16x8x[4]x1xf32> } : vector<16x16x[4]x1xi1> -> vector<16x16x16x8x[4]x1xf32>
893
+ // CHECK: %[[MASK_2:.*]] = vector.create_mask %[[VAL_0]], %[[VAL_1]], %[[C8]], %[[DIM_2]] : vector<16x16x8x[4]xi1>
894
+ // CHECK: %[[VAL_15:.*]] = vector.mask %[[MASK_2]] { vector.transfer_read %[[C_IN]]{{.*}} : memref<16x16x8x?xf32>, vector<16x16x8x[4]xf32> } : vector<16x16x8x[4]xi1> -> vector<16x16x8x[4]xf32>
895
+ // CHECK: %[[VAL_16:.*]] = arith.mulf %[[VEC_A]], %[[VEC_B]] : vector<16x16x16x8x[4]x1xf32>
896
+ // CHECK: %[[MASK_3:.*]] = vector.create_mask %[[VAL_0]], %[[VAL_1]], %[[VAL_2]], %[[C8]], %[[DIM_2]], %[[VAL_6]] : vector<16x16x16x8x[4]x1xi1>
897
+ // CHECK: %[[VAL_18:.*]] = vector.mask %[[MASK_3]] { vector.multi_reduction <add>, %[[VAL_16]], %[[VAL_15]] [2, 5] : vector<16x16x16x8x[4]x1xf32> to vector<16x16x8x[4]xf32> } : vector<16x16x16x8x[4]x1xi1> -> vector<16x16x8x[4]xf32>
898
+ // CHECK: vector.mask %[[MASK_2]] { vector.transfer_write %[[VAL_18]], %[[C_IN]]{{.*}} : vector<16x16x8x[4]xf32>, memref<16x16x8x?xf32> } : vector<16x16x8x[4]xi1>
899
+
900
+
901
+ module attributes {transform.with_named_sequence } {
902
+ transform.named_sequence @__transform_main (%arg1: !transform.any_op {transform.readonly }) {
903
+ %mmt4d = transform.structured.match ops {[" linalg.mmt4d" ]} in %arg1 : (!transform.any_op ) -> !transform.any_op
904
+ transform.structured.vectorize %mmt4d vector_sizes [16 , 16 , 16 , 8 , [4 ], 1 ] : !transform.any_op
905
+ transform.yield
906
+ }
907
+ }
908
+
909
+ // -----
910
+
911
+ func.func @mmt4d_scalable_with_assume (%A: memref <16 x16 x8 x1 xf32 >, %B: memref <16 x16 x?x1 xf32 >, %C_in: memref <16 x16 x8 x?xf32 >) {
912
+ linalg.mmt4d ins (%A , %B: memref <16 x16 x8 x1 xf32 >, memref <16 x16 x?x1 xf32 >)
913
+ outs (%C_in: memref <16 x16 x8 x?xf32 >)
914
+ return
915
+ }
916
+ // CHECK-LABEL: func.func @mmt4d_scalable_with_assume(
917
+ // CHECK-SAME: %[[A:.*]]: memref<16x16x8x1xf32>,
918
+ // CHECK-SAME: %[[B:.*]]: memref<16x16x?x1xf32>,
919
+ // CHECK-SAME: %[[C_IN:.*]]: memref<16x16x8x?xf32>) {
920
+ // CHECK-NOT: mask
921
+ // CHECK: %[[VEC_A:.*]] = vector.transfer_read %[[A]]{{.*}} : memref<16x16x8x1xf32>, vector<16x16x16x8x[4]x1xf32>
922
+ // CHECK: %[[VEC_B:.*]] = vector.transfer_read %[[B]]{{.*}} : memref<16x16x?x1xf32>, vector<16x16x16x8x[4]x1xf32>
923
+ // CHECK: %[[VAL_13:.*]] = vector.transfer_read %[[C_IN]]{{.*}} : memref<16x16x8x?xf32>, vector<16x16x8x[4]xf32>
924
+ // CHECK: %[[VAL_14:.*]] = arith.mulf %[[VEC_A]], %[[VEC_B]] : vector<16x16x16x8x[4]x1xf32>
925
+ // CHECK: %[[VAL_15:.*]] = vector.multi_reduction <add>, %[[VAL_14]], %[[VAL_13]] [2, 5] : vector<16x16x16x8x[4]x1xf32> to vector<16x16x8x[4]xf32>
926
+ // CHECK: vector.transfer_write %[[VAL_15]], %[[C_IN]]{{.*}} : vector<16x16x8x[4]xf32>, memref<16x16x8x?xf32>
927
+
928
+ module attributes {transform.with_named_sequence } {
929
+ transform.named_sequence @__transform_main (%arg1: !transform.any_op {transform.readonly }) {
930
+ %mmt4d = transform.structured.match ops {[" linalg.mmt4d" ]} in %arg1 : (!transform.any_op ) -> !transform.any_op
931
+ transform.structured.vectorize %mmt4d vector_sizes [16 , 16 , 16 , 8 , [4 ], 1 ] {assume_dynamic_dims_match_vec_sizes } : !transform.any_op
932
+ transform.yield
933
+ }
934
+ }
935
+
843
936
///----------------------------------------------------------------------------------------
844
937
/// Tests for other Ops
845
938
///----------------------------------------------------------------------------------------
@@ -1094,30 +1187,6 @@ module attributes {transform.with_named_sequence} {
1094
1187
}
1095
1188
}
1096
1189
1097
- // -----
1098
-
1099
- func.func @mmt4d (%A: memref <16 x16 x8 x1 xf32 >, %B: memref <16 x16 x8 x1 xf32 >, %C_in: memref <16 x16 x8 x8 xf32 >) {
1100
- linalg.mmt4d ins (%A , %B: memref <16 x16 x8 x1 xf32 >, memref <16 x16 x8 x1 xf32 >)
1101
- outs (%C_in: memref <16 x16 x8 x8 xf32 >)
1102
- return
1103
- }
1104
-
1105
- // CHECK-LABEL: func.func @mmt4d(
1106
- // CHECK-SAME: %[[A:.*]]: memref<16x16x8x1xf32>, %[[B:.*]]: memref<16x16x8x1xf32>, %[[C:.*]]: memref<16x16x8x8xf32>) {
1107
- // CHECK: %[[VEC_A:.*]] = vector.transfer_read %[[A]]{{.*}} : memref<16x16x8x1xf32>, vector<16x16x16x8x8x1xf32>
1108
- // CHECK: %[[VEC_B:.*]] = vector.transfer_read %[[B]]{{.*}} : memref<16x16x8x1xf32>, vector<16x16x16x8x8x1xf32>
1109
- // CHECK: %[[VEC_C:.*]] = vector.transfer_read %[[C]]{{.*}} : memref<16x16x8x8xf32>, vector<16x16x8x8xf32>
1110
- // CHECK: %[[MUL:.*]] = arith.mulf %[[VEC_A]], %[[VEC_B]] : vector<16x16x16x8x8x1xf32>
1111
- // CHECK: %[[RED:.*]] = vector.multi_reduction <add>, %[[MUL]], %[[VEC_C]] [2, 5] : vector<16x16x16x8x8x1xf32> to vector<16x16x8x8xf32>
1112
- // CHECK: vector.transfer_write %[[RED]], %[[C]]{{.*}} : vector<16x16x8x8xf32>, memref<16x16x8x8xf32>
1113
-
1114
- module attributes {transform.with_named_sequence } {
1115
- transform.named_sequence @__transform_main (%arg1: !transform.any_op {transform.readonly }) {
1116
- %mmt4d = transform.structured.match ops {[" linalg.mmt4d" ]} in %arg1 : (!transform.any_op ) -> !transform.any_op
1117
- transform.structured.vectorize %mmt4d : !transform.any_op
1118
- transform.yield
1119
- }
1120
- }
1121
1190
1122
1191
// -----
1123
1192
0 commit comments