@@ -354,6 +354,35 @@ bool mlir::tensor::canFoldIntoProducerOp(CastOp castOp) {
354354 castOp.getType ());
355355}
356356
357+ bool mlir::tensor::hasFoldableTensorCastOperand (Operation *op) {
358+ return llvm::any_of (op->getOpOperands (), [&](OpOperand &opOperand) {
359+ if (llvm::isa<BlockArgument>(opOperand.get ()))
360+ return false ;
361+ auto castOp = opOperand.get ().getDefiningOp <tensor::CastOp>();
362+ return castOp && canFoldIntoConsumerOp (castOp);
363+ });
364+ }
365+
366+ SmallVector<Value> mlir::tensor::getUpdatedOperandsAfterCastOpFolding (
367+ DestinationStyleOpInterface op, SmallVector<Type> &newResTy) {
368+ SmallVector<Value> newOperands;
369+ newOperands.reserve (op->getNumOperands ());
370+
371+ assert (hasFoldableTensorCastOperand (op) && " No foldable CastOp operands!" );
372+
373+ // Assumes that the result has dpsInits followed by nonDpsInits.
374+ int64_t dpsInitIdx = 0 ;
375+ for (OpOperand &opOperand : op->getOpOperands ()) {
376+ auto tensorCastOp = opOperand.get ().getDefiningOp <tensor::CastOp>();
377+ bool fold = canFoldIntoConsumerOp (tensorCastOp);
378+ newOperands.push_back (fold ? tensorCastOp.getOperand () : opOperand.get ());
379+ if (op.isDpsInit (&opOperand) &&
380+ !llvm::isa<MemRefType>(newOperands.back ().getType ()))
381+ newResTy[dpsInitIdx++] = newOperands.back ().getType ();
382+ }
383+ return newOperands;
384+ }
385+
357386// / Performs folding of any operand of `op` if it comes from a tensor::CastOp
358387// / that can be folded.
359388LogicalResult mlir::tensor::foldTensorCast (Operation *op) {
@@ -4777,34 +4806,7 @@ bool foldTensorCastPrecondition(DestinationStyleOpInterface op) {
47774806 isa<LoopLikeOpInterface>(op.getOperation ()))
47784807 return false ;
47794808
4780- // If no operand comes from a tensor::CastOp and can be folded then fail.
4781- bool hasTensorCastOperand =
4782- llvm::any_of (op->getOpOperands (), [&](OpOperand &opOperand) {
4783- if (llvm::isa<BlockArgument>(opOperand.get ()))
4784- return false ;
4785- auto castOp = opOperand.get ().getDefiningOp <tensor::CastOp>();
4786- return castOp && canFoldIntoConsumerOp (castOp);
4787- });
4788-
4789- return hasTensorCastOperand;
4790- }
4791-
4792- static SmallVector<Value> getNewOperands (DestinationStyleOpInterface op,
4793- SmallVector<Type> &newResTy) {
4794- SmallVector<Value> newOperands;
4795- newOperands.reserve (op->getNumOperands ());
4796-
4797- // Assumes that the result has dpsInits followed by nonDpsInits.
4798- int64_t dpsInitIdx = 0 ;
4799- for (OpOperand &opOperand : op->getOpOperands ()) {
4800- auto tensorCastOp = opOperand.get ().getDefiningOp <tensor::CastOp>();
4801- bool fold = canFoldIntoConsumerOp (tensorCastOp);
4802- newOperands.push_back (fold ? tensorCastOp.getOperand () : opOperand.get ());
4803- if (op.isDpsInit (&opOperand) &&
4804- !llvm::isa<MemRefType>(newOperands.back ().getType ()))
4805- newResTy[dpsInitIdx++] = newOperands.back ().getType ();
4806- }
4807- return newOperands;
4809+ return hasFoldableTensorCastOperand (op);
48084810}
48094811
48104812// Given the (potentially) updated packed type, `newPackedTy`, generates an
@@ -4868,7 +4870,8 @@ struct FoldTensorCastPackOp : public OpRewritePattern<PackOp> {
48684870 return failure ();
48694871
48704872 SmallVector<Type> newResultTypes (op->getResultTypes ());
4871- SmallVector<Value> newOperands = getNewOperands (op, newResultTypes);
4873+ SmallVector<Value> newOperands =
4874+ getUpdatedOperandsAfterCastOpFolding (op, newResultTypes);
48724875
48734876 // Get the updated mixed-tile-sizes attribute.
48744877 SmallVector<OpFoldResult> newMixedTileSizes =
@@ -4920,7 +4923,8 @@ struct FoldTensorCastUnPackOp : public OpRewritePattern<UnPackOp> {
49204923 return failure ();
49214924
49224925 SmallVector<Type> newResultTypes (op->getResultTypes ());
4923- SmallVector<Value> newOperands = getNewOperands (op, newResultTypes);
4926+ SmallVector<Value> newOperands =
4927+ getUpdatedOperandsAfterCastOpFolding (op, newResultTypes);
49244928 Value sourceTensor = newOperands[0 ];
49254929
49264930 // Get the updated mixed-tile-sizes attribute.
@@ -4980,7 +4984,8 @@ struct FoldTensorCastProducerOp
49804984 return failure ();
49814985
49824986 SmallVector<Type> newResultTypes (op->getResultTypes ());
4983- SmallVector<Value> newOperands = getNewOperands (op, newResultTypes);
4987+ SmallVector<Value> newOperands =
4988+ getUpdatedOperandsAfterCastOpFolding (op, newResultTypes);
49844989
49854990 // Clone op
49864991 auto newOp = clone (rewriter, op, newResultTypes, newOperands);
0 commit comments