Skip to content

Commit d6463a3

Browse files
IanWood1ayounes-synaptics
authored andcommitted
[mlir] Convert expand_shape to more static form (llvm#112265)
Add pattern that converts a `tensor.expand_shape` op to a more static form. This matches the pattern: `tensor.cast` -> `tensor.expand_shape` if it has a foldable `tensor.cast` and some constant foldable `output_shape` operands for the `tensor.expand_shape`. This makes the `tensor.expand_shape` more static, as well as allowing the static information to be propagated further down in the program.
1 parent ea584bf commit d6463a3

File tree

2 files changed

+136
-1
lines changed

2 files changed

+136
-1
lines changed

mlir/lib/Dialect/Tensor/IR/TensorOps.cpp

Lines changed: 82 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,7 @@
2424
#include "mlir/IR/TypeUtilities.h"
2525
#include "mlir/Interfaces/DestinationStyleOpInterface.h"
2626
#include "mlir/Interfaces/LoopLikeInterface.h"
27+
#include "mlir/Support/LLVM.h"
2728
#include "llvm/ADT/DenseSet.h"
2829
#include "llvm/ADT/STLExtras.h"
2930
#include "llvm/ADT/SmallBitVector.h"
@@ -1968,14 +1969,94 @@ struct FoldDimOfCollapseShape : public OpRewritePattern<DimOp> {
19681969
return success();
19691970
}
19701971
};
1972+
1973+
/// Fold/sink a producer `tensor.cast` with a consumer `tensor.expand_shape` by
1974+
/// matching constant output_shape operands of the expand. This makes the
1975+
/// `tensor.expand_shape` more static and creates a consumer cast that can be
1976+
/// propagated further.
1977+
struct ConvertToStaticExpandShape : public OpRewritePattern<ExpandShapeOp> {
1978+
using OpRewritePattern<ExpandShapeOp>::OpRewritePattern;
1979+
1980+
LogicalResult matchAndRewrite(ExpandShapeOp expandOp,
1981+
PatternRewriter &rewriter) const override {
1982+
auto castOp = expandOp.getSrc().getDefiningOp<CastOp>();
1983+
if (!canFoldIntoConsumerOp(castOp))
1984+
return failure();
1985+
1986+
ArrayRef<int64_t> castSrcShape = castOp.getSource().getType().getShape();
1987+
SmallVector<ReassociationIndices, 4> reassoc =
1988+
expandOp.getReassociationIndices();
1989+
1990+
SmallVector<int64_t> newOutputShape(expandOp.getResultType().getShape());
1991+
SmallVector<Value> dynamicOutputShape;
1992+
auto outputIt = expandOp.getOutputShape().begin();
1993+
1994+
for (const auto &[inputDim, innerReassoc] : llvm::enumerate(reassoc)) {
1995+
for (uint64_t outDim : innerReassoc) {
1996+
if (!ShapedType::isDynamic(newOutputShape[outDim]))
1997+
continue;
1998+
1999+
// If the cast's src type is dynamic, don't infer any of the
2000+
// corresponding expanded dimensions. `tensor.expand_shape` requires at
2001+
// least one of the expanded dimensions to be dynamic if the input is
2002+
// dynamic.
2003+
Value val = *outputIt;
2004+
++outputIt;
2005+
if (ShapedType::isDynamic(castSrcShape[inputDim])) {
2006+
dynamicOutputShape.push_back(val);
2007+
continue;
2008+
}
2009+
2010+
APInt cst;
2011+
if (matchPattern(val, m_ConstantInt(&cst))) {
2012+
newOutputShape[outDim] = cst.getSExtValue();
2013+
} else {
2014+
dynamicOutputShape.push_back(val);
2015+
}
2016+
}
2017+
}
2018+
2019+
// Couldn't match any values, nothing to change
2020+
if (expandOp.getOutputShape().size() == dynamicOutputShape.size())
2021+
return failure();
2022+
2023+
// Calculate the input shape from the output
2024+
SmallVector<int64_t> newInputShape(expandOp.getSrcType().getRank(), 1l);
2025+
for (auto inDim : llvm::seq<int>(0, newInputShape.size())) {
2026+
for (auto outDim : reassoc[inDim]) {
2027+
auto ofr = newOutputShape[outDim];
2028+
if (ShapedType::isDynamic(ofr)) {
2029+
newInputShape[inDim] = ShapedType::kDynamic;
2030+
break;
2031+
}
2032+
newInputShape[inDim] *= ofr;
2033+
}
2034+
}
2035+
2036+
SmallVector<OpFoldResult> outputOfr =
2037+
getMixedValues(newOutputShape, dynamicOutputShape, rewriter);
2038+
auto inputType = RankedTensorType::get(
2039+
newInputShape, expandOp.getSrcType().getElementType());
2040+
auto outputType = RankedTensorType::get(
2041+
newOutputShape, expandOp.getSrcType().getElementType());
2042+
auto inputCast = rewriter.create<CastOp>(expandOp.getLoc(), inputType,
2043+
expandOp.getSrc());
2044+
auto newExpand = rewriter.create<ExpandShapeOp>(
2045+
expandOp.getLoc(), outputType, inputCast.getResult(),
2046+
expandOp.getReassociationIndices(), outputOfr);
2047+
rewriter.replaceOpWithNewOp<CastOp>(expandOp, expandOp.getType(),
2048+
newExpand.getResult());
2049+
return success();
2050+
}
2051+
};
19712052
} // namespace
19722053

19732054
void ExpandShapeOp::getCanonicalizationPatterns(RewritePatternSet &results,
19742055
MLIRContext *context) {
19752056
results.add<
19762057
ComposeReassociativeReshapeOps<ExpandShapeOp, ReshapeOpKind::kExpand>,
19772058
ComposeExpandOfCollapseOp<ExpandShapeOp, CollapseShapeOp>,
1978-
FoldReshapeWithConstant<ExpandShapeOp>,
2059+
ConvertToStaticExpandShape, FoldReshapeWithConstant<ExpandShapeOp>,
19792060
FoldReshapeWithSplat<ExpandShapeOp>,
19802061
FoldReshapeWithFromElements<ExpandShapeOp>, FoldDimOfExpandShape,
19812062
FoldDimOfCollapseShape>(context);

mlir/test/Dialect/Tensor/canonicalize.mlir

Lines changed: 54 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2606,3 +2606,57 @@ func.func @test_destination_multiple_result(%arg0: tensor<2x2xf32>, %arg1: tenso
26062606
%0:2 = test.destination_style_op ins(%cast : tensor<?x2xf32>) outs(%cast_0 : tensor<?x2xf32>) -> tensor<?x2xf32>, index
26072607
return %0#1 : index
26082608
}
2609+
2610+
// -----
2611+
2612+
func.func @fold_expand_of_cast(%arg0 : tensor<10x10xf32>)
2613+
-> tensor<10x1x10xf32> {
2614+
%c1 = arith.constant 1 : index
2615+
%c10 = arith.constant 10 : index
2616+
%0 = tensor.cast %arg0 : tensor<10x10xf32> to tensor<?x?xf32>
2617+
%1 = tensor.expand_shape %0 [[0, 1], [2]] output_shape [%c10, %c1, %c10]
2618+
: tensor<?x?xf32> into tensor<?x?x?xf32>
2619+
%2 = tensor.cast %1 : tensor<?x?x?xf32> to tensor<10x1x10xf32>
2620+
return %2 : tensor<10x1x10xf32>
2621+
}
2622+
// CHECK-LABEL: func.func @fold_expand_of_cast
2623+
// CHECK: %[[RES:.+]] = tensor.expand_shape %{{.*}} {{\[}}[0, 1], [2]] output_shape [10, 1, 10]
2624+
// CHECK: return %[[RES]]
2625+
2626+
// -----
2627+
2628+
func.func @sink_expand_of_cast(%arg0 : tensor<?x10xf32>)
2629+
-> tensor<?x?x?xf32> {
2630+
%c1 = arith.constant 1 : index
2631+
%c10 = arith.constant 10 : index
2632+
%0 = tensor.cast %arg0 : tensor<?x10xf32> to tensor<?x?xf32>
2633+
%1 = tensor.expand_shape %0 [[0, 1], [2]] output_shape [%c10, %c1, %c10]
2634+
: tensor<?x?xf32> into tensor<?x?x?xf32>
2635+
return %1 : tensor<?x?x?xf32>
2636+
}
2637+
// CHECK-LABEL: func.func @sink_expand_of_cast
2638+
// CHECK-DAG: %[[C10:.*]] = arith.constant 10
2639+
// CHECK-DAG: %[[C1:.*]] = arith.constant 1
2640+
// CHECK: %[[EXPAND:.+]] = tensor.expand_shape %{{.*}} {{\[}}[0, 1], [2]]
2641+
// CHECK-SAME: output_shape [%[[C10]], %[[C1]], 10]
2642+
// CHECK: %[[RES:.+]] = tensor.cast %[[EXPAND]]
2643+
// CHECK: return %[[RES]]
2644+
2645+
// -----
2646+
2647+
func.func @partial_sink_expand_of_cast(%arg0 : tensor<10x10xf32>, %arg1 : index, %arg2 : index)
2648+
-> tensor<?x?x?xf32> {
2649+
%c10 = arith.constant 10 : index
2650+
%0 = tensor.cast %arg0 : tensor<10x10xf32> to tensor<?x?xf32>
2651+
%1 = tensor.expand_shape %0 [[0, 1], [2]] output_shape [%arg1, %arg2, %c10]
2652+
: tensor<?x?xf32> into tensor<?x?x?xf32>
2653+
return %1 : tensor<?x?x?xf32>
2654+
}
2655+
// CHECK-LABEL: func.func @partial_sink_expand_of_cast
2656+
// CHECK: %[[CAST:.+]] = tensor.cast
2657+
// CHECK-SAME: tensor<10x10xf32> to tensor<?x10xf32>
2658+
// CHECK: %[[EXPAND:.+]] = tensor.expand_shape %{{.*}} {{\[}}[0, 1], [2]]
2659+
// CHECK-SAME: output_shape [%{{.*}}, %{{.*}}, 10]
2660+
// CHECK: %[[RES:.+]] = tensor.cast %[[EXPAND]]
2661+
// CHECK-SAME: tensor<?x?x10xf32> to tensor<?x?x?xf32>
2662+
// CHECK: return %[[RES]]

0 commit comments

Comments
 (0)