@@ -1605,63 +1605,49 @@ static bool isMaskTriviallyFoldable(SmallVector<OpFoldResult> &maskSizes,
16051605// / Creates an optionally masked TransferWriteOp
16061606// /
16071607// / Generates the following operation:
1608- // / %res = vector.transfer_write %vectorToStore into %dest
1608+ // / %res = vector.transfer_write %vecToStore into %dest
16091609// /
1610- // / If the leading N dimensions of the vector to store do not match
1611- // / `inputVecSizesForLeadingDims` (N = rank(inputVecSizesForLeadingDims)),
1612- // / masking is applied to ensure correctness:
1610+ // / If shape(vecToStore) != shape(dest), masking is used to ensure correctness:
16131611// /
1614- // / %mask = vector.create_mask(%destShape) : %vectorToStoreShape
1612+ // / %mask = vector.create_mask(%destShape) : %vecToStoreShape
16151613// / %res = vector.mask %mask {
1616- // / vector.transfer_write %vectorToStore into %dest
1614+ // / vector.transfer_write %vecToStore into %dest
16171615// / }
16181616// /
1619- // / The mask shape is identical to `vectorToStore ` (with the element type ==
1617+ // / The mask shape is identical to `vecToStore ` (with the element type ==
16201618// / i1), and the mask values are based on the shape of the `dest` tensor.
16211619// /
16221620// / If `useInBoundsInsteadOfMasking` is set to `true`, the `in_bounds` attribute
16231621// / is used instead of masking:
16241622// /
1625- // / %write = vector.transfer_write %vectorToStore into %dest
1623+ // / %write = vector.transfer_write %vecToStore into %dest
16261624// / in_bounds_flags = (...)
16271625// / %res = vector.transfer_write %input into %dest
16281626// / {in_bounds = in_bounds_flags}
16291627// /
1630- // / `writeIndices` specifies the offsets to use. If empty, all indices are set
1631- // / to 0.
1632- // /
1633- // / NOTE: When N < rank(vectorToStore), the missing vector sizes are taken from
1634- // / `valueToStore`.
1635- // / TODO: `inputVecSizesForLeadingDims` should not be required - these sizes are
1636- // / already provided in `vectorToStore`.
1628+ // / Finally, `writeIndices` specifies the offsets to use. If empty, all indices
1629+ // / are set to 0.
16371630static Operation *
1638- createWriteOrMaskedWrite (OpBuilder &builder, Location loc, Value vectorToStore,
1639- Value dest,
1640- ArrayRef<int64_t > inputVecSizesForLeadingDims,
1641- SmallVector<Value> writeIndices = {},
1631+ createWriteOrMaskedWrite (OpBuilder &builder, Location loc, Value vecToStore,
1632+ Value dest, SmallVector<Value> writeIndices = {},
16421633 bool useInBoundsInsteadOfMasking = false ) {
16431634
16441635 ShapedType destType = cast<ShapedType>(dest.getType ());
16451636 int64_t destRank = destType.getRank ();
16461637 auto destShape = destType.getShape ();
16471638
1648- VectorType vecToStoreType = cast<VectorType>(vectorToStore .getType ());
1639+ VectorType vecToStoreType = cast<VectorType>(vecToStore .getType ());
16491640 int64_t vecToStoreRank = vecToStoreType.getRank ();
16501641 auto vecToStoreShape = vecToStoreType.getShape ();
16511642
16521643 // Compute the in_bounds attribute
16531644 SmallVector<bool > inBoundsVal (vecToStoreRank, true );
16541645 if (useInBoundsInsteadOfMasking) {
1655- // In this case, assume that all the required vector sizes have been
1656- // provided.
1657- assert (inputVecSizesForLeadingDims.size () ==
1658- static_cast <size_t >(vecToStoreType.getRank ()) &&
1659- " Insufficient number of input vector sizes!" );
16601646 // Update the inBounds attribute.
16611647 // FIXME: This computation is too weak - it ignores the write indices.
16621648 for (unsigned i = 0 ; i < vecToStoreRank; i++)
16631649 inBoundsVal[i] =
1664- (destShape[i] >= inputVecSizesForLeadingDims [i]) &&
1650+ (destShape[destRank - vecToStoreRank + i] >= vecToStoreShape [i]) &&
16651651 !ShapedType::isDynamic (destShape[destRank - vecToStoreRank + i]);
16661652 }
16671653
@@ -1677,7 +1663,7 @@ createWriteOrMaskedWrite(OpBuilder &builder, Location loc, Value vectorToStore,
16771663 // Generate the xfer_write Op
16781664 Operation *write =
16791665 builder.create <vector::TransferWriteOp>(loc,
1680- /* vector=*/ vectorToStore ,
1666+ /* vector=*/ vecToStore ,
16811667 /* source=*/ dest,
16821668 /* indices=*/ writeIndices,
16831669 /* inBounds=*/ inBoundsVal);
@@ -1686,46 +1672,25 @@ createWriteOrMaskedWrite(OpBuilder &builder, Location loc, Value vectorToStore,
16861672 if (useInBoundsInsteadOfMasking)
16871673 return write;
16881674
1689- assert (llvm::none_of (
1690- destShape.drop_front (inputVecSizesForLeadingDims.size ()),
1691- [](int64_t size) { return size == ShapedType::kDynamic ; }) &&
1692- " Only dims aligned with inputVecSizesForLeadingDims may be dynamic" );
1693-
1694- // Check if masking is needed.
1695- bool needMaskForWrite =
1696- !llvm::equal (inputVecSizesForLeadingDims,
1697- destShape.take_front (destRank - vecToStoreRank +
1698- inputVecSizesForLeadingDims.size ()));
1699-
1700- // If masking is needed, generate the mask and mask the operation.
1701- if (needMaskForWrite) {
1702- // Get the mask shape + type. Missing mask dimensions are taken from
1703- // `vectorToStore`.
1704- SmallVector<int64_t > writeMaskShape;
1705- writeMaskShape.append (inputVecSizesForLeadingDims.begin (),
1706- inputVecSizesForLeadingDims.end ());
1707- if (vecToStoreRank >
1708- static_cast <int64_t >(inputVecSizesForLeadingDims.size ()))
1709- writeMaskShape.append (vecToStoreShape.begin () +
1710- inputVecSizesForLeadingDims.size (),
1711- vecToStoreShape.end ());
1712- auto writeMaskType = VectorType::get (writeMaskShape, builder.getI1Type ());
1713-
1714- SmallVector<OpFoldResult> destSizes =
1715- tensor::getMixedSizes (builder, loc, dest);
1716- SmallVector<OpFoldResult> maskSizes (destSizes.end () - writeMaskShape.size (),
1717- destSizes.end ());
1718-
1719- if (isMaskTriviallyFoldable (maskSizes, writeIndices, destShape,
1720- writeMaskShape))
1721- return write;
1722-
1723- Value maskForWrite = builder.createOrFold <vector::CreateMaskOp>(
1724- loc, writeMaskType, maskSizes);
1725- write = mlir::vector::maskOperation (builder, write, maskForWrite);
1726- }
1675+ // Check if masking is needed. If not, exit.
1676+ if (llvm::equal (vecToStoreShape, destShape.take_back (vecToStoreRank)))
1677+ return write;
1678+
1679+ // Compute the mask and mask the write Op.
1680+ auto writeMaskType = VectorType::get (vecToStoreShape, builder.getI1Type ());
1681+
1682+ SmallVector<OpFoldResult> destSizes =
1683+ tensor::getMixedSizes (builder, loc, dest);
1684+ SmallVector<OpFoldResult> maskSizes (destSizes.end () - vecToStoreRank,
1685+ destSizes.end ());
1686+
1687+ if (isMaskTriviallyFoldable (maskSizes, writeIndices, destShape,
1688+ vecToStoreShape))
1689+ return write;
17271690
1728- return write;
1691+ Value maskForWrite =
1692+ builder.createOrFold <vector::CreateMaskOp>(loc, writeMaskType, maskSizes);
1693+ return mlir::vector::maskOperation (builder, write, maskForWrite);
17291694}
17301695
17311696// / Vectorize linalg::PackOp with (1) static inner_tiles (2) constant
@@ -1825,9 +1790,8 @@ vectorizeAsTensorPackOp(RewriterBase &rewriter, linalg::PackOp packOp,
18251790 Value dest = rewriter.create <tensor::EmptyOp>(
18261791 loc, reifiedReturnShapes[0 ],
18271792 transposeOp.getResult ().getType ().getElementType ());
1828- Operation *write = createWriteOrMaskedWrite (
1829- rewriter, loc, transposeOp.getResult (), dest,
1830- /* inputVecSizesForLeadingDims=*/ inputVectorSizes);
1793+ Operation *write =
1794+ createWriteOrMaskedWrite (rewriter, loc, transposeOp.getResult (), dest);
18311795 newResults.push_back (write->getResult (0 ));
18321796 return success ();
18331797}
@@ -1965,7 +1929,6 @@ vectorizeAsTensorUnpackOp(RewriterBase &rewriter, linalg::UnPackOp unpackOp,
19651929 shapeCastOp.getResult ().getType ().getElementType ());
19661930 Operation *write = createWriteOrMaskedWrite (
19671931 rewriter, loc, shapeCastOp.getResult (), dest,
1968- /* inputVecSizesForLeadingDims=*/ writeVectorSizes,
19691932 /* writeIndices=*/ {}, useInBoundsInsteadOfMasking);
19701933 newResults.push_back (write->getResult (0 ));
19711934 return success ();
@@ -1998,9 +1961,7 @@ vectorizeAsTensorPadOp(RewriterBase &rewriter, tensor::PadOp padOp,
19981961 // Create Xfer write Op
19991962 Value dest = rewriter.create <tensor::EmptyOp>(
20001963 loc, reifiedReturnShapes[0 ], padOp.getResultType ().getElementType ());
2001- Operation *write = createWriteOrMaskedWrite (
2002- rewriter, loc, maskedRead, dest,
2003- /* inputVecSizesForLeadingDims=*/ inputVectorSizes);
1964+ Operation *write = createWriteOrMaskedWrite (rewriter, loc, maskedRead, dest);
20041965 newResults.push_back (write->getResult (0 ));
20051966 return success ();
20061967}
@@ -3040,9 +3001,9 @@ vectorizeAsInsertSliceOp(RewriterBase &rewriter, tensor::InsertSliceOp sliceOp,
30403001 // Create write
30413002 auto writeIndices =
30423003 getValueOrCreateConstantIndexOp (rewriter, loc, sliceOp.getMixedOffsets ());
3043- Operation *write = createWriteOrMaskedWrite (
3044- rewriter, loc, read, sliceOp.getDest (), vecType. getShape (), writeIndices ,
3045- /* useInBoundsInsteadOfMasking= */ inputVectorSizes.empty ());
3004+ Operation *write =
3005+ createWriteOrMaskedWrite ( rewriter, loc, read, sliceOp.getDest (),
3006+ writeIndices, inputVectorSizes.empty ());
30463007
30473008 // 4. Finalize
30483009 newResults.push_back (write->getResult (0 ));
0 commit comments