Skip to content

Commit 3b11aaa

Browse files
authored
[mlir][linalg] Add support for scalable vectorization of linalg.mmt4d (llvm#146531)
This patch adds support for scalable vectorization of linalg.mmt4d. The key design change is the introduction of a new vectorizer state variable: * `assumeDynamicDimsMatchVecSizes` ...along with the corresponding Transform dialect attribute: * `assume_dynamic_dims_match_vec_sizes`. This flag instructs the vectorizer to assume that dynamic memref/tensor dimensions match the corresponding vector sizes (fixed or scalable). With this assumption, masking becomes unnecessary, which simplifies the lowering pipeline significantly. While this assumption is not universally valid, it typically holds for `linalg.mmt4d`. Inputs and outputs are explicitly packed using `linalg.pack`, and this packing includes padding, ensuring that dimension sizes align with vector sizes (*). * Related discussion: llvm#143920 An upcoming patch will include an end-to-end test that leverages scalable vectorization of linalg.mmt4d to demonstrate the newly enabled functionality. This would not be feasible without the changes introduced here, as it would otherwise require additional logic to handle complex - but ultimately redundant - masks. (*) This holds provided that the tile sizes used for packing match the vector sizes used during vectorization. It is the user’s responsibility to enforce this.
1 parent e73d1a5 commit 3b11aaa

File tree

5 files changed

+143
-44
lines changed

5 files changed

+143
-44
lines changed

mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.td

Lines changed: 5 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -2431,12 +2431,11 @@ def VectorizeOp : Op<Transform_Dialect, "structured.vectorize",
24312431
}];
24322432

24332433
let arguments = (ins TransformHandleTypeInterface:$target,
2434-
Variadic<TransformAnyParamTypeOrAnyHandle>:$vector_sizes,
2435-
DefaultValuedOptionalAttr<DenseI64ArrayAttr, "{}">:
2436-
$static_vector_sizes,
2437-
OptionalAttr<UnitAttr>:$vectorize_nd_extract,
2438-
DefaultValuedOptionalAttr<DenseBoolArrayAttr, "{}">:
2439-
$scalable_sizes);
2434+
Variadic<TransformAnyParamTypeOrAnyHandle>:$vector_sizes,
2435+
DefaultValuedOptionalAttr<DenseI64ArrayAttr, "{}">:$static_vector_sizes,
2436+
OptionalAttr<UnitAttr>:$vectorize_nd_extract,
2437+
OptionalAttr<UnitAttr>:$assume_dynamic_dims_match_vec_sizes,
2438+
DefaultValuedOptionalAttr<DenseBoolArrayAttr, "{}">:$scalable_sizes);
24402439

24412440
let results = (outs);
24422441

mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -880,7 +880,8 @@ FailureOr<VectorizationResult>
880880
vectorize(RewriterBase &rewriter, Operation *op,
881881
ArrayRef<int64_t> inputVectorSizes = {},
882882
ArrayRef<bool> inputScalableVecDims = {},
883-
bool vectorizeNDExtract = false, bool flatten1DDepthwiseConv = false);
883+
bool vectorizeNDExtract = false, bool flatten1DDepthwiseConv = false,
884+
bool assumeDynamicDimsMatchVecSizes = false);
884885

885886
/// Emit a suitable vector form for a Copy op with fully static shape.
886887
LogicalResult vectorizeCopy(RewriterBase &builder, memref::CopyOp copyOp);

mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3920,7 +3920,8 @@ DiagnosedSilenceableFailure transform::VectorizeOp::apply(
39203920
}
39213921
FailureOr<VectorizationResult> vectorResults =
39223922
linalg::vectorize(rewriter, target, vectorSizes, getScalableSizes(),
3923-
getVectorizeNdExtract().value_or(false));
3923+
getVectorizeNdExtract().value_or(false), false,
3924+
getAssumeDynamicDimsMatchVecSizes().value_or(false));
39243925
if (failed(vectorResults)) {
39253926
return mlir::emitSilenceableFailure(target->getLoc())
39263927
<< "Attempted to vectorize, but failed";

mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp

Lines changed: 41 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -219,7 +219,8 @@ struct VectorizationState {
219219
/// canonical vector shape for vectorization.
220220
LogicalResult initState(RewriterBase &rewriter, LinalgOp linalgOp,
221221
ArrayRef<int64_t> inputVectorSizes,
222-
ArrayRef<bool> inputScalableVecDims);
222+
ArrayRef<bool> inputScalableVecDims,
223+
bool assumeDynamicDimsMatchVecSizes = false);
223224

224225
/// Returns the canonical vector shape used to vectorize the iteration space.
225226
ArrayRef<int64_t> getCanonicalVecShape() const { return canonicalVecShape; }
@@ -328,6 +329,14 @@ struct VectorizationState {
328329
/// Global vectorization guard for the incoming rewriter. It's initialized
329330
/// when the vectorization state is initialized.
330331
OpBuilder::InsertionGuard rewriterGuard;
332+
333+
/// Do all dynamic dims match the corresponding vector sizes?
334+
///
335+
/// When a dynamic tensor/memref dimension matches the corresponding vector
336+
/// dimension, masking can be safely skipped, despite the presence of dynamic
337+
/// shapes. Use this flag with care and only for cases where you are
338+
/// confident the assumption holds.
339+
bool assumeDynamicDimsMatchVecSizes = false;
331340
};
332341

333342
LogicalResult
@@ -364,10 +373,12 @@ VectorizationState::precomputeIterSpaceValueSizes(RewriterBase &rewriter,
364373
/// Initializes the vectorization state, including the computation of the
365374
/// canonical vector shape for vectorization.
366375
// TODO: Move this to the constructor when we can remove the failure cases.
367-
LogicalResult
368-
VectorizationState::initState(RewriterBase &rewriter, LinalgOp linalgOp,
369-
ArrayRef<int64_t> inputVectorSizes,
370-
ArrayRef<bool> inputScalableVecDims) {
376+
LogicalResult VectorizationState::initState(RewriterBase &rewriter,
377+
LinalgOp linalgOp,
378+
ArrayRef<int64_t> inputVectorSizes,
379+
ArrayRef<bool> inputScalableVecDims,
380+
bool assumeDimsMatchVec) {
381+
assumeDynamicDimsMatchVecSizes = assumeDimsMatchVec;
371382
// Initialize the insertion point.
372383
rewriter.setInsertionPoint(linalgOp);
373384

@@ -467,6 +478,23 @@ Value VectorizationState::getOrCreateMaskFor(
467478
return Value();
468479
}
469480

481+
if (assumeDynamicDimsMatchVecSizes) {
482+
// While for _dynamic_ dim sizes we can _assume_ that the corresponding
483+
// vector sizes match, we still need to check the _static_ dim sizes. Only
484+
// then we can be 100% sure that masking is not required.
485+
if (llvm::all_of(llvm::zip(permutedStaticSizes, maskType.getShape()),
486+
[](auto it) {
487+
return std::get<0>(it) == ShapedType::kDynamic
488+
? true
489+
: std::get<0>(it) == std::get<1>(it);
490+
})) {
491+
LDBG("Dynamic + static dimensions match vector sizes, masking is not "
492+
"required.\n");
493+
activeMaskCache[maskingMap] = Value();
494+
return Value();
495+
}
496+
}
497+
470498
// Permute the iteration space value sizes to compute the mask upper bounds.
471499
SmallVector<Value> upperBounds =
472500
applyPermutationMap(maskingMap, ArrayRef<Value>(iterSpaceValueSizes));
@@ -2469,7 +2497,8 @@ vectorizeScalableVectorPrecondition(Operation *op,
24692497
return success(isElementwise(linalgOp) || isa<linalg::MatmulOp>(op) ||
24702498
isa<linalg::MatmulTransposeAOp>(op) ||
24712499
isa<linalg::DepthwiseConv1DNwcWcOp>(op) ||
2472-
isa<linalg::MatvecOp>(op) || hasReductionIterator(linalgOp));
2500+
isa<linalg::MatvecOp>(op) || isa<linalg::Mmt4DOp>(op) ||
2501+
hasReductionIterator(linalgOp));
24732502
}
24742503

24752504
LogicalResult mlir::linalg::vectorizeOpPrecondition(
@@ -2525,11 +2554,10 @@ bool mlir::linalg::hasVectorizationImpl(Operation *op) {
25252554
tensor::InsertSliceOp>(op);
25262555
}
25272556

2528-
FailureOr<VectorizationResult>
2529-
mlir::linalg::vectorize(RewriterBase &rewriter, Operation *op,
2530-
ArrayRef<int64_t> inputVectorSizes,
2531-
ArrayRef<bool> inputScalableVecDims,
2532-
bool vectorizeNDExtract, bool flatten1DDepthwiseConv) {
2557+
FailureOr<VectorizationResult> mlir::linalg::vectorize(
2558+
RewriterBase &rewriter, Operation *op, ArrayRef<int64_t> inputVectorSizes,
2559+
ArrayRef<bool> inputScalableVecDims, bool vectorizeNDExtract,
2560+
bool flatten1DDepthwiseConv, bool assumeDynamicDimsMatchVecSizes) {
25332561
LDBG("Attempting to vectorize:\n" << *op << "\n");
25342562
LDBG("Input vector sizes: ");
25352563
LLVM_DEBUG(llvm::interleaveComma(inputVectorSizes, llvm::dbgs()));
@@ -2549,7 +2577,8 @@ mlir::linalg::vectorize(RewriterBase &rewriter, Operation *op,
25492577
VectorizationState state(rewriter);
25502578
if (auto linalgOp = dyn_cast<linalg::LinalgOp>(op)) {
25512579
if (failed(state.initState(rewriter, linalgOp, inputVectorSizes,
2552-
inputScalableVecDims))) {
2580+
inputScalableVecDims,
2581+
assumeDynamicDimsMatchVecSizes))) {
25532582
LDBG("Vectorization state couldn't be initialized\n");
25542583
return failure();
25552584
}

mlir/test/Dialect/Linalg/vectorization/linalg-ops.mlir

Lines changed: 93 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -840,6 +840,99 @@ module attributes {transform.with_named_sequence} {
840840
}
841841
}
842842

843+
// -----
844+
845+
///----------------------------------------------------------------------------------------
846+
/// Tests for linalg.mmt4d
847+
///----------------------------------------------------------------------------------------
848+
849+
func.func @mmt4d(%A: memref<16x16x8x1xf32>, %B: memref<16x16x8x1xf32>, %C_in: memref<16x16x8x8xf32>) {
850+
linalg.mmt4d ins(%A, %B: memref<16x16x8x1xf32>, memref<16x16x8x1xf32>)
851+
outs(%C_in: memref<16x16x8x8xf32>)
852+
return
853+
}
854+
855+
// CHECK-LABEL: func.func @mmt4d(
856+
// CHECK-SAME: %[[A:.*]]: memref<16x16x8x1xf32>, %[[B:.*]]: memref<16x16x8x1xf32>, %[[C:.*]]: memref<16x16x8x8xf32>) {
857+
// CHECK: %[[VEC_A:.*]] = vector.transfer_read %[[A]]{{.*}} : memref<16x16x8x1xf32>, vector<16x16x16x8x8x1xf32>
858+
// CHECK: %[[VEC_B:.*]] = vector.transfer_read %[[B]]{{.*}} : memref<16x16x8x1xf32>, vector<16x16x16x8x8x1xf32>
859+
// CHECK: %[[VEC_C:.*]] = vector.transfer_read %[[C]]{{.*}} : memref<16x16x8x8xf32>, vector<16x16x8x8xf32>
860+
// CHECK: %[[MUL:.*]] = arith.mulf %[[VEC_A]], %[[VEC_B]] : vector<16x16x16x8x8x1xf32>
861+
// CHECK: %[[RED:.*]] = vector.multi_reduction <add>, %[[MUL]], %[[VEC_C]] [2, 5] : vector<16x16x16x8x8x1xf32> to vector<16x16x8x8xf32>
862+
// CHECK: vector.transfer_write %[[RED]], %[[C]]{{.*}} : vector<16x16x8x8xf32>, memref<16x16x8x8xf32>
863+
864+
module attributes {transform.with_named_sequence} {
865+
transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) {
866+
%mmt4d = transform.structured.match ops{["linalg.mmt4d"]} in %arg1 : (!transform.any_op) -> !transform.any_op
867+
transform.structured.vectorize %mmt4d : !transform.any_op
868+
transform.yield
869+
}
870+
}
871+
872+
// -----
873+
874+
func.func @mmt4d_scalable(%A: memref<16x16x8x1xf32>, %B: memref<16x16x?x1xf32>, %C_in: memref<16x16x8x?xf32>) {
875+
linalg.mmt4d ins(%A, %B: memref<16x16x8x1xf32>, memref<16x16x?x1xf32>)
876+
outs(%C_in: memref<16x16x8x?xf32>)
877+
return
878+
}
879+
// CHECK-LABEL: func.func @mmt4d_scalable(
880+
// CHECK-SAME: %[[A:.*]]: memref<16x16x8x1xf32>,
881+
// CHECK-SAME: %[[B:.*]]: memref<16x16x?x1xf32>,
882+
// CHECK-SAME: %[[C_IN:.*]]: memref<16x16x8x?xf32>) {
883+
// CHECK: %[[VAL_0:.*]] = arith.constant 16 : index
884+
// CHECK: %[[VAL_1:.*]] = arith.constant 16 : index
885+
// CHECK: %[[VAL_2:.*]] = arith.constant 16 : index
886+
// CHECK: %[[C8:.*]] = arith.constant 8 : index
887+
// CHECK: %[[C2:.*]] = arith.constant 2 : index
888+
// CHECK: %[[DIM_2:.*]] = memref.dim %[[B]], %[[C2]] : memref<16x16x?x1xf32>
889+
// CHECK: %[[VAL_6:.*]] = arith.constant 1 : index
890+
// CHECK: %[[VEC_A:.*]] = vector.transfer_read %[[A]]{{.*}} : memref<16x16x8x1xf32>, vector<16x16x16x8x[4]x1xf32>
891+
// CHECK: %[[MASK_1:.*]] = vector.create_mask %[[VAL_1]], %[[VAL_2]], %[[DIM_2]], %[[VAL_6]] : vector<16x16x[4]x1xi1>
892+
// CHECK: %[[VEC_B:.*]] = vector.mask %[[MASK_1]] { vector.transfer_read %[[B]]{{.*}} : memref<16x16x?x1xf32>, vector<16x16x16x8x[4]x1xf32> } : vector<16x16x[4]x1xi1> -> vector<16x16x16x8x[4]x1xf32>
893+
// CHECK: %[[MASK_2:.*]] = vector.create_mask %[[VAL_0]], %[[VAL_1]], %[[C8]], %[[DIM_2]] : vector<16x16x8x[4]xi1>
894+
// CHECK: %[[VAL_15:.*]] = vector.mask %[[MASK_2]] { vector.transfer_read %[[C_IN]]{{.*}} : memref<16x16x8x?xf32>, vector<16x16x8x[4]xf32> } : vector<16x16x8x[4]xi1> -> vector<16x16x8x[4]xf32>
895+
// CHECK: %[[VAL_16:.*]] = arith.mulf %[[VEC_A]], %[[VEC_B]] : vector<16x16x16x8x[4]x1xf32>
896+
// CHECK: %[[MASK_3:.*]] = vector.create_mask %[[VAL_0]], %[[VAL_1]], %[[VAL_2]], %[[C8]], %[[DIM_2]], %[[VAL_6]] : vector<16x16x16x8x[4]x1xi1>
897+
// CHECK: %[[VAL_18:.*]] = vector.mask %[[MASK_3]] { vector.multi_reduction <add>, %[[VAL_16]], %[[VAL_15]] [2, 5] : vector<16x16x16x8x[4]x1xf32> to vector<16x16x8x[4]xf32> } : vector<16x16x16x8x[4]x1xi1> -> vector<16x16x8x[4]xf32>
898+
// CHECK: vector.mask %[[MASK_2]] { vector.transfer_write %[[VAL_18]], %[[C_IN]]{{.*}} : vector<16x16x8x[4]xf32>, memref<16x16x8x?xf32> } : vector<16x16x8x[4]xi1>
899+
900+
901+
module attributes {transform.with_named_sequence} {
902+
transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) {
903+
%mmt4d = transform.structured.match ops{["linalg.mmt4d"]} in %arg1 : (!transform.any_op) -> !transform.any_op
904+
transform.structured.vectorize %mmt4d vector_sizes [16, 16, 16, 8, [4], 1] : !transform.any_op
905+
transform.yield
906+
}
907+
}
908+
909+
// -----
910+
911+
func.func @mmt4d_scalable_with_assume(%A: memref<16x16x8x1xf32>, %B: memref<16x16x?x1xf32>, %C_in: memref<16x16x8x?xf32>) {
912+
linalg.mmt4d ins(%A, %B: memref<16x16x8x1xf32>, memref<16x16x?x1xf32>)
913+
outs(%C_in: memref<16x16x8x?xf32>)
914+
return
915+
}
916+
// CHECK-LABEL: func.func @mmt4d_scalable_with_assume(
917+
// CHECK-SAME: %[[A:.*]]: memref<16x16x8x1xf32>,
918+
// CHECK-SAME: %[[B:.*]]: memref<16x16x?x1xf32>,
919+
// CHECK-SAME: %[[C_IN:.*]]: memref<16x16x8x?xf32>) {
920+
// CHECK-NOT: mask
921+
// CHECK: %[[VEC_A:.*]] = vector.transfer_read %[[A]]{{.*}} : memref<16x16x8x1xf32>, vector<16x16x16x8x[4]x1xf32>
922+
// CHECK: %[[VEC_B:.*]] = vector.transfer_read %[[B]]{{.*}} : memref<16x16x?x1xf32>, vector<16x16x16x8x[4]x1xf32>
923+
// CHECK: %[[VAL_13:.*]] = vector.transfer_read %[[C_IN]]{{.*}} : memref<16x16x8x?xf32>, vector<16x16x8x[4]xf32>
924+
// CHECK: %[[VAL_14:.*]] = arith.mulf %[[VEC_A]], %[[VEC_B]] : vector<16x16x16x8x[4]x1xf32>
925+
// CHECK: %[[VAL_15:.*]] = vector.multi_reduction <add>, %[[VAL_14]], %[[VAL_13]] [2, 5] : vector<16x16x16x8x[4]x1xf32> to vector<16x16x8x[4]xf32>
926+
// CHECK: vector.transfer_write %[[VAL_15]], %[[C_IN]]{{.*}} : vector<16x16x8x[4]xf32>, memref<16x16x8x?xf32>
927+
928+
module attributes {transform.with_named_sequence} {
929+
transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) {
930+
%mmt4d = transform.structured.match ops{["linalg.mmt4d"]} in %arg1 : (!transform.any_op) -> !transform.any_op
931+
transform.structured.vectorize %mmt4d vector_sizes [16, 16, 16, 8, [4], 1] {assume_dynamic_dims_match_vec_sizes} : !transform.any_op
932+
transform.yield
933+
}
934+
}
935+
843936
///----------------------------------------------------------------------------------------
844937
/// Tests for other Ops
845938
///----------------------------------------------------------------------------------------
@@ -1094,30 +1187,6 @@ module attributes {transform.with_named_sequence} {
10941187
}
10951188
}
10961189

1097-
// -----
1098-
1099-
func.func @mmt4d(%A: memref<16x16x8x1xf32>, %B: memref<16x16x8x1xf32>, %C_in: memref<16x16x8x8xf32>) {
1100-
linalg.mmt4d ins(%A, %B: memref<16x16x8x1xf32>, memref<16x16x8x1xf32>)
1101-
outs(%C_in: memref<16x16x8x8xf32>)
1102-
return
1103-
}
1104-
1105-
// CHECK-LABEL: func.func @mmt4d(
1106-
// CHECK-SAME: %[[A:.*]]: memref<16x16x8x1xf32>, %[[B:.*]]: memref<16x16x8x1xf32>, %[[C:.*]]: memref<16x16x8x8xf32>) {
1107-
// CHECK: %[[VEC_A:.*]] = vector.transfer_read %[[A]]{{.*}} : memref<16x16x8x1xf32>, vector<16x16x16x8x8x1xf32>
1108-
// CHECK: %[[VEC_B:.*]] = vector.transfer_read %[[B]]{{.*}} : memref<16x16x8x1xf32>, vector<16x16x16x8x8x1xf32>
1109-
// CHECK: %[[VEC_C:.*]] = vector.transfer_read %[[C]]{{.*}} : memref<16x16x8x8xf32>, vector<16x16x8x8xf32>
1110-
// CHECK: %[[MUL:.*]] = arith.mulf %[[VEC_A]], %[[VEC_B]] : vector<16x16x16x8x8x1xf32>
1111-
// CHECK: %[[RED:.*]] = vector.multi_reduction <add>, %[[MUL]], %[[VEC_C]] [2, 5] : vector<16x16x16x8x8x1xf32> to vector<16x16x8x8xf32>
1112-
// CHECK: vector.transfer_write %[[RED]], %[[C]]{{.*}} : vector<16x16x8x8xf32>, memref<16x16x8x8xf32>
1113-
1114-
module attributes {transform.with_named_sequence} {
1115-
transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) {
1116-
%mmt4d = transform.structured.match ops{["linalg.mmt4d"]} in %arg1 : (!transform.any_op) -> !transform.any_op
1117-
transform.structured.vectorize %mmt4d : !transform.any_op
1118-
transform.yield
1119-
}
1120-
}
11211190

11221191
// -----
11231192

0 commit comments

Comments
 (0)