Skip to content

Commit 34ec4cf

Browse files
author
git apple-llvm automerger
committed
Merge commit '3b11aaaf94fe' from llvm.org/main into next
2 parents b2cedcd + 3b11aaa commit 34ec4cf

File tree

5 files changed

+143
-44
lines changed

5 files changed

+143
-44
lines changed

mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.td

Lines changed: 5 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -2431,12 +2431,11 @@ def VectorizeOp : Op<Transform_Dialect, "structured.vectorize",
24312431
}];
24322432

24332433
let arguments = (ins TransformHandleTypeInterface:$target,
2434-
Variadic<TransformAnyParamTypeOrAnyHandle>:$vector_sizes,
2435-
DefaultValuedOptionalAttr<DenseI64ArrayAttr, "{}">:
2436-
$static_vector_sizes,
2437-
OptionalAttr<UnitAttr>:$vectorize_nd_extract,
2438-
DefaultValuedOptionalAttr<DenseBoolArrayAttr, "{}">:
2439-
$scalable_sizes);
2434+
Variadic<TransformAnyParamTypeOrAnyHandle>:$vector_sizes,
2435+
DefaultValuedOptionalAttr<DenseI64ArrayAttr, "{}">:$static_vector_sizes,
2436+
OptionalAttr<UnitAttr>:$vectorize_nd_extract,
2437+
OptionalAttr<UnitAttr>:$assume_dynamic_dims_match_vec_sizes,
2438+
DefaultValuedOptionalAttr<DenseBoolArrayAttr, "{}">:$scalable_sizes);
24402439

24412440
let results = (outs);
24422441

mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -880,7 +880,8 @@ FailureOr<VectorizationResult>
880880
vectorize(RewriterBase &rewriter, Operation *op,
881881
ArrayRef<int64_t> inputVectorSizes = {},
882882
ArrayRef<bool> inputScalableVecDims = {},
883-
bool vectorizeNDExtract = false, bool flatten1DDepthwiseConv = false);
883+
bool vectorizeNDExtract = false, bool flatten1DDepthwiseConv = false,
884+
bool assumeDynamicDimsMatchVecSizes = false);
884885

885886
/// Emit a suitable vector form for a Copy op with fully static shape.
886887
LogicalResult vectorizeCopy(RewriterBase &builder, memref::CopyOp copyOp);

mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3920,7 +3920,8 @@ DiagnosedSilenceableFailure transform::VectorizeOp::apply(
39203920
}
39213921
FailureOr<VectorizationResult> vectorResults =
39223922
linalg::vectorize(rewriter, target, vectorSizes, getScalableSizes(),
3923-
getVectorizeNdExtract().value_or(false));
3923+
getVectorizeNdExtract().value_or(false), false,
3924+
getAssumeDynamicDimsMatchVecSizes().value_or(false));
39243925
if (failed(vectorResults)) {
39253926
return mlir::emitSilenceableFailure(target->getLoc())
39263927
<< "Attempted to vectorize, but failed";

mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp

Lines changed: 41 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -219,7 +219,8 @@ struct VectorizationState {
219219
/// canonical vector shape for vectorization.
220220
LogicalResult initState(RewriterBase &rewriter, LinalgOp linalgOp,
221221
ArrayRef<int64_t> inputVectorSizes,
222-
ArrayRef<bool> inputScalableVecDims);
222+
ArrayRef<bool> inputScalableVecDims,
223+
bool assumeDynamicDimsMatchVecSizes = false);
223224

224225
/// Returns the canonical vector shape used to vectorize the iteration space.
225226
ArrayRef<int64_t> getCanonicalVecShape() const { return canonicalVecShape; }
@@ -328,6 +329,14 @@ struct VectorizationState {
328329
/// Global vectorization guard for the incoming rewriter. It's initialized
329330
/// when the vectorization state is initialized.
330331
OpBuilder::InsertionGuard rewriterGuard;
332+
333+
/// Do all dynamic dims match the corresponding vector sizes?
334+
///
335+
/// When a dynamic tensor/memref dimension matches the corresponding vector
336+
/// dimension, masking can be safely skipped, despite the presence of dynamic
337+
/// shapes. Use this flag with care and only for cases where you are
338+
/// confident the assumption holds.
339+
bool assumeDynamicDimsMatchVecSizes = false;
331340
};
332341

333342
LogicalResult
@@ -364,10 +373,12 @@ VectorizationState::precomputeIterSpaceValueSizes(RewriterBase &rewriter,
364373
/// Initializes the vectorization state, including the computation of the
365374
/// canonical vector shape for vectorization.
366375
// TODO: Move this to the constructor when we can remove the failure cases.
367-
LogicalResult
368-
VectorizationState::initState(RewriterBase &rewriter, LinalgOp linalgOp,
369-
ArrayRef<int64_t> inputVectorSizes,
370-
ArrayRef<bool> inputScalableVecDims) {
376+
LogicalResult VectorizationState::initState(RewriterBase &rewriter,
377+
LinalgOp linalgOp,
378+
ArrayRef<int64_t> inputVectorSizes,
379+
ArrayRef<bool> inputScalableVecDims,
380+
bool assumeDimsMatchVec) {
381+
assumeDynamicDimsMatchVecSizes = assumeDimsMatchVec;
371382
// Initialize the insertion point.
372383
rewriter.setInsertionPoint(linalgOp);
373384

@@ -467,6 +478,23 @@ Value VectorizationState::getOrCreateMaskFor(
467478
return Value();
468479
}
469480

481+
if (assumeDynamicDimsMatchVecSizes) {
482+
// While for _dynamic_ dim sizes we can _assume_ that the corresponding
483+
// vector sizes match, we still need to check the _static_ dim sizes. Only
484+
// then we can be 100% sure that masking is not required.
485+
if (llvm::all_of(llvm::zip(permutedStaticSizes, maskType.getShape()),
486+
[](auto it) {
487+
return std::get<0>(it) == ShapedType::kDynamic
488+
? true
489+
: std::get<0>(it) == std::get<1>(it);
490+
})) {
491+
LDBG("Dynamic + static dimensions match vector sizes, masking is not "
492+
"required.\n");
493+
activeMaskCache[maskingMap] = Value();
494+
return Value();
495+
}
496+
}
497+
470498
// Permute the iteration space value sizes to compute the mask upper bounds.
471499
SmallVector<Value> upperBounds =
472500
applyPermutationMap(maskingMap, ArrayRef<Value>(iterSpaceValueSizes));
@@ -2469,7 +2497,8 @@ vectorizeScalableVectorPrecondition(Operation *op,
24692497
return success(isElementwise(linalgOp) || isa<linalg::MatmulOp>(op) ||
24702498
isa<linalg::MatmulTransposeAOp>(op) ||
24712499
isa<linalg::DepthwiseConv1DNwcWcOp>(op) ||
2472-
isa<linalg::MatvecOp>(op) || hasReductionIterator(linalgOp));
2500+
isa<linalg::MatvecOp>(op) || isa<linalg::Mmt4DOp>(op) ||
2501+
hasReductionIterator(linalgOp));
24732502
}
24742503

24752504
LogicalResult mlir::linalg::vectorizeOpPrecondition(
@@ -2525,11 +2554,10 @@ bool mlir::linalg::hasVectorizationImpl(Operation *op) {
25252554
tensor::InsertSliceOp>(op);
25262555
}
25272556

2528-
FailureOr<VectorizationResult>
2529-
mlir::linalg::vectorize(RewriterBase &rewriter, Operation *op,
2530-
ArrayRef<int64_t> inputVectorSizes,
2531-
ArrayRef<bool> inputScalableVecDims,
2532-
bool vectorizeNDExtract, bool flatten1DDepthwiseConv) {
2557+
FailureOr<VectorizationResult> mlir::linalg::vectorize(
2558+
RewriterBase &rewriter, Operation *op, ArrayRef<int64_t> inputVectorSizes,
2559+
ArrayRef<bool> inputScalableVecDims, bool vectorizeNDExtract,
2560+
bool flatten1DDepthwiseConv, bool assumeDynamicDimsMatchVecSizes) {
25332561
LDBG("Attempting to vectorize:\n" << *op << "\n");
25342562
LDBG("Input vector sizes: ");
25352563
LLVM_DEBUG(llvm::interleaveComma(inputVectorSizes, llvm::dbgs()));
@@ -2549,7 +2577,8 @@ mlir::linalg::vectorize(RewriterBase &rewriter, Operation *op,
25492577
VectorizationState state(rewriter);
25502578
if (auto linalgOp = dyn_cast<linalg::LinalgOp>(op)) {
25512579
if (failed(state.initState(rewriter, linalgOp, inputVectorSizes,
2552-
inputScalableVecDims))) {
2580+
inputScalableVecDims,
2581+
assumeDynamicDimsMatchVecSizes))) {
25532582
LDBG("Vectorization state couldn't be initialized\n");
25542583
return failure();
25552584
}

mlir/test/Dialect/Linalg/vectorization/linalg-ops.mlir

Lines changed: 93 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -840,6 +840,99 @@ module attributes {transform.with_named_sequence} {
840840
}
841841
}
842842

843+
// -----
844+
845+
///----------------------------------------------------------------------------------------
846+
/// Tests for linalg.mmt4d
847+
///----------------------------------------------------------------------------------------
848+
849+
func.func @mmt4d(%A: memref<16x16x8x1xf32>, %B: memref<16x16x8x1xf32>, %C_in: memref<16x16x8x8xf32>) {
850+
linalg.mmt4d ins(%A, %B: memref<16x16x8x1xf32>, memref<16x16x8x1xf32>)
851+
outs(%C_in: memref<16x16x8x8xf32>)
852+
return
853+
}
854+
855+
// CHECK-LABEL: func.func @mmt4d(
856+
// CHECK-SAME: %[[A:.*]]: memref<16x16x8x1xf32>, %[[B:.*]]: memref<16x16x8x1xf32>, %[[C:.*]]: memref<16x16x8x8xf32>) {
857+
// CHECK: %[[VEC_A:.*]] = vector.transfer_read %[[A]]{{.*}} : memref<16x16x8x1xf32>, vector<16x16x16x8x8x1xf32>
858+
// CHECK: %[[VEC_B:.*]] = vector.transfer_read %[[B]]{{.*}} : memref<16x16x8x1xf32>, vector<16x16x16x8x8x1xf32>
859+
// CHECK: %[[VEC_C:.*]] = vector.transfer_read %[[C]]{{.*}} : memref<16x16x8x8xf32>, vector<16x16x8x8xf32>
860+
// CHECK: %[[MUL:.*]] = arith.mulf %[[VEC_A]], %[[VEC_B]] : vector<16x16x16x8x8x1xf32>
861+
// CHECK: %[[RED:.*]] = vector.multi_reduction <add>, %[[MUL]], %[[VEC_C]] [2, 5] : vector<16x16x16x8x8x1xf32> to vector<16x16x8x8xf32>
862+
// CHECK: vector.transfer_write %[[RED]], %[[C]]{{.*}} : vector<16x16x8x8xf32>, memref<16x16x8x8xf32>
863+
864+
module attributes {transform.with_named_sequence} {
865+
transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) {
866+
%mmt4d = transform.structured.match ops{["linalg.mmt4d"]} in %arg1 : (!transform.any_op) -> !transform.any_op
867+
transform.structured.vectorize %mmt4d : !transform.any_op
868+
transform.yield
869+
}
870+
}
871+
872+
// -----
873+
874+
func.func @mmt4d_scalable(%A: memref<16x16x8x1xf32>, %B: memref<16x16x?x1xf32>, %C_in: memref<16x16x8x?xf32>) {
875+
linalg.mmt4d ins(%A, %B: memref<16x16x8x1xf32>, memref<16x16x?x1xf32>)
876+
outs(%C_in: memref<16x16x8x?xf32>)
877+
return
878+
}
879+
// CHECK-LABEL: func.func @mmt4d_scalable(
880+
// CHECK-SAME: %[[A:.*]]: memref<16x16x8x1xf32>,
881+
// CHECK-SAME: %[[B:.*]]: memref<16x16x?x1xf32>,
882+
// CHECK-SAME: %[[C_IN:.*]]: memref<16x16x8x?xf32>) {
883+
// CHECK: %[[VAL_0:.*]] = arith.constant 16 : index
884+
// CHECK: %[[VAL_1:.*]] = arith.constant 16 : index
885+
// CHECK: %[[VAL_2:.*]] = arith.constant 16 : index
886+
// CHECK: %[[C8:.*]] = arith.constant 8 : index
887+
// CHECK: %[[C2:.*]] = arith.constant 2 : index
888+
// CHECK: %[[DIM_2:.*]] = memref.dim %[[B]], %[[C2]] : memref<16x16x?x1xf32>
889+
// CHECK: %[[VAL_6:.*]] = arith.constant 1 : index
890+
// CHECK: %[[VEC_A:.*]] = vector.transfer_read %[[A]]{{.*}} : memref<16x16x8x1xf32>, vector<16x16x16x8x[4]x1xf32>
891+
// CHECK: %[[MASK_1:.*]] = vector.create_mask %[[VAL_1]], %[[VAL_2]], %[[DIM_2]], %[[VAL_6]] : vector<16x16x[4]x1xi1>
892+
// CHECK: %[[VEC_B:.*]] = vector.mask %[[MASK_1]] { vector.transfer_read %[[B]]{{.*}} : memref<16x16x?x1xf32>, vector<16x16x16x8x[4]x1xf32> } : vector<16x16x[4]x1xi1> -> vector<16x16x16x8x[4]x1xf32>
893+
// CHECK: %[[MASK_2:.*]] = vector.create_mask %[[VAL_0]], %[[VAL_1]], %[[C8]], %[[DIM_2]] : vector<16x16x8x[4]xi1>
894+
// CHECK: %[[VAL_15:.*]] = vector.mask %[[MASK_2]] { vector.transfer_read %[[C_IN]]{{.*}} : memref<16x16x8x?xf32>, vector<16x16x8x[4]xf32> } : vector<16x16x8x[4]xi1> -> vector<16x16x8x[4]xf32>
895+
// CHECK: %[[VAL_16:.*]] = arith.mulf %[[VEC_A]], %[[VEC_B]] : vector<16x16x16x8x[4]x1xf32>
896+
// CHECK: %[[MASK_3:.*]] = vector.create_mask %[[VAL_0]], %[[VAL_1]], %[[VAL_2]], %[[C8]], %[[DIM_2]], %[[VAL_6]] : vector<16x16x16x8x[4]x1xi1>
897+
// CHECK: %[[VAL_18:.*]] = vector.mask %[[MASK_3]] { vector.multi_reduction <add>, %[[VAL_16]], %[[VAL_15]] [2, 5] : vector<16x16x16x8x[4]x1xf32> to vector<16x16x8x[4]xf32> } : vector<16x16x16x8x[4]x1xi1> -> vector<16x16x8x[4]xf32>
898+
// CHECK: vector.mask %[[MASK_2]] { vector.transfer_write %[[VAL_18]], %[[C_IN]]{{.*}} : vector<16x16x8x[4]xf32>, memref<16x16x8x?xf32> } : vector<16x16x8x[4]xi1>
899+
900+
901+
module attributes {transform.with_named_sequence} {
902+
transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) {
903+
%mmt4d = transform.structured.match ops{["linalg.mmt4d"]} in %arg1 : (!transform.any_op) -> !transform.any_op
904+
transform.structured.vectorize %mmt4d vector_sizes [16, 16, 16, 8, [4], 1] : !transform.any_op
905+
transform.yield
906+
}
907+
}
908+
909+
// -----
910+
911+
func.func @mmt4d_scalable_with_assume(%A: memref<16x16x8x1xf32>, %B: memref<16x16x?x1xf32>, %C_in: memref<16x16x8x?xf32>) {
912+
linalg.mmt4d ins(%A, %B: memref<16x16x8x1xf32>, memref<16x16x?x1xf32>)
913+
outs(%C_in: memref<16x16x8x?xf32>)
914+
return
915+
}
916+
// CHECK-LABEL: func.func @mmt4d_scalable_with_assume(
917+
// CHECK-SAME: %[[A:.*]]: memref<16x16x8x1xf32>,
918+
// CHECK-SAME: %[[B:.*]]: memref<16x16x?x1xf32>,
919+
// CHECK-SAME: %[[C_IN:.*]]: memref<16x16x8x?xf32>) {
920+
// CHECK-NOT: mask
921+
// CHECK: %[[VEC_A:.*]] = vector.transfer_read %[[A]]{{.*}} : memref<16x16x8x1xf32>, vector<16x16x16x8x[4]x1xf32>
922+
// CHECK: %[[VEC_B:.*]] = vector.transfer_read %[[B]]{{.*}} : memref<16x16x?x1xf32>, vector<16x16x16x8x[4]x1xf32>
923+
// CHECK: %[[VAL_13:.*]] = vector.transfer_read %[[C_IN]]{{.*}} : memref<16x16x8x?xf32>, vector<16x16x8x[4]xf32>
924+
// CHECK: %[[VAL_14:.*]] = arith.mulf %[[VEC_A]], %[[VEC_B]] : vector<16x16x16x8x[4]x1xf32>
925+
// CHECK: %[[VAL_15:.*]] = vector.multi_reduction <add>, %[[VAL_14]], %[[VAL_13]] [2, 5] : vector<16x16x16x8x[4]x1xf32> to vector<16x16x8x[4]xf32>
926+
// CHECK: vector.transfer_write %[[VAL_15]], %[[C_IN]]{{.*}} : vector<16x16x8x[4]xf32>, memref<16x16x8x?xf32>
927+
928+
module attributes {transform.with_named_sequence} {
929+
transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) {
930+
%mmt4d = transform.structured.match ops{["linalg.mmt4d"]} in %arg1 : (!transform.any_op) -> !transform.any_op
931+
transform.structured.vectorize %mmt4d vector_sizes [16, 16, 16, 8, [4], 1] {assume_dynamic_dims_match_vec_sizes} : !transform.any_op
932+
transform.yield
933+
}
934+
}
935+
843936
///----------------------------------------------------------------------------------------
844937
/// Tests for other Ops
845938
///----------------------------------------------------------------------------------------
@@ -1094,30 +1187,6 @@ module attributes {transform.with_named_sequence} {
10941187
}
10951188
}
10961189

1097-
// -----
1098-
1099-
func.func @mmt4d(%A: memref<16x16x8x1xf32>, %B: memref<16x16x8x1xf32>, %C_in: memref<16x16x8x8xf32>) {
1100-
linalg.mmt4d ins(%A, %B: memref<16x16x8x1xf32>, memref<16x16x8x1xf32>)
1101-
outs(%C_in: memref<16x16x8x8xf32>)
1102-
return
1103-
}
1104-
1105-
// CHECK-LABEL: func.func @mmt4d(
1106-
// CHECK-SAME: %[[A:.*]]: memref<16x16x8x1xf32>, %[[B:.*]]: memref<16x16x8x1xf32>, %[[C:.*]]: memref<16x16x8x8xf32>) {
1107-
// CHECK: %[[VEC_A:.*]] = vector.transfer_read %[[A]]{{.*}} : memref<16x16x8x1xf32>, vector<16x16x16x8x8x1xf32>
1108-
// CHECK: %[[VEC_B:.*]] = vector.transfer_read %[[B]]{{.*}} : memref<16x16x8x1xf32>, vector<16x16x16x8x8x1xf32>
1109-
// CHECK: %[[VEC_C:.*]] = vector.transfer_read %[[C]]{{.*}} : memref<16x16x8x8xf32>, vector<16x16x8x8xf32>
1110-
// CHECK: %[[MUL:.*]] = arith.mulf %[[VEC_A]], %[[VEC_B]] : vector<16x16x16x8x8x1xf32>
1111-
// CHECK: %[[RED:.*]] = vector.multi_reduction <add>, %[[MUL]], %[[VEC_C]] [2, 5] : vector<16x16x16x8x8x1xf32> to vector<16x16x8x8xf32>
1112-
// CHECK: vector.transfer_write %[[RED]], %[[C]]{{.*}} : vector<16x16x8x8xf32>, memref<16x16x8x8xf32>
1113-
1114-
module attributes {transform.with_named_sequence} {
1115-
transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) {
1116-
%mmt4d = transform.structured.match ops{["linalg.mmt4d"]} in %arg1 : (!transform.any_op) -> !transform.any_op
1117-
transform.structured.vectorize %mmt4d : !transform.any_op
1118-
transform.yield
1119-
}
1120-
}
11211190

11221191
// -----
11231192

0 commit comments

Comments
 (0)