|
24 | 24 | #include "mlir/IR/TypeUtilities.h" |
25 | 25 | #include "mlir/Interfaces/DestinationStyleOpInterface.h" |
26 | 26 | #include "mlir/Interfaces/LoopLikeInterface.h" |
| 27 | +#include "mlir/Support/LLVM.h" |
27 | 28 | #include "llvm/ADT/DenseSet.h" |
28 | 29 | #include "llvm/ADT/STLExtras.h" |
29 | 30 | #include "llvm/ADT/SmallBitVector.h" |
@@ -1968,14 +1969,94 @@ struct FoldDimOfCollapseShape : public OpRewritePattern<DimOp> { |
1968 | 1969 | return success(); |
1969 | 1970 | } |
1970 | 1971 | }; |
| 1972 | + |
| 1973 | +/// Fold/sink a producer `tensor.cast` with a consumer `tensor.expand_shape` by |
| 1974 | +/// matching constant output_shape operands of the expand. This makes the |
| 1975 | +/// `tensor.expand_shape` more static and creates a consumer cast that can be |
| 1976 | +/// propagated further. |
| 1977 | +struct ConvertToStaticExpandShape : public OpRewritePattern<ExpandShapeOp> { |
| 1978 | + using OpRewritePattern<ExpandShapeOp>::OpRewritePattern; |
| 1979 | + |
| 1980 | + LogicalResult matchAndRewrite(ExpandShapeOp expandOp, |
| 1981 | + PatternRewriter &rewriter) const override { |
| 1982 | + auto castOp = expandOp.getSrc().getDefiningOp<CastOp>(); |
| 1983 | + if (!canFoldIntoConsumerOp(castOp)) |
| 1984 | + return failure(); |
| 1985 | + |
| 1986 | + ArrayRef<int64_t> castSrcShape = castOp.getSource().getType().getShape(); |
| 1987 | + SmallVector<ReassociationIndices, 4> reassoc = |
| 1988 | + expandOp.getReassociationIndices(); |
| 1989 | + |
| 1990 | + SmallVector<int64_t> newOutputShape(expandOp.getResultType().getShape()); |
| 1991 | + SmallVector<Value> dynamicOutputShape; |
| 1992 | + auto outputIt = expandOp.getOutputShape().begin(); |
| 1993 | + |
| 1994 | + for (const auto &[inputDim, innerReassoc] : llvm::enumerate(reassoc)) { |
| 1995 | + for (uint64_t outDim : innerReassoc) { |
| 1996 | + if (!ShapedType::isDynamic(newOutputShape[outDim])) |
| 1997 | + continue; |
| 1998 | + |
| 1999 | + // If the cast's src type is dynamic, don't infer any of the |
| 2000 | + // corresponding expanded dimensions. `tensor.expand_shape` requires at |
| 2001 | + // least one of the expanded dimensions to be dynamic if the input is |
| 2002 | + // dynamic. |
| 2003 | + Value val = *outputIt; |
| 2004 | + ++outputIt; |
| 2005 | + if (ShapedType::isDynamic(castSrcShape[inputDim])) { |
| 2006 | + dynamicOutputShape.push_back(val); |
| 2007 | + continue; |
| 2008 | + } |
| 2009 | + |
| 2010 | + APInt cst; |
| 2011 | + if (matchPattern(val, m_ConstantInt(&cst))) { |
| 2012 | + newOutputShape[outDim] = cst.getSExtValue(); |
| 2013 | + } else { |
| 2014 | + dynamicOutputShape.push_back(val); |
| 2015 | + } |
| 2016 | + } |
| 2017 | + } |
| 2018 | + |
| 2019 | + // Couldn't match any values, nothing to change |
| 2020 | + if (expandOp.getOutputShape().size() == dynamicOutputShape.size()) |
| 2021 | + return failure(); |
| 2022 | + |
| 2023 | + // Calculate the input shape from the output |
| 2024 | + SmallVector<int64_t> newInputShape(expandOp.getSrcType().getRank(), 1l); |
| 2025 | + for (auto inDim : llvm::seq<int>(0, newInputShape.size())) { |
| 2026 | + for (auto outDim : reassoc[inDim]) { |
| 2027 | + auto ofr = newOutputShape[outDim]; |
| 2028 | + if (ShapedType::isDynamic(ofr)) { |
| 2029 | + newInputShape[inDim] = ShapedType::kDynamic; |
| 2030 | + break; |
| 2031 | + } |
| 2032 | + newInputShape[inDim] *= ofr; |
| 2033 | + } |
| 2034 | + } |
| 2035 | + |
| 2036 | + SmallVector<OpFoldResult> outputOfr = |
| 2037 | + getMixedValues(newOutputShape, dynamicOutputShape, rewriter); |
| 2038 | + auto inputType = RankedTensorType::get( |
| 2039 | + newInputShape, expandOp.getSrcType().getElementType()); |
| 2040 | + auto outputType = RankedTensorType::get( |
| 2041 | + newOutputShape, expandOp.getSrcType().getElementType()); |
| 2042 | + auto inputCast = rewriter.create<CastOp>(expandOp.getLoc(), inputType, |
| 2043 | + expandOp.getSrc()); |
| 2044 | + auto newExpand = rewriter.create<ExpandShapeOp>( |
| 2045 | + expandOp.getLoc(), outputType, inputCast.getResult(), |
| 2046 | + expandOp.getReassociationIndices(), outputOfr); |
| 2047 | + rewriter.replaceOpWithNewOp<CastOp>(expandOp, expandOp.getType(), |
| 2048 | + newExpand.getResult()); |
| 2049 | + return success(); |
| 2050 | + } |
| 2051 | +}; |
1971 | 2052 | } // namespace |
1972 | 2053 |
|
1973 | 2054 | void ExpandShapeOp::getCanonicalizationPatterns(RewritePatternSet &results, |
1974 | 2055 | MLIRContext *context) { |
1975 | 2056 | results.add< |
1976 | 2057 | ComposeReassociativeReshapeOps<ExpandShapeOp, ReshapeOpKind::kExpand>, |
1977 | 2058 | ComposeExpandOfCollapseOp<ExpandShapeOp, CollapseShapeOp>, |
1978 | | - FoldReshapeWithConstant<ExpandShapeOp>, |
| 2059 | + ConvertToStaticExpandShape, FoldReshapeWithConstant<ExpandShapeOp>, |
1979 | 2060 | FoldReshapeWithSplat<ExpandShapeOp>, |
1980 | 2061 | FoldReshapeWithFromElements<ExpandShapeOp>, FoldDimOfExpandShape, |
1981 | 2062 | FoldDimOfCollapseShape>(context); |
|
0 commit comments